Coverage for /opt/conda/lib/python3.12/site-packages/medil/interv_vae.py: 93%

121 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-25 05:42 +0000

1import math 

2import warnings 

3 

4import torch 

5from torch import nn 

6from torch.nn.parameter import Parameter 

7 

8 

9class VariationalAutoencoder(nn.Module): 

10 def __init__( 

11 self, num_meas, meas_width, meas_depth, num_latent, latent_width, latent_depth 

12 ): 

13 super(VariationalAutoencoder, self).__init__() 

14 self.encoder = Encoder(num_meas, num_latent, latent_width) 

15 self.decoder = Decoder( 

16 num_latent, latent_width, latent_depth, num_meas, meas_width, meas_depth 

17 ) 

18 

19 def forward(self, x, interv_idx): 

20 mu, logvar = self.encoder(x) 

21 latent = self.latent_sample(mu, logvar) 

22 x_recon, logcov = self.decoder(latent, interv_idx) 

23 

24 return x_recon, logcov, mu, logvar 

25 

26 def latent_sample(self, mu, logvar): 

27 # the re-parameterization trick 

28 if self.training: 

29 std = logvar.mul(0.5).exp_() 

30 eps = torch.empty_like(std).normal_() 

31 return eps.mul(std).add_(mu) 

32 else: 

33 return mu 

34 

35 

36class Encoder(nn.Module): 

37 def __init__(self, num_meas, num_latent, latent_width): 

38 super(Encoder, self).__init__() 

39 

40 # first encoder layer 

41 self.enc1 = nn.Linear(in_features=num_meas, out_features=num_meas) 

42 

43 # second encoder layer 

44 self.enc2 = nn.Linear(in_features=num_meas, out_features=num_meas) 

45 

46 # map to mu and variance 

47 num_vae_latent = num_latent * latent_width 

48 self.fc_mu = nn.Linear(in_features=num_meas, out_features=num_vae_latent) 

49 self.fc_logvar = nn.Linear(in_features=num_meas, out_features=num_vae_latent) 

50 

51 def forward(self, x): 

52 activation = torch.nn.GELU() 

53 # encoder layers 

54 x = activation(self.enc1(x)) 

55 x = activation(self.enc2(x)) 

56 

57 # calculate mu & logvar 

58 mu = self.fc_mu(x) 

59 logvar = self.fc_logvar(x) 

60 

61 return mu, logvar 

62 

63 

64class Decoder(nn.Module): 

65 def __init__( 

66 self, num_latent, latent_width, latent_depth, num_meas, meas_width, meas_depth 

67 ): 

68 super(Decoder, self).__init__() 

69 num_vae_latent = num_latent * latent_width 

70 num_vae_meas = num_meas * meas_width 

71 if meas_depth == 0 and meas_width > 1: 

72 num_vae_meas = num_meas 

73 warnings.warn( 

74 f"Reduced architecture complexity: `meas_width` set to 1 rather than {meas_width} since `meas_depth`={meas_depth}." 

75 ) 

76 

77 # hidden latent layers 

78 hidden_block = torch.ones(latent_width, latent_width) 

79 hidden_blocks = [hidden_block for _ in range(num_latent)] 

80 hidden_mask = torch.block_diag(*hidden_blocks) 

81 self.mean_hidden_latent = { 

82 layer_idx: SparseLinear( 

83 in_features=num_vae_latent, 

84 out_features=num_vae_latent, 

85 mask=hidden_mask, 

86 ) 

87 for layer_idx in range(latent_depth) 

88 } 

89 self.logcov_hidden_latent = { 

90 layer_idx: SparseLinear( 

91 in_features=num_vae_latent, 

92 out_features=num_vae_latent, 

93 mask=hidden_mask, 

94 ) 

95 for layer_idx in range(latent_depth) 

96 } 

97 

98 # causal layer 

99 self.mean_causal = Intervenable( 

100 in_features=num_vae_latent, out_features=num_vae_latent, width=latent_width 

101 ) 

102 self.logcov_causal = SparseLinear( 

103 in_features=num_vae_latent, out_features=num_vae_latent 

104 ) 

105 

106 # mixture layer 

107 self.mean_mix = SparseLinear( 

108 in_features=num_vae_latent, out_features=num_vae_meas 

109 ) 

110 self.logcov_mix = SparseLinear( 

111 in_features=num_vae_latent, out_features=num_vae_meas 

112 ) 

113 

114 # additional mixture layers 

