import math
import torch
from torch import nn
from torch.nn.parameter import Parameter
[docs]
class VariationalAutoencoder(nn.Module):
[docs]
def __init__(self, m, n, mask):
super(VariationalAutoencoder, self).__init__()
self.encoder = Encoder(m, n)
self.decoder = Decoder(m, n, mask)
[docs]
def forward(self, x):
mu, logvar = self.encoder(x)
latent = self.latent_sample(mu, logvar)
x_recon, logcov = self.decoder(latent)
return x_recon, logcov, mu, logvar
[docs]
def latent_sample(self, mu, logvar):
# the re-parameterization trick
if self.training:
std = logvar.mul(0.5).exp_()
eps = torch.empty_like(std).normal_()
return eps.mul(std).add_(mu)
else:
return mu
[docs]
class Block(nn.Module):
[docs]
def __init__(self, m, n):
super(Block, self).__init__()
self.input_dim = n
self.latent_dim = m
self.output_dim = n
[docs]
class Encoder(Block):
[docs]
def __init__(self, m, n):
super(Encoder, self).__init__(m, n)
# first encoder layer
self.inter_dim = self.input_dim
self.enc1 = nn.Linear(in_features=self.input_dim, out_features=self.inter_dim)
# second encoder layer
self.enc2 = nn.Linear(in_features=self.inter_dim, out_features=self.inter_dim)
# map to mu and variance
self.fc_mu = nn.Linear(in_features=self.inter_dim, out_features=self.latent_dim)
self.fc_logvar = nn.Linear(
in_features=self.inter_dim, out_features=self.latent_dim
)
[docs]
def forward(self, x):
# encoder layers
inter = torch.relu(self.enc1(x))
inter = torch.relu(self.enc2(inter))
# calculate mu & logvar
mu = self.fc_mu(inter)
logvar = self.fc_logvar(inter)
return mu, logvar
[docs]
class Decoder(Block):
[docs]
def __init__(self, m, n, mask):
super(Decoder, self).__init__(m, n)
# decoder layer -- estimate mean
self.dec_mean = SparseLinear(
in_features=self.latent_dim, out_features=self.output_dim, mask=mask
)
# decoder layer -- estimate log-covariance
self.fc_logcov = SparseLinear(
in_features=self.latent_dim, out_features=self.output_dim, mask=mask
)
[docs]
def forward(self, z):
# linear layer
mean = self.dec_mean(z)
logcov = self.fc_logcov(z)
return mean, logcov
[docs]
class SparseLinear(nn.Module):
[docs]
def __init__(
self, in_features, out_features, mask, bias=True, device=None, dtype=None
):
factory_kwargs = {"device": device, "dtype": dtype}
super(SparseLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.mask = mask
self.weight = Parameter(
torch.empty((out_features, in_features), **factory_kwargs)
)
if bias:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
self.reset_parameters()
[docs]
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)
[docs]
def forward(self, input):
# masked linear layer
return nn.functional.linear(input, self.weight * self.mask, self.bias)