aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/suntorch/models/variational/autoencoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'code/sunlab/suntorch/models/variational/autoencoder.py')
-rw-r--r--code/sunlab/suntorch/models/variational/autoencoder.py128
1 files changed, 128 insertions, 0 deletions
diff --git a/code/sunlab/suntorch/models/variational/autoencoder.py b/code/sunlab/suntorch/models/variational/autoencoder.py
new file mode 100644
index 0000000..e335704
--- /dev/null
+++ b/code/sunlab/suntorch/models/variational/autoencoder.py
@@ -0,0 +1,128 @@
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+from .encoder import Encoder
+from .decoder import Decoder
+from .common import *
+
+
+class VariationalAutoencoder:
+ """# Variational Autoencoder Model"""
+
+ def __init__(
+ self, data_dim, latent_dim, enc_dec_size, negative_slope=0.3, dropout=0.0
+ ):
+ self.encoder = Encoder(
+ data_dim,
+ enc_dec_size,
+ latent_dim,
+ negative_slope=negative_slope,
+ dropout=dropout,
+ )
+ self.decoder = Decoder(
+ data_dim,
+ enc_dec_size,
+ latent_dim,
+ negative_slope=negative_slope,
+ dropout=dropout,
+ )
+ self.data_dim = data_dim
+ self.latent_dim = latent_dim
+ self.p = dropout
+ self.negative_slope = negative_slope
+ return
+
+ def parameters(self):
+ return (*self.encoder.parameters(), *self.decoder.parameters())
+
+ def train(self):
+ self.encoder.train(True)
+ self.decoder.train(True)
+ return self
+
+ def eval(self):
+ self.encoder.train(False)
+ self.decoder.train(False)
+ return self
+
+ def encode(self, data):
+ return self.encoder(data)
+
+ def decode(self, data):
+ return self.decoder(data)
+
+ def reparameterization(self, mean, var):
+ epsilon = torch.randn_like(var)
+ if torch.cuda.is_available():
+ epsilon = epsilon.cuda()
+ z = mean + var * epsilon
+ return z
+
+ def forward(self, x):
+ mean, log_var = self.encoder(x)
+ z = self.reparameterization(mean, torch.exp(0.5 * log_var))
+ X = self.decoder(z)
+ return X, mean, log_var
+
+ def __call__(self, data):
+ return self.forward(data)
+
+ def save(self, base="./"):
+ torch.save(self.encoder.state_dict(), base + "weights_encoder.pt")
+ torch.save(self.decoder.state_dict(), base + "weights_decoder.pt")
+ return self
+
+ def load(self, base="./"):
+ self.encoder.load_state_dict(torch.load(base + "weights_encoder.pt"))
+ self.encoder.eval()
+ self.decoder.load_state_dict(torch.load(base + "weights_decoder.pt"))
+ self.decoder.eval()
+ return self
+
+ def to(self, device):
+ self.encoder.to(device)
+ self.decoder.to(device)
+ return self
+
+ def cuda(self):
+ if torch.cuda.is_available():
+ self.encoder = self.encoder.cuda()
+ self.decoder = self.decoder.cuda()
+ return self
+
+ def cpu(self):
+ self.encoder = self.encoder.cpu()
+ self.decoder = self.decoder.cpu()
+ return self
+
+ def init_optimizers(self, recon_lr=1e-4):
+ self.optim_E_enc = torch.optim.Adam(self.encoder.parameters(), lr=recon_lr)
+ self.optim_D = torch.optim.Adam(self.decoder.parameters(), lr=recon_lr)
+ return self
+
+ def init_losses(self, recon_loss_fn=F.binary_cross_entropy):
+ self.recon_loss_fn = recon_loss_fn
+ return self
+
+ def train_step(self, raw_data):
+ data = to_var(raw_data.view(raw_data.size(0), -1))
+
+ self.encoder.zero_grad()
+ self.decoder.zero_grad()
+ X, _, _ = self.forward(data)
+ # mean, log_var = self.encoder(data)
+ # z = self.reparameterization(mean, torch.exp(0.5 * log_var))
+ # X = self.decoder(z)
+ self.recon_loss = self.recon_loss_fn(X + EPS, data + EPS)
+ self.recon_loss.backward()
+ self.optim_E_enc.step()
+ self.optim_D.step()
+ return
+
+ def losses(self):
+ try:
+ return self.recon_loss
+ except:
+ ...
+ return