115 self.mean_hidden_mix = { 

116 layer_idx: SparseLinear( 

117 in_features=num_vae_meas, 

118 out_features=num_vae_meas, 

119 ) 

120 for layer_idx in range(meas_depth - 1) 

121 } 

122 self.mean_hidden_mix[meas_depth] = SparseLinear( 

123 in_features=num_vae_meas, 

124 out_features=num_meas, 

125 ) 

126 self.logcov_hidden_mix = { 

127 layer_idx: SparseLinear( 

128 in_features=num_vae_meas, 

129 out_features=num_vae_meas, 

130 ) 

131 for layer_idx in range(meas_depth - 1) 

132 } 

133 self.logcov_hidden_mix[meas_depth] = SparseLinear( 

134 in_features=num_vae_meas, 

135 out_features=num_meas, 

136 ) 

137 

138 self.activation = torch.nn.GELU() 

139 

140 def forward(self, z, interv_idx): 

141 # hidden layers for latent exogenous variables 

142 mean = z.clone() 

143 logcov = z.clone() 

144 for hidden_layer in self.mean_hidden_latent.values(): 

145 mean = hidden_layer(mean) 

146 mean = self.activation(mean) 

147 for hidden_layer in self.logcov_hidden_latent.values(): 

148 logcov = hidden_layer(logcov) 

149 logcov = self.activation(logcov) 

150 

151 # connect exogenous variables to latent causal DAG 

152 mean = self.mean_causal(mean, interv_idx) 

153 mean = self.activation(mean) 

154 logcov = self.logcov_causal(logcov) 

155 logcov = self.activation(logcov) 

156 

157 # mix latent causal vars into measurements 

158 mean = self.mean_mix(mean) 

159 logcov = self.logcov_mix(logcov) 

160 

161 # hidden layers for mixture 

162 for hidden_layer in self.mean_hidden_mix.values(): 

163 mean = self.activation(mean) 

164 mean = hidden_layer(mean) 

165 for hidden_layer in self.logcov_hidden_mix.values(): 

166 logcov = self.activation(logcov) 

167 logcov = hidden_layer(logcov) 

168 

169 return mean, logcov 

170 

171 

172class SparseLinear(nn.Module): 

173 def __init__( 

174 self, 

175 in_features, 

176 out_features, 

177 mask=None, 

178 bias=True, 

179 device=None, 

180 dtype=None, 

181 ): 

182 factory_kwargs = {"device": device, "dtype": dtype} 

183 super(SparseLinear, self).__init__() 

184 self.in_features = in_features 

185 self.out_features = out_features 

186 self.mask = mask 

187 self.weight = Parameter( 

188 torch.empty((out_features, in_features), **factory_kwargs) 

189 ) 

190 

191 if bias: 

192 self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) 

193 else: 

194 self.register_parameter("bias", None) 

195 self.reset_parameters() 

196 

197 def reset_parameters(self): 

198 # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 

199 nn.init.orthogonal_(self.weight) 

200 # nn.init.sparse_(self.weight, 2 / 3) 

201 if self.bias is not None: 

202 fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 

203 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 

204 nn.init.uniform_(self.bias, -bound, bound) 

205 

206 def forward(self, input): 

207 # masked linear layer 

208 if self.mask is None: 

209 return nn.functional.linear(input, self.weight, self.bias) 

210 else: 

211 return nn.functional.linear(input, self.weight * self.mask, self.bias) 

212 

213 def extra_repr(self): 

214 return "in_features={}, out_features={}, bias={}".format( 

215 self.in_features, self.out_features, self.bias is not None 

216 ) 

217 

218 

219class Intervenable(SparseLinear): 

220 def __init__(self, width, **kwargs): 

221 super().__init__(**kwargs) 

222 self.width = width 

223 

224 def forward(self, input, interv_idx): 

225 interv_idx = interv_idx.squeeze().to(int) # reformat 

226 observed = nn.functional.linear(input, self.weight, self.bias) 

227 

228 # interv idx of -1 indicates no intervention, so remove those 

229 sample_idx = torch.arange(len(input)) 

230 obs_mask = interv_idx == -1 

231 sample_idx = sample_idx[~obs_mask] 

232 interv_idx = interv_idx[~obs_mask] 

233 

234 # expand idcs according to width of layer in VAE 

235 expander = torch.arange(self.width).tile(len(sample_idx)) 

236 sample_idx = sample_idx.tile(self.width) 

237 interv_idx = interv_idx.tile(self.width) + expander 

238 

239 # perform intervention 

240 intervened = observed 

241 intervened[sample_idx, interv_idx] = input[sample_idx, interv_idx] 

242 return intervened