aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/suntorch/models/variational
diff options
context:
space:
mode:
Diffstat (limited to 'code/sunlab/suntorch/models/variational')
-rw-r--r--code/sunlab/suntorch/models/variational/autoencoder.py128
-rw-r--r--code/sunlab/suntorch/models/variational/common.py12
-rw-r--r--code/sunlab/suntorch/models/variational/decoder.py33
-rw-r--r--code/sunlab/suntorch/models/variational/encoder.py34
4 files changed, 207 insertions, 0 deletions
diff --git a/code/sunlab/suntorch/models/variational/autoencoder.py b/code/sunlab/suntorch/models/variational/autoencoder.py
new file mode 100644
index 0000000..e335704
--- /dev/null
+++ b/code/sunlab/suntorch/models/variational/autoencoder.py
@@ -0,0 +1,128 @@
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+from .encoder import Encoder
+from .decoder import Decoder
+from .common import *
+
+
+class VariationalAutoencoder:
+ """# Variational Autoencoder Model"""
+
+ def __init__(
+ self, data_dim, latent_dim, enc_dec_size, negative_slope=0.3, dropout=0.0
+ ):
+ self.encoder = Encoder(
+ data_dim,
+ enc_dec_size,
+ latent_dim,
+ negative_slope=negative_slope,
+ dropout=dropout,
+ )
+ self.decoder = Decoder(
+ data_dim,
+ enc_dec_size,
+ latent_dim,
+ negative_slope=negative_slope,
+ dropout=dropout,
+ )
+ self.data_dim = data_dim
+ self.latent_dim = latent_dim
+ self.p = dropout
+ self.negative_slope = negative_slope
+ return
+
+ def parameters(self):
+ return (*self.encoder.parameters(), *self.decoder.parameters())
+
+ def train(self):
+ self.encoder.train(True)
+ self.decoder.train(True)
+ return self
+
+ def eval(self):
+ self.encoder.train(False)
+ self.decoder.train(False)
+ return self
+
+ def encode(self, data):
+ return self.encoder(data)
+
+ def decode(self, data):
+ return self.decoder(data)
+
+ def reparameterization(self, mean, var):
+ epsilon = torch.randn_like(var)
+ if torch.cuda.is_available():
+ epsilon = epsilon.cuda()
+ z = mean + var * epsilon
+ return z
+
+ def forward(self, x):
+ mean, log_var = self.encoder(x)
+ z = self.reparameterization(mean, torch.exp(0.5 * log_var))
+ X = self.decoder(z)
+ return X, mean, log_var
+
+ def __call__(self, data):
+ return self.forward(data)
+
+ def save(self, base="./"):
+ torch.save(self.encoder.state_dict(), base + "weights_encoder.pt")
+ torch.save(self.decoder.state_dict(), base + "weights_decoder.pt")
+ return self
+
+ def load(self, base="./"):
+ self.encoder.load_state_dict(torch.load(base + "weights_encoder.pt"))
+ self.encoder.eval()
+ self.decoder.load_state_dict(torch.load(base + "weights_decoder.pt"))
+ self.decoder.eval()
+ return self
+
+ def to(self, device):
+ self.encoder.to(device)
+ self.decoder.to(device)
+ return self
+
+ def cuda(self):
+ if torch.cuda.is_available():
+ self.encoder = self.encoder.cuda()
+ self.decoder = self.decoder.cuda()
+ return self
+
+ def cpu(self):
+ self.encoder = self.encoder.cpu()
+ self.decoder = self.decoder.cpu()
+ return self
+
+ def init_optimizers(self, recon_lr=1e-4):
+ self.optim_E_enc = torch.optim.Adam(self.encoder.parameters(), lr=recon_lr)
+ self.optim_D = torch.optim.Adam(self.decoder.parameters(), lr=recon_lr)
+ return self
+
+ def init_losses(self, recon_loss_fn=F.binary_cross_entropy):
+ self.recon_loss_fn = recon_loss_fn
+ return self
+
+ def train_step(self, raw_data):
+ data = to_var(raw_data.view(raw_data.size(0), -1))
+
+ self.encoder.zero_grad()
+ self.decoder.zero_grad()
+ X, _, _ = self.forward(data)
+ # mean, log_var = self.encoder(data)
+ # z = self.reparameterization(mean, torch.exp(0.5 * log_var))
+ # X = self.decoder(z)
+ self.recon_loss = self.recon_loss_fn(X + EPS, data + EPS)
+ self.recon_loss.backward()
+ self.optim_E_enc.step()
+ self.optim_D.step()
+ return
+
+ def losses(self):
+ try:
+ return self.recon_loss
+ except:
+ ...
+ return
diff --git a/code/sunlab/suntorch/models/variational/common.py b/code/sunlab/suntorch/models/variational/common.py
new file mode 100644
index 0000000..f10930e
--- /dev/null
+++ b/code/sunlab/suntorch/models/variational/common.py
@@ -0,0 +1,12 @@
+from torch.autograd import Variable
+
+EPS = 1e-15
+
+
+def to_var(x):
+ """# Convert to variable"""
+ import torch
+
+ if torch.cuda.is_available():
+ x = x.cuda()
+ return Variable(x)
diff --git a/code/sunlab/suntorch/models/variational/decoder.py b/code/sunlab/suntorch/models/variational/decoder.py
new file mode 100644
index 0000000..2eeb7a4
--- /dev/null
+++ b/code/sunlab/suntorch/models/variational/decoder.py
@@ -0,0 +1,33 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import sigmoid
+
+
+class Decoder(nn.Module):
+ """# Decoder Neural Network
+ X_dim: Output dimension shape
+ N: Inner neuronal layer size
+ z_dim: Input dimension shape
+ """
+
+ def __init__(self, X_dim, N, z_dim, dropout=0.0, negative_slope=0.3):
+ super(Decoder, self).__init__()
+ self.lin1 = nn.Linear(z_dim, N)
+ self.lin2 = nn.Linear(N, N)
+ self.lin3 = nn.Linear(N, X_dim)
+ self.p = dropout
+ self.negative_slope = negative_slope
+
+ def forward(self, x):
+ x = self.lin1(x)
+ if self.p > 0.0:
+ x = F.dropout(x, p=self.p, training=self.training)
+ x = F.leaky_relu(x, negative_slope=self.negative_slope)
+
+ x = self.lin2(x)
+ if self.p > 0.0:
+ x = F.dropout(x, p=self.p, training=self.training)
+ x = F.leaky_relu(x, negative_slope=self.negative_slope)
+
+ x = self.lin3(x)
+ return sigmoid(x)
diff --git a/code/sunlab/suntorch/models/variational/encoder.py b/code/sunlab/suntorch/models/variational/encoder.py
new file mode 100644
index 0000000..b08202f
--- /dev/null
+++ b/code/sunlab/suntorch/models/variational/encoder.py
@@ -0,0 +1,34 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Encoder(nn.Module):
+ """# Encoder Neural Network
+ X_dim: Input dimension shape
+ N: Inner neuronal layer size
+ z_dim: Output dimension shape
+ """
+
+ def __init__(self, X_dim, N, z_dim, dropout=0.0, negative_slope=0.3):
+ super(Encoder, self).__init__()
+ self.lin1 = nn.Linear(X_dim, N)
+ self.lin2 = nn.Linear(N, N)
+ self.lin3mu = nn.Linear(N, z_dim)
+ self.lin3sigma = nn.Linear(N, z_dim)
+ self.p = dropout
+ self.negative_slope = negative_slope
+
+ def forward(self, x):
+ x = self.lin1(x)
+ if self.p > 0.0:
+ x = F.dropout(x, p=self.p, training=self.training)
+ x = F.leaky_relu(x, negative_slope=self.negative_slope)
+
+ x = self.lin2(x)
+ if self.p > 0.0:
+ x = F.dropout(x, p=self.p, training=self.training)
+ x = F.leaky_relu(x, negative_slope=self.negative_slope)
+
+ mu = self.lin3mu(x)
+ sigma = self.lin3sigma(x)
+ return mu, sigma