Coverage for /opt/conda/lib/python3.12/site-packages/medil/interv_vae2.py: 92%

123 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-25 05:42 +0000

1import math 

2import warnings 

3 

4import torch 

5from torch import nn 

6from torch.nn.parameter import Parameter 

7 

8 

9class VariationalAutoencoder(nn.Module): 

10 def __init__( 

11 self, num_meas, meas_width, meas_depth, num_latent, latent_width, latent_depth 

12 ): 

13 super(VariationalAutoencoder, self).__init__() 

14 self.encoder = Encoder(num_meas, num_latent, latent_width) 

15 self.decoder = Decoder( 

16 num_latent, latent_width, latent_depth, num_meas, meas_width, meas_depth 

17 ) 

18 

19 def forward(self, x, interv_idx): 

20 mu, logvar = self.encoder(x) 

21 latent = self.latent_sample(mu, logvar) 

22 x_recon, logcov = self.decoder(latent, interv_idx) 

23 

24 return x_recon, logcov, mu, logvar 

25 

26 def latent_sample(self, mu, logvar): 

27 # the re-parameterization trick 

28 if self.training: 

29 std = logvar.mul(0.5).exp_() 

30 eps = torch.empty_like(std).normal_() 

31 return eps.mul(std).add_(mu) 

32 else: 

33 return mu 

34 

35 

36class Encoder(nn.Module): 

37 def __init__(self, num_meas, num_latent, latent_width): 

38 super(Encoder, self).__init__() 

39 

40 # first encoder layer 

41 self.enc1 = nn.Linear(in_features=num_meas, out_features=num_meas) 

42 

43 # second encoder layer 

44 self.enc2 = nn.Linear(in_features=num_meas, out_features=num_meas) 

45 

46 # map to mu and variance 

47 num_vae_latent = num_latent * latent_width 

48 self.fc_mu = nn.Linear(in_features=num_meas, out_features=num_vae_latent) 

49 self.fc_logvar = nn.Linear(in_features=num_meas, out_features=num_vae_latent) 

50 

51 def forward(self, x): 

52 activation = torch.nn.GELU() 

53 # encoder layers 

54 x = activation(self.enc1(x)) 

55 x = activation(self.enc2(x)) 

56 

57 # calculate mu & logvar 

58 mu = self.fc_mu(x) 

59 logvar = self.fc_logvar(x) 

60 

61 return mu, logvar 

62 

63 

64class Decoder(nn.Module): 

65 def __init__( 

66 self, num_latent, latent_width, latent_depth, num_meas, meas_width, meas_depth 

67 ): 

68 super(Decoder, self).__init__() 

69 num_vae_latent = num_latent * latent_width 

70 num_vae_meas = num_meas * meas_width 

71 if meas_depth == 0 and meas_width > 1: 

72 num_vae_meas = num_meas 

73 warnings.warn( 

74 f"Reduced architecture complexity: `meas_width` set to 1 rather than {meas_width} since `meas_depth`={meas_depth}." 

75 ) 

76 

77 # hidden latent layers 

78 hidden_block = torch.ones(latent_width, latent_width) 

79 hidden_blocks = [hidden_block for _ in range(num_latent)] 

80 hidden_mask = torch.block_diag(*hidden_blocks) 

81 self.mean_hidden_latent = { 

82 layer_idx: SparseLinear( 

83 in_features=num_vae_latent, 

84 out_features=num_vae_latent, 

85 mask=hidden_mask, 

86 ) 

87 for layer_idx in range(latent_depth) 

88 } 

89 self.logcov_hidden_latent = { 

90 layer_idx: SparseLinear( 

91 in_features=num_vae_latent, 

92 out_features=num_vae_latent, 

93 mask=hidden_mask, 

94 ) 

95 for layer_idx in range(latent_depth) 

96 } 

97 

98 # causal layer 

99 self.mean_causal = { 

100 interv_idx: Intervenable( 

101 in_features=num_vae_latent, 

102 out_features=num_vae_latent, 

103 width=latent_width, 

104 ) 

105 for interv_idx in range(-1, num_latent) 

106 } 

107 

108 self.logcov_causal = SparseLinear( 

109 in_features=num_vae_latent, out_features=num_vae_latent 

110 ) 

111 

112 # mixture layer 

113 self.mean_mix = SparseLinear( 

114 in_features=num_vae_latent, out_features=num_vae_meas 

115 ) 

116 self.logcov_mix = SparseLinear( 

117 in_features=num_vae_latent, out_features=num_vae_meas 

118 ) 

119 

120 # additional mixture layers 

