diff options
author | Christian C <cc@localhost> | 2024-11-11 12:29:32 -0800 |
---|---|---|
committer | Christian C <cc@localhost> | 2024-11-11 12:29:32 -0800 |
commit | b85ee9d64a536937912544c7bbd5b98b635b7e8d (patch) | |
tree | cef7bc17d7b29f40fc6b1867d0ce0a742d5583d0 /code/sunlab/sunflow/models/autoencoder.py |
Initial commit
Diffstat (limited to 'code/sunlab/sunflow/models/autoencoder.py')
-rw-r--r-- | code/sunlab/sunflow/models/autoencoder.py | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/code/sunlab/sunflow/models/autoencoder.py b/code/sunlab/sunflow/models/autoencoder.py new file mode 100644 index 0000000..473d00d --- /dev/null +++ b/code/sunlab/sunflow/models/autoencoder.py @@ -0,0 +1,85 @@ +class Autoencoder: + """# Autoencoder Model + + Constructs an encoder-decoder model""" + + def __init__(self, model_base_directory): + """# Autoencoder Model Initialization + + - model_base_directory: The base folder directory where the model will + be saved/ loaded""" + self.model_base_directory = model_base_directory + + def init(self, encoder, decoder): + """# Initialize an Autoencoder + + - encoder: The encoder to use + - decoder: The decoder to use""" + from tensorflow import keras + + self.load_parameters() + self.model = keras.models.Sequential() + self.model.add(encoder.model) + self.model.add(decoder.model) + self.model._name = "Autoencoder" + return self + + def load(self): + """# Load an existing Autoencoder""" + from os import listdir + + if "autoencoder.keras" not in listdir(f"{self.model_base_directory}/portable/"): + return None + import tensorflow as tf + + self.model = tf.keras.models.load_model( + f"{self.model_base_directory}/portable/autoencoder.keras", compile=False + ) + self.model._name = "Autoencoder" + return self + + def save(self, overwrite=False): + """# Save the current Autoencoder + + - Overwrite: overwrite any existing autoencoder that has been saved""" + from os import listdir + + if overwrite: + self.model.save(f"{self.model_base_directory}/portable/autoencoder.keras") + return True + if "autoencoder.keras" in listdir(f"{self.model_base_directory}/portable/"): + return False + self.model.save(f"{self.model_base_directory}/portable/autoencoder.keras") + return True + + def load_parameters(self): + """# Load Autoencoder Model Parameters from File + The file needs to have the following parameters defined: + - data_size: int + - autoencoder_layer_size: int + - latent_size: int + - autoencoder_depth: int + - dropout: float (set to 0. if you don't want a dropout layer) + - use_leaky_relu: boolean""" + from pickle import load + + with open( + f"{self.model_base_directory}/portable/model_parameters.pkl", "rb" + ) as phandle: + parameters = load(phandle) + self.data_size = parameters["data_size"] + self.layer_size = parameters["autoencoder_layer_size"] + self.latent_size = parameters["latent_size"] + self.depth = parameters["autoencoder_depth"] + self.dropout = parameters["dropout"] + self.use_leaky_relu = parameters["use_leaky_relu"] + + def summary(self): + """# Returns the summary of the Autoencoder model""" + return self.model.summary() + + def __call__(self, *args, **kwargs): + """# Callable + + When calling the autoencoder class, return the model's output""" + return self.model(*args, **kwargs) |