Coverage for /opt/conda/lib/python3.12/site-packages/medil/vae.py: 98%

97 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-25 05:42 +0000

1import math 

2 

3import torch 

4from torch import nn 

5from torch.nn.parameter import Parameter 

6 

7 

8class VariationalAutoencoder(nn.Module): 

9 def __init__( 

10 self, 

11 num_vae_latent, 

12 num_meas, 

13 num_hidden_layers, 

14 width_per_meas, 

15 prior_biadj=None, 

16 ): 

17 super(VariationalAutoencoder, self).__init__() 

18 self.encoder = Encoder(num_vae_latent, num_meas) 

19 self.decoder = Decoder( 

20 num_vae_latent, num_meas, num_hidden_layers, width_per_meas, prior_biadj 

21 ) 

22 

23 def forward(self, x): 

24 mu, logvar = self.encoder(x) 

25 latent = self.latent_sample(mu, logvar) 

26 x_recon, logcov = self.decoder(latent) 

27 

28 return x_recon, logcov, mu, logvar 

29 

30 def latent_sample(self, mu, logvar): 

31 # the re-parameterization trick 

32 if self.training: 

33 std = logvar.mul(0.5).exp_() 

34 eps = torch.empty_like(std).normal_() 

35 return eps.mul(std).add_(mu) 

36 else: 

37 return mu 

38 

39 

40class Block(nn.Module): 

41 def __init__(self, num_vae_latent, num_meas, width_per_meas=1): 

42 super(Block, self).__init__() 

43 self.input_dim = num_meas 

44 self.latent_dim = num_vae_latent 

45 self.hidden_dim = num_meas * width_per_meas 

46 self.output_dim = num_meas 

47 

48 

49class Encoder(Block): 

50 def __init__(self, num_vae_latent, num_meas): 

51 super(Encoder, self).__init__(num_vae_latent, num_meas) 

52 

53 # first encoder layer 

54 self.inter_dim = self.input_dim 

55 self.enc1 = nn.Linear(in_features=self.input_dim, out_features=self.inter_dim) 

56 

57 # second encoder layer 

58 self.enc2 = nn.Linear(in_features=self.inter_dim, out_features=self.inter_dim) 

59 

60 # map to mu and variance 

61 self.fc_mu = nn.Linear(in_features=self.inter_dim, out_features=self.latent_dim) 

62 self.fc_logvar = nn.Linear( 

63 in_features=self.inter_dim, out_features=self.latent_dim 

64 ) 

65 

66 def forward(self, x): 

67 # encoder layers 

68 norm = nn.BatchNorm1d(self.inter_dim) 

69 inter = torch.relu(norm(self.enc1(x))) 

70 inter = torch.relu(norm(self.enc2(inter))) 

71 

72 # calculate mu & logvar 

73 mu = self.fc_mu(inter) 

74 logvar = self.fc_logvar(inter) 

75 

76 return mu, logvar 

77 

78 

79class Decoder(Block): 

80 def __init__( 

81 self, num_vae_latent, num_meas, num_hidden_layers, width_per_meas, prior_biadj 

82 ): 

83 super(Decoder, self).__init__(num_vae_latent, num_meas, width_per_meas) 

84 

85 # # decoder layer -- estimate mean 

86 # self.dec_mean = SparseLinear( 

87 # in_features=self.latent_dim, out_features=self.output_dim 

88 # ) 

89 

90 # # decoder layer -- estimate log-covariance 

91 # self.fc_logcov = SparseLinear( 

92 # in_features=self.latent_dim, out_features=self.output_dim 

93 # ) 

94 

95 # new arch 

96 self.mean_linear_fulcon = SparseLinear( 

97 in_features=self.latent_dim, out_features=self.hidden_dim, mask=prior_biadj 

98 ) 

99 self.cov_linear_fulcon = SparseLinear( 

100 in_features=self.latent_dim, out_features=self.hidden_dim 

101 ) 

102 

103 hidden_block = torch.ones(width_per_meas, width_per_meas) 

104 hidden_blocks = [hidden_block for _ in range(num_meas)] 

105 hidden_mask = torch.block_diag(*hidden_blocks) 

106 

107 self.mean_linear_hidden = { 

108 layer_idx: SparseLinear( 

109 in_features=self.hidden_dim, 

110 out_features=self.hidden_dim, 

111 mask=hidden_mask, 

112 ) 

113 for layer_idx in range(num_hidden_layers) 

114 } 

115 self.cov_linear_hidden = { 

116 layer_idx: SparseLinear( 

117 in_features=self.hidden_dim, 

118 out_features=self.hidden_dim, 

119 mask=hidden_mask, 

120 ) 

121 for layer_idx in range(num_hidden_layers) 

122 } 

123 

124 output_block = torch.ones(1, width_per_meas) 

125 output_blocks = [output_block for _ in range(num_meas)] 

126 output_mask = torch.block_diag(*output_blocks) 

127 self.mean_linear_output = SparseLinear( 

128 in_features=self.hidden_dim, out_features=self.output_dim, mask=output_mask 

129 ) 

130 self.cov_linear_output = SparseLinear( 

131 in_features=self.hidden_dim, out_features=self.output_dim, mask=output_mask 

132 ) 

133 

134 self.activation = torch.nn.GELU() 

135 

136 def forward(self, z): 

137 # linear layer 

138 # mean = self.dec_mean(z) 

139 # logcov = self.fc_logcov(z) 

140 

141 # new arch 

142 norm = nn.BatchNorm1d(self.hidden_dim) 

143 mean = self.mean_linear_fulcon(z) 

144 mean = self.activation(norm(mean)) 

145 for hidden_layer in self.mean_linear_hidden.values(): 

146 mean = hidden_layer(norm(mean)) 

147 mean = self.activation(mean) 

148 mean = self.mean_linear_output(mean) 

149 

150 logcov = self.cov_linear_fulcon(z) 

151 logcov = self.activation(norm(logcov)) 

152 for hidden_layer in self.cov_linear_hidden.values(): 

153 logcov = hidden_layer(logcov) 

154 logcov = self.activation(norm(logcov)) 

155 logcov = self.cov_linear_output(logcov) 

156 

157 return mean, logcov 

158 

159 

160class SparseLinear(nn.Module): 

161 def __init__( 

162 self, 

163 in_features, 

164 out_features, 

165 mask=None, 

166 bias=True, 

167 device=None, 

168 dtype=None, 

169 ): 

170 factory_kwargs = {"device": device, "dtype": dtype} 

171 super(SparseLinear, self).__init__() 

172 self.in_features = in_features 

173 self.out_features = out_features 

174 if mask is None: 

175 self.mask = torch.ones(1) 

176 else: 

177 self.mask = mask 

178 self.weight = Parameter( 

179 torch.empty((out_features, in_features), **factory_kwargs) 

180 ) 

181 

182 if bias: 

183 self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) 

184 else: 

185 self.register_parameter("bias", None) 

186 self.reset_parameters() 

187 

188 def reset_parameters(self): 

189 # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 

190 nn.init.orthogonal_(self.weight) 

191 # nn.init.sparse_(self.weight, 2 / 3) 

192 if self.bias is not None: 

193 fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 

194 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 

195 nn.init.uniform_(self.bias, -bound, bound) 

196 

197 def forward(self, input): 

198 # masked linear layer 

199 return nn.functional.linear(input, self.weight * self.mask, self.bias) 

200 

201 def extra_repr(self): 

202 return "in_features={}, out_features={}, bias={}".format( 

203 self.in_features, self.out_features, self.bias is not None 

204 )