121 self.mean_hidden_mix = { 

122 layer_idx: SparseLinear( 

123 in_features=num_vae_meas, 

124 out_features=num_vae_meas, 

125 ) 

126 for layer_idx in range(meas_depth - 1) 

127 } 

128 self.mean_hidden_mix[meas_depth] = SparseLinear( 

129 in_features=num_vae_meas, 

130 out_features=num_meas, 

131 ) 

132 self.logcov_hidden_mix = { 

133 layer_idx: SparseLinear( 

134 in_features=num_vae_meas, 

135 out_features=num_vae_meas, 

136 ) 

137 for layer_idx in range(meas_depth - 1) 

138 } 

139 self.logcov_hidden_mix[meas_depth] = SparseLinear( 

140 in_features=num_vae_meas, 

141 out_features=num_meas, 

142 ) 

143 

144 self.activation = torch.nn.GELU() 

145 

146 def forward(self, z, interv_idx): 

147 # hidden layers for latent exogenous variables 

148 mean = z.clone() 

149 logcov = z.clone() 

150 for hidden_layer in self.mean_hidden_latent.values(): 

151 mean = hidden_layer(mean) 

152 mean = self.activation(mean) 

153 for hidden_layer in self.logcov_hidden_latent.values(): 

154 logcov = hidden_layer(logcov) 

155 logcov = self.activation(logcov) 

156 

157 # connect exogenous variables to latent causal DAG 

158 interv_idx = torch.unique(interv_idx).to(int) 

159 assert len(interv_idx) == 1 

160 obs_weight = self.mean_causal[-1].weight 

161 mean = self.mean_causal[int(interv_idx)](mean, obs_weight, interv_idx) 

162 mean = self.activation(mean) 

163 logcov = self.logcov_causal(logcov) 

164 logcov = self.activation(logcov) 

165 

166 # mix latent causal vars into measurements 

167 mean = self.mean_mix(mean) 

168 logcov = self.logcov_mix(logcov) 

169 

170 # hidden layers for mixture 

171 for hidden_layer in self.mean_hidden_mix.values(): 

172 mean = self.activation(mean) 

173 mean = hidden_layer(mean) 

174 for hidden_layer in self.logcov_hidden_mix.values(): 

175 logcov = self.activation(logcov) 

176 logcov = hidden_layer(logcov) 

177 

178 return mean, logcov 

179 

180 

181class SparseLinear(nn.Module): 

182 def __init__( 

183 self, 

184 in_features, 

185 out_features, 

186 mask=None, 

187 bias=True, 

188 device=None, 

189 dtype=None, 

190 ): 

191 factory_kwargs = {"device": device, "dtype": dtype} 

192 super(SparseLinear, self).__init__() 

193 self.in_features = in_features 

194 self.out_features = out_features 

195 self.mask = mask 

196 self.weight = Parameter( 

197 torch.empty((out_features, in_features), **factory_kwargs) 

198 ) 

199 

200 if bias: 

201 self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) 

202 else: 

203 self.register_parameter("bias", None) 

204 self.reset_parameters() 

205 

206 def reset_parameters(self): 

207 # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 

208 nn.init.orthogonal_(self.weight) 

209 # nn.init.sparse_(self.weight, 2 / 3) 

210 if self.bias is not None: 

211 fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 

212 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 

213 nn.init.uniform_(self.bias, -bound, bound) 

214 

215 def forward(self, input): 

216 # masked linear layer 

217 if self.mask is None: 

218 return nn.functional.linear(input, self.weight, self.bias) 

219 else: 

220 return nn.functional.linear(input, self.weight * self.mask, self.bias) 

221 

222 def extra_repr(self): 

223 return "in_features={}, out_features={}, bias={}".format( 

224 self.in_features, self.out_features, self.bias is not None 

225 ) 

226 

227 

228class Intervenable(SparseLinear): 

229 def __init__(self, width, **kwargs): 

230 super().__init__(**kwargs) 

231 self.width = width 

232 

233 def forward(self, input, obs_weight, interv_idx): 

234 if interv_idx != -1: 

235 min_weight = torch.minimum(self.weight, obs_weight) 

236 num_vars = len(self.weight) // self.width 

237 interv_mask = torch.ones(num_vars, num_vars) 

238 interv_mask[interv_idx] = 0 

239 interv_mask[interv_idx, interv_idx] = 1 

240 interv_mask = interv_mask.kron(torch.ones(self.width, self.width)) 

241 self.weight.data = min_weight * interv_mask 

242 if self.mask is None: 

243 return nn.functional.linear(input, self.weight, self.bias) 

244 else: 

245 return nn.functional.linear(input, self.weight * self.mask, self.bias)