Coverage for /opt/conda/lib/python3.12/site-packages/medil/interv_vae.py: 93%
121 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-25 05:42 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-25 05:42 +0000
1import math
2import warnings
4import torch
5from torch import nn
6from torch.nn.parameter import Parameter
9class VariationalAutoencoder(nn.Module):
10 def __init__(
11 self, num_meas, meas_width, meas_depth, num_latent, latent_width, latent_depth
12 ):
13 super(VariationalAutoencoder, self).__init__()
14 self.encoder = Encoder(num_meas, num_latent, latent_width)
15 self.decoder = Decoder(
16 num_latent, latent_width, latent_depth, num_meas, meas_width, meas_depth
17 )
19 def forward(self, x, interv_idx):
20 mu, logvar = self.encoder(x)
21 latent = self.latent_sample(mu, logvar)
22 x_recon, logcov = self.decoder(latent, interv_idx)
24 return x_recon, logcov, mu, logvar
26 def latent_sample(self, mu, logvar):
27 # the re-parameterization trick
28 if self.training:
29 std = logvar.mul(0.5).exp_()
30 eps = torch.empty_like(std).normal_()
31 return eps.mul(std).add_(mu)
32 else:
33 return mu
36class Encoder(nn.Module):
37 def __init__(self, num_meas, num_latent, latent_width):
38 super(Encoder, self).__init__()
40 # first encoder layer
41 self.enc1 = nn.Linear(in_features=num_meas, out_features=num_meas)
43 # second encoder layer
44 self.enc2 = nn.Linear(in_features=num_meas, out_features=num_meas)
46 # map to mu and variance
47 num_vae_latent = num_latent * latent_width
48 self.fc_mu = nn.Linear(in_features=num_meas, out_features=num_vae_latent)
49 self.fc_logvar = nn.Linear(in_features=num_meas, out_features=num_vae_latent)
51 def forward(self, x):
52 activation = torch.nn.GELU()
53 # encoder layers
54 x = activation(self.enc1(x))
55 x = activation(self.enc2(x))
57 # calculate mu & logvar
58 mu = self.fc_mu(x)
59 logvar = self.fc_logvar(x)
61 return mu, logvar
64class Decoder(nn.Module):
65 def __init__(
66 self, num_latent, latent_width, latent_depth, num_meas, meas_width, meas_depth
67 ):
68 super(Decoder, self).__init__()
69 num_vae_latent = num_latent * latent_width
70 num_vae_meas = num_meas * meas_width
71 if meas_depth == 0 and meas_width > 1:
72 num_vae_meas = num_meas
73 warnings.warn(
74 f"Reduced architecture complexity: `meas_width` set to 1 rather than {meas_width} since `meas_depth`={meas_depth}."
75 )
77 # hidden latent layers
78 hidden_block = torch.ones(latent_width, latent_width)
79 hidden_blocks = [hidden_block for _ in range(num_latent)]
80 hidden_mask = torch.block_diag(*hidden_blocks)
81 self.mean_hidden_latent = {
82 layer_idx: SparseLinear(
83 in_features=num_vae_latent,
84 out_features=num_vae_latent,
85 mask=hidden_mask,
86 )
87 for layer_idx in range(latent_depth)
88 }
89 self.logcov_hidden_latent = {
90 layer_idx: SparseLinear(
91 in_features=num_vae_latent,
92 out_features=num_vae_latent,
93 mask=hidden_mask,
94 )
95 for layer_idx in range(latent_depth)
96 }
98 # causal layer
99 self.mean_causal = Intervenable(
100 in_features=num_vae_latent, out_features=num_vae_latent, width=latent_width
101 )
102 self.logcov_causal = SparseLinear(
103 in_features=num_vae_latent, out_features=num_vae_latent
104 )
106 # mixture layer
107 self.mean_mix = SparseLinear(
108 in_features=num_vae_latent, out_features=num_vae_meas
109 )
110 self.logcov_mix = SparseLinear(
111 in_features=num_vae_latent, out_features=num_vae_meas
112 )
114 # additional mixture layers
115 self.mean_hidden_mix = {
116 layer_idx: SparseLinear(
117 in_features=num_vae_meas,
118 out_features=num_vae_meas,
119 )
120 for layer_idx in range(meas_depth - 1)
121 }
122 self.mean_hidden_mix[meas_depth] = SparseLinear(
123 in_features=num_vae_meas,
124 out_features=num_meas,
125 )
126 self.logcov_hidden_mix = {
127 layer_idx: SparseLinear(
128 in_features=num_vae_meas,
129 out_features=num_vae_meas,
130 )
131 for layer_idx in range(meas_depth - 1)
132 }
133 self.logcov_hidden_mix[meas_depth] = SparseLinear(
134 in_features=num_vae_meas,
135 out_features=num_meas,
136 )
138 self.activation = torch.nn.GELU()
140 def forward(self, z, interv_idx):
141 # hidden layers for latent exogenous variables
142 mean = z.clone()
143 logcov = z.clone()
144 for hidden_layer in self.mean_hidden_latent.values():
145 mean = hidden_layer(mean)
146 mean = self.activation(mean)
147 for hidden_layer in self.logcov_hidden_latent.values():
148 logcov = hidden_layer(logcov)
149 logcov = self.activation(logcov)
151 # connect exogenous variables to latent causal DAG
152 mean = self.mean_causal(mean, interv_idx)
153 mean = self.activation(mean)
154 logcov = self.logcov_causal(logcov)
155 logcov = self.activation(logcov)
157 # mix latent causal vars into measurements
158 mean = self.mean_mix(mean)
159 logcov = self.logcov_mix(logcov)
161 # hidden layers for mixture
162 for hidden_layer in self.mean_hidden_mix.values():
163 mean = self.activation(mean)
164 mean = hidden_layer(mean)
165 for hidden_layer in self.logcov_hidden_mix.values():
166 logcov = self.activation(logcov)
167 logcov = hidden_layer(logcov)
169 return mean, logcov
172class SparseLinear(nn.Module):
173 def __init__(
174 self,
175 in_features,
176 out_features,
177 mask=None,
178 bias=True,
179 device=None,
180 dtype=None,
181 ):
182 factory_kwargs = {"device": device, "dtype": dtype}
183 super(SparseLinear, self).__init__()
184 self.in_features = in_features
185 self.out_features = out_features
186 self.mask = mask
187 self.weight = Parameter(
188 torch.empty((out_features, in_features), **factory_kwargs)
189 )
191 if bias:
192 self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
193 else:
194 self.register_parameter("bias", None)
195 self.reset_parameters()
197 def reset_parameters(self):
198 # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
199 nn.init.orthogonal_(self.weight)
200 # nn.init.sparse_(self.weight, 2 / 3)
201 if self.bias is not None:
202 fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
203 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
204 nn.init.uniform_(self.bias, -bound, bound)
206 def forward(self, input):
207 # masked linear layer
208 if self.mask is None:
209 return nn.functional.linear(input, self.weight, self.bias)
210 else:
211 return nn.functional.linear(input, self.weight * self.mask, self.bias)
213 def extra_repr(self):
214 return "in_features={}, out_features={}, bias={}".format(
215 self.in_features, self.out_features, self.bias is not None
216 )
219class Intervenable(SparseLinear):
220 def __init__(self, width, **kwargs):
221 super().__init__(**kwargs)
222 self.width = width
224 def forward(self, input, interv_idx):
225 interv_idx = interv_idx.squeeze().to(int) # reformat
226 observed = nn.functional.linear(input, self.weight, self.bias)
228 # interv idx of -1 indicates no intervention, so remove those
229 sample_idx = torch.arange(len(input))
230 obs_mask = interv_idx == -1
231 sample_idx = sample_idx[~obs_mask]
232 interv_idx = interv_idx[~obs_mask]
234 # expand idcs according to width of layer in VAE
235 expander = torch.arange(self.width).tile(len(sample_idx))
236 sample_idx = sample_idx.tile(self.width)
237 interv_idx = interv_idx.tile(self.width) + expander
239 # perform intervention
240 intervened = observed
241 intervened[sample_idx, interv_idx] = input[sample_idx, interv_idx]
242 return intervened