Coverage for /opt/conda/lib/python3.12/site-packages/medil/interv_vae2.py: 92%
123 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-25 05:42 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-25 05:42 +0000
1import math
2import warnings
4import torch
5from torch import nn
6from torch.nn.parameter import Parameter
9class VariationalAutoencoder(nn.Module):
10 def __init__(
11 self, num_meas, meas_width, meas_depth, num_latent, latent_width, latent_depth
12 ):
13 super(VariationalAutoencoder, self).__init__()
14 self.encoder = Encoder(num_meas, num_latent, latent_width)
15 self.decoder = Decoder(
16 num_latent, latent_width, latent_depth, num_meas, meas_width, meas_depth
17 )
19 def forward(self, x, interv_idx):
20 mu, logvar = self.encoder(x)
21 latent = self.latent_sample(mu, logvar)
22 x_recon, logcov = self.decoder(latent, interv_idx)
24 return x_recon, logcov, mu, logvar
26 def latent_sample(self, mu, logvar):
27 # the re-parameterization trick
28 if self.training:
29 std = logvar.mul(0.5).exp_()
30 eps = torch.empty_like(std).normal_()
31 return eps.mul(std).add_(mu)
32 else:
33 return mu
36class Encoder(nn.Module):
37 def __init__(self, num_meas, num_latent, latent_width):
38 super(Encoder, self).__init__()
40 # first encoder layer
41 self.enc1 = nn.Linear(in_features=num_meas, out_features=num_meas)
43 # second encoder layer
44 self.enc2 = nn.Linear(in_features=num_meas, out_features=num_meas)
46 # map to mu and variance
47 num_vae_latent = num_latent * latent_width
48 self.fc_mu = nn.Linear(in_features=num_meas, out_features=num_vae_latent)
49 self.fc_logvar = nn.Linear(in_features=num_meas, out_features=num_vae_latent)
51 def forward(self, x):
52 activation = torch.nn.GELU()
53 # encoder layers
54 x = activation(self.enc1(x))
55 x = activation(self.enc2(x))
57 # calculate mu & logvar
58 mu = self.fc_mu(x)
59 logvar = self.fc_logvar(x)
61 return mu, logvar
64class Decoder(nn.Module):
65 def __init__(
66 self, num_latent, latent_width, latent_depth, num_meas, meas_width, meas_depth
67 ):
68 super(Decoder, self).__init__()
69 num_vae_latent = num_latent * latent_width
70 num_vae_meas = num_meas * meas_width
71 if meas_depth == 0 and meas_width > 1:
72 num_vae_meas = num_meas
73 warnings.warn(
74 f"Reduced architecture complexity: `meas_width` set to 1 rather than {meas_width} since `meas_depth`={meas_depth}."
75 )
77 # hidden latent layers
78 hidden_block = torch.ones(latent_width, latent_width)
79 hidden_blocks = [hidden_block for _ in range(num_latent)]
80 hidden_mask = torch.block_diag(*hidden_blocks)
81 self.mean_hidden_latent = {
82 layer_idx: SparseLinear(
83 in_features=num_vae_latent,
84 out_features=num_vae_latent,
85 mask=hidden_mask,
86 )
87 for layer_idx in range(latent_depth)
88 }
89 self.logcov_hidden_latent = {
90 layer_idx: SparseLinear(
91 in_features=num_vae_latent,
92 out_features=num_vae_latent,
93 mask=hidden_mask,
94 )
95 for layer_idx in range(latent_depth)
96 }
98 # causal layer
99 self.mean_causal = {
100 interv_idx: Intervenable(
101 in_features=num_vae_latent,
102 out_features=num_vae_latent,
103 width=latent_width,
104 )
105 for interv_idx in range(-1, num_latent)
106 }
108 self.logcov_causal = SparseLinear(
109 in_features=num_vae_latent, out_features=num_vae_latent
110 )
112 # mixture layer
113 self.mean_mix = SparseLinear(
114 in_features=num_vae_latent, out_features=num_vae_meas
115 )
116 self.logcov_mix = SparseLinear(
117 in_features=num_vae_latent, out_features=num_vae_meas
118 )
120 # additional mixture layers
121 self.mean_hidden_mix = {
122 layer_idx: SparseLinear(
123 in_features=num_vae_meas,
124 out_features=num_vae_meas,
125 )
126 for layer_idx in range(meas_depth - 1)
127 }
128 self.mean_hidden_mix[meas_depth] = SparseLinear(
129 in_features=num_vae_meas,
130 out_features=num_meas,
131 )
132 self.logcov_hidden_mix = {
133 layer_idx: SparseLinear(
134 in_features=num_vae_meas,
135 out_features=num_vae_meas,
136 )
137 for layer_idx in range(meas_depth - 1)
138 }
139 self.logcov_hidden_mix[meas_depth] = SparseLinear(
140 in_features=num_vae_meas,
141 out_features=num_meas,
142 )
144 self.activation = torch.nn.GELU()
146 def forward(self, z, interv_idx):
147 # hidden layers for latent exogenous variables
148 mean = z.clone()
149 logcov = z.clone()
150 for hidden_layer in self.mean_hidden_latent.values():
151 mean = hidden_layer(mean)
152 mean = self.activation(mean)
153 for hidden_layer in self.logcov_hidden_latent.values():
154 logcov = hidden_layer(logcov)
155 logcov = self.activation(logcov)
157 # connect exogenous variables to latent causal DAG
158 interv_idx = torch.unique(interv_idx).to(int)
159 assert len(interv_idx) == 1
160 obs_weight = self.mean_causal[-1].weight
161 mean = self.mean_causal[int(interv_idx)](mean, obs_weight, interv_idx)
162 mean = self.activation(mean)
163 logcov = self.logcov_causal(logcov)
164 logcov = self.activation(logcov)
166 # mix latent causal vars into measurements
167 mean = self.mean_mix(mean)
168 logcov = self.logcov_mix(logcov)
170 # hidden layers for mixture
171 for hidden_layer in self.mean_hidden_mix.values():
172 mean = self.activation(mean)
173 mean = hidden_layer(mean)
174 for hidden_layer in self.logcov_hidden_mix.values():
175 logcov = self.activation(logcov)
176 logcov = hidden_layer(logcov)
178 return mean, logcov
181class SparseLinear(nn.Module):
182 def __init__(
183 self,
184 in_features,
185 out_features,
186 mask=None,
187 bias=True,
188 device=None,
189 dtype=None,
190 ):
191 factory_kwargs = {"device": device, "dtype": dtype}
192 super(SparseLinear, self).__init__()
193 self.in_features = in_features
194 self.out_features = out_features
195 self.mask = mask
196 self.weight = Parameter(
197 torch.empty((out_features, in_features), **factory_kwargs)
198 )
200 if bias:
201 self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
202 else:
203 self.register_parameter("bias", None)
204 self.reset_parameters()
206 def reset_parameters(self):
207 # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
208 nn.init.orthogonal_(self.weight)
209 # nn.init.sparse_(self.weight, 2 / 3)
210 if self.bias is not None:
211 fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
212 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
213 nn.init.uniform_(self.bias, -bound, bound)
215 def forward(self, input):
216 # masked linear layer
217 if self.mask is None:
218 return nn.functional.linear(input, self.weight, self.bias)
219 else:
220 return nn.functional.linear(input, self.weight * self.mask, self.bias)
222 def extra_repr(self):
223 return "in_features={}, out_features={}, bias={}".format(
224 self.in_features, self.out_features, self.bias is not None
225 )
228class Intervenable(SparseLinear):
229 def __init__(self, width, **kwargs):
230 super().__init__(**kwargs)
231 self.width = width
233 def forward(self, input, obs_weight, interv_idx):
234 if interv_idx != -1:
235 min_weight = torch.minimum(self.weight, obs_weight)
236 num_vars = len(self.weight) // self.width
237 interv_mask = torch.ones(num_vars, num_vars)
238 interv_mask[interv_idx] = 0
239 interv_mask[interv_idx, interv_idx] = 1
240 interv_mask = interv_mask.kron(torch.ones(self.width, self.width))
241 self.weight.data = min_weight * interv_mask
242 if self.mask is None:
243 return nn.functional.linear(input, self.weight, self.bias)
244 else:
245 return nn.functional.linear(input, self.weight * self.mask, self.bias)