Coverage for /opt/conda/lib/python3.13/site-packages/medil/vae.py: 95%

102 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-01 15:11 +0000

1import math 

2 

3import torch 

4from torch import nn 

5from torch.nn import functional as F 

6 

7 

8class VariationalAutoencoder(nn.Module): 

9 def __init__( 

10 self, 

11 num_latent, 

12 num_meas, 

13 num_hidden_layers, 

14 latent_width, 

15 meas_width, 

16 biadj=None, 

17 encoder_hidden_dim=None, 

18 ): 

19 super().__init__() 

20 

21 if encoder_hidden_dim is None: 

22 encoder_hidden_dim = max(num_meas, 64) 

23 

24 self.encoder = Encoder( 

25 num_latent=num_latent * latent_width, 

26 num_meas=num_meas, 

27 hidden_dim=encoder_hidden_dim, 

28 ) 

29 

30 self.decoder = Decoder( 

31 num_latent=num_latent, 

32 num_meas=num_meas, 

33 num_hidden_layers=num_hidden_layers, 

34 latent_width=latent_width, 

35 meas_width=meas_width, 

36 biadj=biadj, 

37 ) 

38 

39 def forward(self, x): 

40 mu, logvar = self.encoder(x) 

41 z = self.latent_sample(mu, logvar) 

42 x_recon = self.decoder(z) 

43 return x_recon, mu, logvar 

44 

45 def latent_sample(self, mu, logvar): 

46 if self.training: 

47 std = torch.exp(0.5 * logvar) 

48 eps = torch.randn_like(std) 

49 return mu + eps * std 

50 return mu 

51 

52 

53class Encoder(nn.Module): 

54 def __init__(self, num_latent, num_meas, hidden_dim=64): 

55 super().__init__() 

56 self.enc1 = nn.Linear(num_meas, hidden_dim) 

57 self.bn1 = nn.BatchNorm1d(hidden_dim) 

58 self.enc2 = nn.Linear(hidden_dim, hidden_dim) 

59 self.bn2 = nn.BatchNorm1d(hidden_dim) 

60 self.activation = nn.GELU() 

61 self.fc_mu = nn.Linear(hidden_dim, num_latent) 

62 self.fc_logvar = nn.Linear(hidden_dim, num_latent) 

63 

64 def forward(self, x): 

65 h = self.activation(self.bn1(self.enc1(x))) 

66 h = self.activation(self.bn2(self.enc2(h))) 

67 mu = self.fc_mu(h) 

68 logvar = self.fc_logvar(h) 

69 return mu, logvar 

70 

71 

72class Decoder(nn.Module): 

73 def __init__( 

74 self, 

75 num_latent, 

76 num_meas, 

77 num_hidden_layers, 

78 latent_width, 

79 meas_width, 

80 biadj=None, 

81 ): 

82 super().__init__() 

83 

84 self.num_latent = num_latent 

85 self.num_meas = num_meas 

86 self.latent_width = latent_width 

87 self.meas_width = meas_width 

88 

89 self.latent_dim = num_latent * latent_width 

90 self.hidden_dim = num_meas * meas_width 

91 

92 if biadj is None: 

93 biadj = torch.ones(num_meas, num_latent) 

94 else: 

95 biadj = torch.as_tensor(biadj, dtype=torch.float32) 

96 

97 first_mask = self._expand_biadj(biadj, meas_width, latent_width) 

98 hidden_mask = self._make_hidden_block_mask(num_meas, meas_width) 

99 output_mask = self._make_output_mask(num_meas, meas_width) 

100 

101 self.linear_in = SparseLinear( 

102 in_features=self.latent_dim, 

103 out_features=self.hidden_dim, 

104 mask=first_mask, 

105 ) 

106 

107 self.bn_in = nn.BatchNorm1d(self.hidden_dim) 

108 

109 self.hidden_layers = nn.ModuleList( 

110 [ 

111 SparseLinear( 

112 in_features=self.hidden_dim, 

113 out_features=self.hidden_dim, 

114 mask=hidden_mask, 

115 ) 

116 for _ in range(num_hidden_layers) 

117 ] 

118 ) 

119 

120 self.hidden_bns = nn.ModuleList( 

121 [nn.BatchNorm1d(self.hidden_dim) for _ in range(num_hidden_layers)] 

122 ) 

123 

124 self.linear_out = SparseLinear( 

125 in_features=self.hidden_dim, 

126 out_features=self.num_meas, 

127 mask=output_mask, 

128 ) 

129 

130 self.activation = nn.GELU() 

131 

132 @staticmethod 

133 def _expand_biadj(biadj, meas_width, latent_width): 

134 return biadj.repeat_interleave(meas_width, dim=0).repeat_interleave( 

135 latent_width, dim=1 

136 ) 

137 

138 @staticmethod 

139 def _make_hidden_block_mask(num_meas, width_per_meas): 

140 block = torch.ones(width_per_meas, width_per_meas) 

141 blocks = [block for _ in range(num_meas)] 

142 return torch.block_diag(*blocks) 

143 

144 @staticmethod 

145 def _make_output_mask(num_meas, width_per_meas): 

146 block = torch.ones(1, width_per_meas) 

147 blocks = [block for _ in range(num_meas)] 

148 return torch.block_diag(*blocks) 

149 

150 def forward(self, z): 

151 h = self.activation(self.bn_in(self.linear_in(z))) 

152 

153 for layer, bn in zip(self.hidden_layers, self.hidden_bns): 

154 h = self.activation(bn(layer(h))) 

155 

156 x_recon = self.linear_out(h) 

157 return x_recon 

158 

159 

160class SparseLinear(nn.Module): 

161 def __init__( 

162 self, 

163 in_features, 

164 out_features, 

165 mask=None, 

166 bias=True, 

167 device=None, 

168 dtype=None, 

169 ): 

170 super().__init__() 

171 factory_kwargs = {"device": device, "dtype": dtype} 

172 

173 self.in_features = in_features 

174 self.out_features = out_features 

175 

176 if mask is None: 

177 mask = torch.ones(out_features, in_features) 

178 else: 

179 if mask.shape != (out_features, in_features): 

180 raise ValueError( 

181 f"mask must have shape {(out_features, in_features)}, " 

182 f"got {tuple(mask.shape)}" 

183 ) 

184 

185 self.register_buffer("mask", mask.float()) 

186 

187 self.weight = nn.Parameter( 

188 torch.empty((out_features, in_features), **factory_kwargs) 

189 ) 

190 

191 if bias: 

192 self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) 

193 else: 

194 self.register_parameter("bias", None) 

195 

196 self.reset_parameters() 

197 

198 def reset_parameters(self): 

199 nn.init.orthogonal_(self.weight) 

200 if self.bias is not None: 

201 fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 

202 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 

203 nn.init.uniform_(self.bias, -bound, bound) 

204 

205 def forward(self, x): 

206 return F.linear(x, self.weight * self.mask, self.bias)