Coverage for /opt/conda/lib/python3.13/site-packages/medil/vae.py: 95%
102 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-01 15:11 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-01 15:11 +0000
1import math
3import torch
4from torch import nn
5from torch.nn import functional as F
8class VariationalAutoencoder(nn.Module):
9 def __init__(
10 self,
11 num_latent,
12 num_meas,
13 num_hidden_layers,
14 latent_width,
15 meas_width,
16 biadj=None,
17 encoder_hidden_dim=None,
18 ):
19 super().__init__()
21 if encoder_hidden_dim is None:
22 encoder_hidden_dim = max(num_meas, 64)
24 self.encoder = Encoder(
25 num_latent=num_latent * latent_width,
26 num_meas=num_meas,
27 hidden_dim=encoder_hidden_dim,
28 )
30 self.decoder = Decoder(
31 num_latent=num_latent,
32 num_meas=num_meas,
33 num_hidden_layers=num_hidden_layers,
34 latent_width=latent_width,
35 meas_width=meas_width,
36 biadj=biadj,
37 )
39 def forward(self, x):
40 mu, logvar = self.encoder(x)
41 z = self.latent_sample(mu, logvar)
42 x_recon = self.decoder(z)
43 return x_recon, mu, logvar
45 def latent_sample(self, mu, logvar):
46 if self.training:
47 std = torch.exp(0.5 * logvar)
48 eps = torch.randn_like(std)
49 return mu + eps * std
50 return mu
53class Encoder(nn.Module):
54 def __init__(self, num_latent, num_meas, hidden_dim=64):
55 super().__init__()
56 self.enc1 = nn.Linear(num_meas, hidden_dim)
57 self.bn1 = nn.BatchNorm1d(hidden_dim)
58 self.enc2 = nn.Linear(hidden_dim, hidden_dim)
59 self.bn2 = nn.BatchNorm1d(hidden_dim)
60 self.activation = nn.GELU()
61 self.fc_mu = nn.Linear(hidden_dim, num_latent)
62 self.fc_logvar = nn.Linear(hidden_dim, num_latent)
64 def forward(self, x):
65 h = self.activation(self.bn1(self.enc1(x)))
66 h = self.activation(self.bn2(self.enc2(h)))
67 mu = self.fc_mu(h)
68 logvar = self.fc_logvar(h)
69 return mu, logvar
72class Decoder(nn.Module):
73 def __init__(
74 self,
75 num_latent,
76 num_meas,
77 num_hidden_layers,
78 latent_width,
79 meas_width,
80 biadj=None,
81 ):
82 super().__init__()
84 self.num_latent = num_latent
85 self.num_meas = num_meas
86 self.latent_width = latent_width
87 self.meas_width = meas_width
89 self.latent_dim = num_latent * latent_width
90 self.hidden_dim = num_meas * meas_width
92 if biadj is None:
93 biadj = torch.ones(num_meas, num_latent)
94 else:
95 biadj = torch.as_tensor(biadj, dtype=torch.float32)
97 first_mask = self._expand_biadj(biadj, meas_width, latent_width)
98 hidden_mask = self._make_hidden_block_mask(num_meas, meas_width)
99 output_mask = self._make_output_mask(num_meas, meas_width)
101 self.linear_in = SparseLinear(
102 in_features=self.latent_dim,
103 out_features=self.hidden_dim,
104 mask=first_mask,
105 )
107 self.bn_in = nn.BatchNorm1d(self.hidden_dim)
109 self.hidden_layers = nn.ModuleList(
110 [
111 SparseLinear(
112 in_features=self.hidden_dim,
113 out_features=self.hidden_dim,
114 mask=hidden_mask,
115 )
116 for _ in range(num_hidden_layers)
117 ]
118 )
120 self.hidden_bns = nn.ModuleList(
121 [nn.BatchNorm1d(self.hidden_dim) for _ in range(num_hidden_layers)]
122 )
124 self.linear_out = SparseLinear(
125 in_features=self.hidden_dim,
126 out_features=self.num_meas,
127 mask=output_mask,
128 )
130 self.activation = nn.GELU()
132 @staticmethod
133 def _expand_biadj(biadj, meas_width, latent_width):
134 return biadj.repeat_interleave(meas_width, dim=0).repeat_interleave(
135 latent_width, dim=1
136 )
138 @staticmethod
139 def _make_hidden_block_mask(num_meas, width_per_meas):
140 block = torch.ones(width_per_meas, width_per_meas)
141 blocks = [block for _ in range(num_meas)]
142 return torch.block_diag(*blocks)
144 @staticmethod
145 def _make_output_mask(num_meas, width_per_meas):
146 block = torch.ones(1, width_per_meas)
147 blocks = [block for _ in range(num_meas)]
148 return torch.block_diag(*blocks)
150 def forward(self, z):
151 h = self.activation(self.bn_in(self.linear_in(z)))
153 for layer, bn in zip(self.hidden_layers, self.hidden_bns):
154 h = self.activation(bn(layer(h)))
156 x_recon = self.linear_out(h)
157 return x_recon
160class SparseLinear(nn.Module):
161 def __init__(
162 self,
163 in_features,
164 out_features,
165 mask=None,
166 bias=True,
167 device=None,
168 dtype=None,
169 ):
170 super().__init__()
171 factory_kwargs = {"device": device, "dtype": dtype}
173 self.in_features = in_features
174 self.out_features = out_features
176 if mask is None:
177 mask = torch.ones(out_features, in_features)
178 else:
179 if mask.shape != (out_features, in_features):
180 raise ValueError(
181 f"mask must have shape {(out_features, in_features)}, "
182 f"got {tuple(mask.shape)}"
183 )
185 self.register_buffer("mask", mask.float())
187 self.weight = nn.Parameter(
188 torch.empty((out_features, in_features), **factory_kwargs)
189 )
191 if bias:
192 self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
193 else:
194 self.register_parameter("bias", None)
196 self.reset_parameters()
198 def reset_parameters(self):
199 nn.init.orthogonal_(self.weight)
200 if self.bias is not None:
201 fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
202 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
203 nn.init.uniform_(self.bias, -bound, bound)
205 def forward(self, x):
206 return F.linear(x, self.weight * self.mask, self.bias)