aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/sunflow/models/adversarial_autoencoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'code/sunlab/sunflow/models/adversarial_autoencoder.py')
-rw-r--r--code/sunlab/sunflow/models/adversarial_autoencoder.py344
1 files changed, 344 insertions, 0 deletions
diff --git a/code/sunlab/sunflow/models/adversarial_autoencoder.py b/code/sunlab/sunflow/models/adversarial_autoencoder.py
new file mode 100644
index 0000000..4cbb2f8
--- /dev/null
+++ b/code/sunlab/sunflow/models/adversarial_autoencoder.py
@@ -0,0 +1,344 @@
+from sunlab.common.data.dataset import Dataset
+from sunlab.common.scaler.adversarial_scaler import AdversarialScaler
+from sunlab.common.distribution.adversarial_distribution import AdversarialDistribution
+from .encoder import Encoder
+from .decoder import Decoder
+from .discriminator import Discriminator
+from .encoder_discriminator import EncoderDiscriminator
+from .autoencoder import Autoencoder
+from tensorflow.keras import optimizers, metrics, losses
+import tensorflow as tf
+from numpy import ones, zeros, float32, NaN
+
+
+class AdversarialAutoencoder:
+ """# Adversarial Autoencoder
+ - distribution: The distribution used by the adversary to learn on"""
+
+ def __init__(
+ self,
+ model_base_directory,
+ distribution: AdversarialDistribution or None = None,
+ scaler: AdversarialScaler or None = None,
+ ):
+ """# Adversarial Autoencoder Model Initialization
+
+ - model_base_directory: The base folder directory where the model will
+ be saved/ loaded
+ - distribution: The distribution the adversary will use
+ - scaler: The scaling function the model will assume on the data"""
+ self.model_base_directory = model_base_directory
+ if distribution is not None:
+ self.distribution = distribution
+ else:
+ self.distribution = None
+ if scaler is not None:
+ self.scaler = scaler(self.model_base_directory)
+ else:
+ self.scaler = None
+
+ def init(
+ self,
+ data=None,
+ data_size=13,
+ autoencoder_layer_size=16,
+ adversary_layer_size=8,
+ latent_size=2,
+ autoencoder_depth=2,
+ dropout=0.0,
+ use_leaky_relu=False,
+ **kwargs,
+ ):
+ """# Initialize AAE model parameters
+ - data_size: int
+ - autoencoder_layer_size: int
+ - adversary_layer_size: int
+ - latent_size: int
+ - autoencoder_depth: int
+ - dropout: float
+ - use_leaky_relu: boolean"""
+ self.data_size = data_size
+ self.autoencoder_layer_size = autoencoder_layer_size
+ self.adversary_layer_size = adversary_layer_size
+ self.latent_size = latent_size
+ self.autoencoder_depth = autoencoder_depth
+ self.dropout = dropout
+ self.use_leaky_relu = use_leaky_relu
+ self.save_parameters()
+ self.encoder = Encoder(self.model_base_directory).init()
+ self.decoder = Decoder(self.model_base_directory).init()
+ self.autoencoder = Autoencoder(self.model_base_directory).init(
+ self.encoder, self.decoder
+ )
+ self.discriminator = Discriminator(self.model_base_directory).init()
+ self.encoder_discriminator = EncoderDiscriminator(
+ self.model_base_directory
+ ).init(self.encoder, self.discriminator)
+ if self.distribution is not None:
+ self.distribution = self.distribution(self.latent_size)
+ if (data is not None) and (self.scaler is not None):
+ self.scaler = self.scaler.init(data)
+ self.init_optimizers_and_metrics(**kwargs)
+ return self
+
+ def init_optimizers_and_metrics(
+ self,
+ optimizer=optimizers.Adam,
+ ae_metric=metrics.MeanAbsoluteError,
+ adv_metric=metrics.BinaryCrossentropy,
+ ae_lr=7e-4,
+ adv_lr=3e-4,
+ loss_fn=losses.BinaryCrossentropy,
+ **kwargs,
+ ):
+ """# Set the optimizer, loss function, and metrics"""
+ self.ae_optimizer = optimizer(learning_rate=ae_lr)
+ self.adv_optimizer = optimizer(learning_rate=adv_lr)
+ self.gan_optimizer = optimizer(learning_rate=adv_lr)
+ self.train_ae_metric = ae_metric()
+ self.val_ae_metric = ae_metric()
+ self.train_adv_metric = adv_metric()
+ self.val_adv_metric = adv_metric()
+ self.train_gan_metric = adv_metric()
+ self.val_gan_metric = adv_metric()
+ self.loss_fn = loss_fn()
+
+ def load(self):
+ """# Load the models from their respective files"""
+ self.load_parameters()
+ self.encoder = Encoder(self.model_base_directory).load()
+ self.decoder = Decoder(self.model_base_directory).load()
+ self.autoencoder = Autoencoder(self.model_base_directory).load()
+ self.discriminator = Discriminator(self.model_base_directory).load()
+ self.encoder_discriminator = EncoderDiscriminator(
+ self.model_base_directory
+ ).load()
+ if self.scaler is not None:
+ self.scaler = self.scaler.load()
+ return self
+
+ def save(self, overwrite=False):
+ """# Save each model in the AAE"""
+ self.encoder.save(overwrite=overwrite)
+ self.decoder.save(overwrite=overwrite)
+ self.autoencoder.save(overwrite=overwrite)
+ self.discriminator.save(overwrite=overwrite)
+ self.encoder_discriminator.save(overwrite=overwrite)
+ if self.scaler is not None:
+ self.scaler.save()
+
+ def save_parameters(self):
+ """# Save the AAE parameters in a file"""
+ from pickle import dump
+ from os import makedirs
+
+ makedirs(self.model_base_directory + "/portable/", exist_ok=True)
+ parameters = {
+ "data_size": self.data_size,
+ "autoencoder_layer_size": self.autoencoder_layer_size,
+ "adversary_layer_size": self.adversary_layer_size,
+ "latent_size": self.latent_size,
+ "autoencoder_depth": self.autoencoder_depth,
+ "dropout": self.dropout,
+ "use_leaky_relu": self.use_leaky_relu,
+ }
+ with open(
+ f"{self.model_base_directory}/portable/model_parameters.pkl", "wb"
+ ) as phandle:
+ dump(parameters, phandle)
+
+ def load_parameters(self):
+ """# Load the AAE parameters from a file"""
+ from pickle import load
+
+ with open(
+ f"{self.model_base_directory}/portable/model_parameters.pkl", "rb"
+ ) as phandle:
+ parameters = load(phandle)
+ self.data_size = parameters["data_size"]
+ self.autoencoder_layer_size = parameters["autoencoder_layer_size"]
+ self.adversary_layer_size = parameters["adversary_layer_size"]
+ self.latent_size = parameters["latent_size"]
+ self.autoencoder_depth = parameters["autoencoder_depth"]
+ self.dropout = parameters["dropout"]
+ self.use_leaky_relu = parameters["use_leaky_relu"]
+ return parameters
+
+ def summary(self):
+ """# Summarize each model in the AAE"""
+ self.encoder.summary()
+ self.decoder.summary()
+ self.autoencoder.summary()
+ self.discriminator.summary()
+ self.encoder_discriminator.summary()
+
+ @tf.function
+ def train_step(self, x, y):
+ """# Training Step
+
+ 1. Train the Autoencoder
+ 2. (If distribution is given) Train the discriminator
+ 3. (If the distribution is given) Train the encoder_discriminator"""
+ # Autoencoder Training
+ with tf.GradientTape() as tape:
+ decoded_vector = self.autoencoder(x, training=True)
+ ae_loss_value = self.loss_fn(y, decoded_vector)
+ grads = tape.gradient(ae_loss_value, self.autoencoder.model.trainable_weights)
+ self.ae_optimizer.apply_gradients(
+ zip(grads, self.autoencoder.model.trainable_weights)
+ )
+ self.train_ae_metric.update_state(y, decoded_vector)
+ if self.distribution is not None:
+ # Adversary Trainig
+ with tf.GradientTape() as tape:
+ latent_vector = self.encoder(x)
+ fakepred = self.distribution(x.shape[0])
+ discbatch_x = tf.concat([latent_vector, fakepred], axis=0)
+ discbatch_y = tf.concat([zeros(x.shape[0]), ones(x.shape[0])], axis=0)
+ adversary_vector = self.discriminator(discbatch_x, training=True)
+ adv_loss_value = self.loss_fn(discbatch_y, adversary_vector)
+ grads = tape.gradient(
+ adv_loss_value, self.discriminator.model.trainable_weights
+ )
+ self.adv_optimizer.apply_gradients(
+ zip(grads, self.discriminator.model.trainable_weights)
+ )
+ self.train_adv_metric.update_state(discbatch_y, adversary_vector)
+ # Gan Training
+ with tf.GradientTape() as tape:
+ gan_vector = self.encoder_discriminator(x, training=True)
+ adv_vector = tf.convert_to_tensor(ones((x.shape[0], 1), dtype=float32))
+ gan_loss_value = self.loss_fn(gan_vector, adv_vector)
+ grads = tape.gradient(gan_loss_value, self.encoder.model.trainable_weights)
+ self.gan_optimizer.apply_gradients(
+ zip(grads, self.encoder.model.trainable_weights)
+ )
+ self.train_gan_metric.update_state(adv_vector, gan_vector)
+ return (ae_loss_value, adv_loss_value, gan_loss_value)
+ return (ae_loss_value, None, None)
+
+ @tf.function
+ def test_step(self, x, y):
+ """# Test Step - On validation data
+
+ 1. Evaluate the Autoencoder
+ 2. (If distribution is given) Evaluate the discriminator
+ 3. (If the distribution is given) Evaluate the encoder_discriminator"""
+ val_decoded_vector = self.autoencoder(x, training=False)
+ self.val_ae_metric.update_state(y, val_decoded_vector)
+
+ if self.distribution is not None:
+ latent_vector = self.encoder(x)
+ fakepred = self.distribution(x.shape[0])
+ discbatch_x = tf.concat([latent_vector, fakepred], axis=0)
+ discbatch_y = tf.concat([zeros(x.shape[0]), ones(x.shape[0])], axis=0)
+ adversary_vector = self.discriminator(discbatch_x, training=False)
+ self.val_adv_metric.update_state(discbatch_y, adversary_vector)
+
+ gan_vector = self.encoder_discriminator(x, training=False)
+ self.val_gan_metric.update_state(ones(x.shape[0]), gan_vector)
+
+ # Garbage Collect at the end of each epoch
+ def on_epoch_end(self, _epoch, logs=None):
+ """# Cleanup environment to prevent memory leaks each epoch"""
+ import gc
+ from tensorflow.keras import backend as k
+
+ gc.collect()
+ k.clear_session()
+
+ def train(
+ self,
+ dataset: Dataset,
+ epoch_count: int = 1,
+ output=False,
+ output_freq=1,
+ fmt="%i[%.3f]: %.2e %.2e %.2e %.2e %.2e %.2e",
+ ):
+ """# Train the model on a dataset
+
+ - dataset: ataset = Dataset to train the model on, which as the
+ training and validation iterators set up
+ - epoch_count: int = The number of epochs to train
+ - output: boolean = Whether or not to output training information
+ - output_freq: int = The number of epochs between each output"""
+ from time import time
+ from numpy import array as narray
+
+ def fmtter(x):
+ return x if x is not None else -1
+
+ epoch_data = []
+ dataset.reset_iterators()
+
+ self.test_step(dataset.dataset, dataset.dataset)
+ val_ae = self.val_ae_metric.result()
+ val_adv = self.val_adv_metric.result()
+ val_gan = self.val_gan_metric.result()
+ self.val_ae_metric.reset_states()
+ self.val_adv_metric.reset_states()
+ self.val_gan_metric.reset_states()
+ print(
+ fmt
+ % (
+ 0,
+ NaN,
+ val_ae,
+ fmtter(val_adv),
+ fmtter(val_gan),
+ NaN,
+ NaN,
+ NaN,
+ )
+ )
+ for epoch in range(epoch_count):
+ start_time = time()
+
+ for step, (x_batch_train, y_batch_train) in enumerate(dataset.training):
+ ae_lv, adv_lv, gan_lv = self.train_step(x_batch_train, x_batch_train)
+
+ train_ae = self.train_ae_metric.result()
+ train_adv = self.train_adv_metric.result()
+ train_gan = self.train_gan_metric.result()
+ self.train_ae_metric.reset_states()
+ self.train_adv_metric.reset_states()
+ self.train_gan_metric.reset_states()
+
+ for step, (x_batch_val, y_batch_val) in enumerate(dataset.validation):
+ self.test_step(x_batch_val, x_batch_val)
+
+ val_ae = self.val_ae_metric.result()
+ val_adv = self.val_adv_metric.result()
+ val_gan = self.val_gan_metric.result()
+ self.val_ae_metric.reset_states()
+ self.val_adv_metric.reset_states()
+ self.val_gan_metric.reset_states()
+
+ epoch_data.append(
+ (
+ epoch,
+ train_ae,
+ val_ae,
+ fmtter(train_adv),
+ fmtter(val_adv),
+ fmtter(train_gan),
+ fmtter(val_gan),
+ )
+ )
+ if output and (epoch + 1) % output_freq == 0:
+ print(
+ fmt
+ % (
+ epoch + 1,
+ time() - start_time,
+ train_ae,
+ fmtter(train_adv),
+ fmtter(train_gan),
+ val_ae,
+ fmtter(val_adv),
+ fmtter(val_gan),
+ )
+ )
+ self.on_epoch_end(epoch)
+ dataset.reset_iterators()
+ return narray(epoch_data)