diff options
Diffstat (limited to 'code/sunlab/sunflow/models/utilities.py')
-rw-r--r-- | code/sunlab/sunflow/models/utilities.py | 93 |
1 files changed, 93 insertions, 0 deletions
diff --git a/code/sunlab/sunflow/models/utilities.py b/code/sunlab/sunflow/models/utilities.py new file mode 100644 index 0000000..ab0c2a6 --- /dev/null +++ b/code/sunlab/sunflow/models/utilities.py @@ -0,0 +1,93 @@ +# Higher-level functions + +from sunlab.common.distribution.adversarial_distribution import AdversarialDistribution +from sunlab.common.scaler.adversarial_scaler import AdversarialScaler +from sunlab.common.data.utilities import import_dataset +from .adversarial_autoencoder import AdversarialAutoencoder + + +def create_aae( + dataset_file_name, + model_directory, + normalization_scaler: AdversarialScaler, + distribution: AdversarialDistribution or None, + magnification=10, + latent_size=2, +): + """# Create Adversarial Autoencoder + + - dataset_file_name: str = Path to the dataset file + - model_directory: str = Path to save the model in + - normalization_scaler: AdversarialScaler = Data normalization Scaler Model + - distribution: AdversarialDistribution = Distribution for the Adversary + - magnification: int = The Magnification of the Dataset""" + dataset = import_dataset(dataset_file_name, magnification) + model = AdversarialAutoencoder( + model_directory, distribution, normalization_scaler + ).init(dataset.dataset, latent_size=latent_size) + return model + + +def create_aae_and_dataset( + dataset_file_name, + model_directory, + normalization_scaler: AdversarialScaler, + distribution: AdversarialDistribution or None, + magnification=10, + batch_size=1024, + shuffle=True, + val_split=0.1, + latent_size=2, +): + """# Create Adversarial Autoencoder and Load the Dataset + + - dataset_file_name: str = Path to the dataset file + - model_directory: str = Path to save the model in + - normalization_scaler: AdversarialScaler = Data normalization Scaler Model + - distribution: AdversarialDistribution = Distribution for the Adversary + - magnification: int = The Magnification of the Dataset""" + model = create_aae( + dataset_file_name, + model_directory, + normalization_scaler, + distribution, + magnification=magnification, + latent_size=latent_size, + ) + dataset = import_dataset( + dataset_file_name, + magnification, + batch_size=batch_size, + shuffle=shuffle, + val_split=val_split, + scaler=model.scaler, + ) + return model, dataset + + +def load_aae(model_directory, normalization_scaler: AdversarialScaler): + """# Load Adversarial Autoencoder + + - model_directory: str = Path to save the model in + - normalization_scaler: AdversarialScaler = Data normalization Scaler Model + """ + return AdversarialAutoencoder(model_directory, None, normalization_scaler).load() + + +def load_aae_and_dataset( + dataset_file_name, + model_directory, + normalization_scaler: AdversarialScaler, + magnification=10, +): + """# Load Adversarial Autoencoder + + - dataset_file_name: str = Path to the dataset file + - model_directory: str = Path to save the model in + - normalization_scaler: AdversarialScaler = Data normalization Scaler Model + - magnification: int = The Magnification of the Dataset""" + model = load_aae(model_directory, normalization_scaler) + dataset = import_dataset( + dataset_file_name, magnification=magnification, scaler=model.scaler + ) + return model, dataset |