diff options
author | Christian C <cc@localhost> | 2024-11-11 12:29:32 -0800 |
---|---|---|
committer | Christian C <cc@localhost> | 2024-11-11 12:29:32 -0800 |
commit | b85ee9d64a536937912544c7bbd5b98b635b7e8d (patch) | |
tree | cef7bc17d7b29f40fc6b1867d0ce0a742d5583d0 /code/sunlab/suntorch/models/adversarial_autoencoder.py |
Initial commit
Diffstat (limited to 'code/sunlab/suntorch/models/adversarial_autoencoder.py')
-rw-r--r-- | code/sunlab/suntorch/models/adversarial_autoencoder.py | 165 |
1 files changed, 165 insertions, 0 deletions
diff --git a/code/sunlab/suntorch/models/adversarial_autoencoder.py b/code/sunlab/suntorch/models/adversarial_autoencoder.py new file mode 100644 index 0000000..e3e8a06 --- /dev/null +++ b/code/sunlab/suntorch/models/adversarial_autoencoder.py @@ -0,0 +1,165 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable + +from .encoder import Encoder +from .decoder import Decoder +from .discriminator import Discriminator +from .common import * + + +def dummy_distribution(*args): + raise NotImplementedError("Give a distribution") + + +class AdversarialAutoencoder: + """# Adversarial Autoencoder Model""" + + def __init__( + self, + data_dim, + latent_dim, + enc_dec_size, + disc_size, + negative_slope=0.3, + dropout=0.0, + distribution=dummy_distribution, + ): + self.encoder = Encoder( + data_dim, + enc_dec_size, + latent_dim, + negative_slope=negative_slope, + dropout=dropout, + ) + self.decoder = Decoder( + data_dim, + enc_dec_size, + latent_dim, + negative_slope=negative_slope, + dropout=dropout, + ) + self.discriminator = Discriminator( + disc_size, latent_dim, negative_slope=negative_slope, dropout=dropout + ) + self.data_dim = data_dim + self.latent_dim = latent_dim + self.p = dropout + self.negative_slope = negative_slope + self.distribution = distribution + return + + def parameters(self): + return ( + *self.encoder.parameters(), + *self.decoder.parameters(), + *self.discriminator.parameters(), + ) + + def train(self): + self.encoder.train(True) + self.decoder.train(True) + self.discriminator.train(True) + return self + + def eval(self): + self.encoder.train(False) + self.decoder.train(False) + self.discriminator.train(False) + return self + + def encode(self, data): + return self.encoder(data) + + def decode(self, data): + return self.decoder(data) + + def __call__(self, data): + return self.decode(self.encode(data)) + + def save(self, base="./"): + torch.save(self.encoder.state_dict(), base + "weights_encoder.pt") + torch.save(self.decoder.state_dict(), base + "weights_decoder.pt") + torch.save(self.discriminator.state_dict(), base + "weights_discriminator.pt") + return self + + def load(self, base="./"): + self.encoder.load_state_dict(torch.load(base + "weights_encoder.pt")) + self.encoder.eval() + self.decoder.load_state_dict(torch.load(base + "weights_decoder.pt")) + self.decoder.eval() + self.discriminator.load_state_dict( + torch.load(base + "weights_discriminator.pt") + ) + self.discriminator.eval() + return self + + def to(self, device): + self.encoder.to(device) + self.decoder.to(device) + self.discriminator.to(device) + return self + + def cuda(self): + if torch.cuda.is_available(): + self.encoder = self.encoder.cuda() + self.decoder = self.decoder.cuda() + self.discriminator = self.discriminator.cuda() + return self + + def cpu(self): + self.encoder = self.encoder.cpu() + self.decoder = self.decoder.cpu() + self.discriminator = self.discriminator.cpu() + return self + + def init_optimizers(self, recon_lr=1e-4, adv_lr=5e-5): + self.optim_E_gen = torch.optim.Adam(self.encoder.parameters(), lr=adv_lr) + self.optim_E_enc = torch.optim.Adam(self.encoder.parameters(), lr=recon_lr) + self.optim_D_dec = torch.optim.Adam(self.decoder.parameters(), lr=recon_lr) + self.optim_D_dis = torch.optim.Adam(self.discriminator.parameters(), lr=adv_lr) + return self + + def init_losses(self, recon_loss_fn=F.binary_cross_entropy): + self.recon_loss_fn = recon_loss_fn + return self + + def train_step(self, raw_data, scale=1.0): + data = to_var(raw_data.view(raw_data.size(0), -1)) + + self.encoder.zero_grad() + self.decoder.zero_grad() + self.discriminator.zero_grad() + + z = self.encoder(data) + X = self.decoder(z) + self.recon_loss = self.recon_loss_fn(X + EPS, data + EPS) + self.recon_loss.backward() + self.optim_E_enc.step() + self.optim_D_dec.step() + + self.encoder.eval() + z_gaussian = to_var(self.distribution(data.size(0), self.latent_dim) * scale) + z_gaussian_fake = self.encoder(data) + y_gaussian = self.discriminator(z_gaussian) + y_gaussian_fake = self.discriminator(z_gaussian_fake) + self.D_loss = -torch.mean( + torch.log(y_gaussian + EPS) + torch.log(1 - y_gaussian_fake + EPS) + ) + self.D_loss.backward() + self.optim_D_dis.step() + + self.encoder.train() + z_gaussian = self.encoder(data) + y_gaussian = self.discriminator(z_gaussian) + self.G_loss = -torch.mean(torch.log(y_gaussian + EPS)) + self.G_loss.backward() + self.optim_E_gen.step() + return + + def losses(self): + try: + return self.recon_loss, self.D_loss, self.G_loss + except: + ... + return |