Coverage for /opt/conda/lib/python3.12/site-packages/medil/vae.py: 98%
97 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-25 05:42 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-25 05:42 +0000
1import math
3import torch
4from torch import nn
5from torch.nn.parameter import Parameter
8class VariationalAutoencoder(nn.Module):
9 def __init__(
10 self,
11 num_vae_latent,
12 num_meas,
13 num_hidden_layers,
14 width_per_meas,
15 prior_biadj=None,
16 ):
17 super(VariationalAutoencoder, self).__init__()
18 self.encoder = Encoder(num_vae_latent, num_meas)
19 self.decoder = Decoder(
20 num_vae_latent, num_meas, num_hidden_layers, width_per_meas, prior_biadj
21 )
23 def forward(self, x):
24 mu, logvar = self.encoder(x)
25 latent = self.latent_sample(mu, logvar)
26 x_recon, logcov = self.decoder(latent)
28 return x_recon, logcov, mu, logvar
30 def latent_sample(self, mu, logvar):
31 # the re-parameterization trick
32 if self.training:
33 std = logvar.mul(0.5).exp_()
34 eps = torch.empty_like(std).normal_()
35 return eps.mul(std).add_(mu)
36 else:
37 return mu
40class Block(nn.Module):
41 def __init__(self, num_vae_latent, num_meas, width_per_meas=1):
42 super(Block, self).__init__()
43 self.input_dim = num_meas
44 self.latent_dim = num_vae_latent
45 self.hidden_dim = num_meas * width_per_meas
46 self.output_dim = num_meas
49class Encoder(Block):
50 def __init__(self, num_vae_latent, num_meas):
51 super(Encoder, self).__init__(num_vae_latent, num_meas)
53 # first encoder layer
54 self.inter_dim = self.input_dim
55 self.enc1 = nn.Linear(in_features=self.input_dim, out_features=self.inter_dim)
57 # second encoder layer
58 self.enc2 = nn.Linear(in_features=self.inter_dim, out_features=self.inter_dim)
60 # map to mu and variance
61 self.fc_mu = nn.Linear(in_features=self.inter_dim, out_features=self.latent_dim)
62 self.fc_logvar = nn.Linear(
63 in_features=self.inter_dim, out_features=self.latent_dim
64 )
66 def forward(self, x):
67 # encoder layers
68 norm = nn.BatchNorm1d(self.inter_dim)
69 inter = torch.relu(norm(self.enc1(x)))
70 inter = torch.relu(norm(self.enc2(inter)))
72 # calculate mu & logvar
73 mu = self.fc_mu(inter)
74 logvar = self.fc_logvar(inter)
76 return mu, logvar
79class Decoder(Block):
80 def __init__(
81 self, num_vae_latent, num_meas, num_hidden_layers, width_per_meas, prior_biadj
82 ):
83 super(Decoder, self).__init__(num_vae_latent, num_meas, width_per_meas)
85 # # decoder layer -- estimate mean
86 # self.dec_mean = SparseLinear(
87 # in_features=self.latent_dim, out_features=self.output_dim
88 # )
90 # # decoder layer -- estimate log-covariance
91 # self.fc_logcov = SparseLinear(
92 # in_features=self.latent_dim, out_features=self.output_dim
93 # )
95 # new arch
96 self.mean_linear_fulcon = SparseLinear(
97 in_features=self.latent_dim, out_features=self.hidden_dim, mask=prior_biadj
98 )
99 self.cov_linear_fulcon = SparseLinear(
100 in_features=self.latent_dim, out_features=self.hidden_dim
101 )
103 hidden_block = torch.ones(width_per_meas, width_per_meas)
104 hidden_blocks = [hidden_block for _ in range(num_meas)]
105 hidden_mask = torch.block_diag(*hidden_blocks)
107 self.mean_linear_hidden = {
108 layer_idx: SparseLinear(
109 in_features=self.hidden_dim,
110 out_features=self.hidden_dim,
111 mask=hidden_mask,
112 )
113 for layer_idx in range(num_hidden_layers)
114 }
115 self.cov_linear_hidden = {
116 layer_idx: SparseLinear(
117 in_features=self.hidden_dim,
118 out_features=self.hidden_dim,
119 mask=hidden_mask,
120 )
121 for layer_idx in range(num_hidden_layers)
122 }
124 output_block = torch.ones(1, width_per_meas)
125 output_blocks = [output_block for _ in range(num_meas)]
126 output_mask = torch.block_diag(*output_blocks)
127 self.mean_linear_output = SparseLinear(
128 in_features=self.hidden_dim, out_features=self.output_dim, mask=output_mask
129 )
130 self.cov_linear_output = SparseLinear(
131 in_features=self.hidden_dim, out_features=self.output_dim, mask=output_mask
132 )
134 self.activation = torch.nn.GELU()
136 def forward(self, z):
137 # linear layer
138 # mean = self.dec_mean(z)
139 # logcov = self.fc_logcov(z)
141 # new arch
142 norm = nn.BatchNorm1d(self.hidden_dim)
143 mean = self.mean_linear_fulcon(z)
144 mean = self.activation(norm(mean))
145 for hidden_layer in self.mean_linear_hidden.values():
146 mean = hidden_layer(norm(mean))
147 mean = self.activation(mean)
148 mean = self.mean_linear_output(mean)
150 logcov = self.cov_linear_fulcon(z)
151 logcov = self.activation(norm(logcov))
152 for hidden_layer in self.cov_linear_hidden.values():
153 logcov = hidden_layer(logcov)
154 logcov = self.activation(norm(logcov))
155 logcov = self.cov_linear_output(logcov)
157 return mean, logcov
160class SparseLinear(nn.Module):
161 def __init__(
162 self,
163 in_features,
164 out_features,
165 mask=None,
166 bias=True,
167 device=None,
168 dtype=None,
169 ):
170 factory_kwargs = {"device": device, "dtype": dtype}
171 super(SparseLinear, self).__init__()
172 self.in_features = in_features
173 self.out_features = out_features
174 if mask is None:
175 self.mask = torch.ones(1)
176 else:
177 self.mask = mask
178 self.weight = Parameter(
179 torch.empty((out_features, in_features), **factory_kwargs)
180 )
182 if bias:
183 self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
184 else:
185 self.register_parameter("bias", None)
186 self.reset_parameters()
188 def reset_parameters(self):
189 # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
190 nn.init.orthogonal_(self.weight)
191 # nn.init.sparse_(self.weight, 2 / 3)
192 if self.bias is not None:
193 fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
194 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
195 nn.init.uniform_(self.bias, -bound, bound)
197 def forward(self, input):
198 # masked linear layer
199 return nn.functional.linear(input, self.weight * self.mask, self.bias)
201 def extra_repr(self):
202 return "in_features={}, out_features={}, bias={}".format(
203 self.in_features, self.out_features, self.bias is not None
204 )