Coverage for /opt/conda/lib/python3.13/site-packages/medil/models.py: 89%
225 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-01 15:11 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-01 15:11 +0000
1"""MeDIL causal model base class and a preconfigured NCFA class."""
3import copy
4import os
5import pickle
6import random
7import warnings
8from datetime import datetime
9from pathlib import Path
11import numpy as np
12import numpy.typing as npt
13import torch
14import torch.nn.functional as F
15from numpy.random import default_rng
16from scipy.optimize import minimize
17from sklearn.model_selection import train_test_split
18from torch.utils.data import DataLoader, TensorDataset
19from tqdm import tqdm
21from .ecc_algorithms import find_heuristic_1pc
22from .independence_testing import estimate_UDG
23from .vae import VariationalAutoencoder
26class MedilCausalModel(object):
27 """Base class using principle of polymorphism to establish common
28 interface for derived parametric estimators.
29 """
31 def __init__(
32 self,
33 biadj: npt.NDArray = np.array([]),
34 udg: npt.NDArray = np.array([]),
35 one_pure_child: bool = True,
36 rng=default_rng(0),
37 ) -> None:
38 self.biadj = biadj
39 self.udg = udg
40 self.one_pure_child = one_pure_child
41 self.rng = rng
43 def fit(self, dataset: npt.NDArray) -> "MedilCausalModel":
44 raise NotImplementedError
46 def sample(self, sample_size: int) -> npt.NDArray:
47 raise NotImplementedError
50class Parameters(object):
51 "Different parameterizations of MeDIL causal Models."
53 def __init__(self, parameterization: str) -> None:
54 self.parameterization = parameterization
56 if parameterization == "Gaussian":
57 self.error_means = np.array([])
58 self.error_variances = np.array([])
59 self.biadj_weights = np.array([])
60 elif parameterization == "VAE":
61 self.weights = np.array([])
62 self.vae = None
64 def __str__(self) -> str:
65 return "\n".join(
66 f"parameters.{attr}: {val}" for attr, val in vars(self).items()
67 )
70class GaussianMCM(MedilCausalModel):
71 """A linear MeDIL causal model with Gaussian random variables."""
73 def __init__(self, **kwargs):
74 super().__init__(**kwargs)
75 self.parameters = Parameters("Gaussian")
77 def fit(self, dataset: npt.NDArray) -> "GaussianMCM":
78 """Fit a Gaussian MCM to a dataset with constraint-based
79 structure learning and least squares parameter estimation."""
80 self.dataset = dataset
81 if self.biadj.size == 0:
82 self._compute_biadj()
84 self.parameters.error_means = self.dataset.mean(0)
86 cov = np.cov(self.dataset, rowvar=False)
88 num_weights = self.biadj.sum()
89 num_err_vars = self.biadj.shape[1]
91 def _objective(weights_and_err_vars):
92 weights = weights_and_err_vars[:num_weights]
93 err_vars = weights_and_err_vars[num_weights:]
95 biadj_weights = np.zeros_like(self.biadj, float)
96 biadj_weights[self.biadj] = weights
98 return (
99 (cov - biadj_weights.T @ biadj_weights - np.diagflat(err_vars)) ** 2
100 ).sum()
102 result = minimize(_objective, np.ones(num_weights + num_err_vars))
103 if not result.success:
104 warnings.warn(f"Optimization failed: {result.message}")
106 self.parameters.error_variances = result.x[num_weights:]
108 self.parameters.biadj_weights = np.zeros_like(self.biadj, float)
109 self.parameters.biadj_weights[self.biadj] = result.x[:num_weights]
111 return self
113 def _compute_biadj(self):
114 """Constraint-based structure learning."""
115 if self.udg.size == 0:
116 self._estimate_udg()
117 self.biadj = find_heuristic_1pc(self.udg)
119 def _estimate_udg(self):
120 """Constraint-based structure learning."""
121 samp_size = len(self.dataset)
122 cov = np.cov(self.dataset, rowvar=False)
123 corr = np.corrcoef(self.dataset, rowvar=False)
124 inner_numerator = 1 - cov * corr # should never be <= 0?
125 inner_numerator = inner_numerator.clip(min=0.00001)
126 inner_numerator[np.tril_indices_from(inner_numerator)] = 1
127 udg_triu = np.log(inner_numerator) < (-np.log(samp_size) / samp_size)
128 udg = udg_triu + udg_triu.T
129 self.udg = udg
131 def sample(self, sample_size: int, include_latent: bool = False) -> npt.NDArray:
132 """Sample a dataset from a GaussianMCM, after structure and
133 parameters have been specified or estimated."""
134 num_latent = len(self.biadj)
135 latent_sample = self.rng.multivariate_normal(
136 np.zeros(num_latent), np.eye(num_latent), sample_size
137 )
138 error_sample = self.rng.multivariate_normal(
139 self.parameters.error_means,
140 np.diagflat(self.parameters.error_variances),
141 sample_size,
142 )
143 sample = latent_sample @ self.parameters.biadj_weights + error_sample
145 return (sample, latent_sample) if include_latent else sample
148class NeuroCausalFactorAnalysis(MedilCausalModel):
149 """A MeDIL causal model represented by a masked variational autoencoder."""
151 def __init__(
152 self,
153 seed: int = 0,
154 log_path: str = "",
155 verbose: bool = False,
156 **kwargs,
157 ):
158 super().__init__(**kwargs)
160 if log_path:
161 Path(log_path).mkdir(exist_ok=True)
163 self.log_path = log_path
164 self.verbose = verbose
165 self.seed = seed
167 self.hyperparams = {
168 "method": "xicor",
169 "alpha": 0.05,
170 "batch_size": 128,
171 "num_epochs": 200,
172 "lr": 1e-3,
173 "beta": 1.0,
174 "latent_width": 2,
175 "meas_width": 2,
176 "num_hidden_layers": 1,
177 "encoder_hidden_dim": 64,
178 "shuffle": True,
179 "early_stopping": True,
180 "patience": 20,
181 "min_delta": 1e-4,
182 }
184 self.parameters = Parameters("VAE")
185 self.loss = None
186 self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
188 def log(self, entry: str) -> None:
189 if not (self.log_path or self.verbose):
190 return
192 time_stamped_entry = f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} {entry}"
194 if self.log_path:
195 with open(os.path.join(self.log_path, "training.log"), "a") as log_file:
196 log_file.write(time_stamped_entry + "\n")
198 if self.verbose:
199 print(time_stamped_entry)
201 def _set_deterministic_seed(self):
202 os.environ["PYTHONHASHSEED"] = str(self.seed)
203 random.seed(self.seed)
204 np.random.seed(self.seed)
206 torch.manual_seed(self.seed)
207 if torch.cuda.is_available():
208 torch.cuda.manual_seed(self.seed)
209 torch.cuda.manual_seed_all(self.seed)
211 torch.backends.cudnn.deterministic = True
212 torch.backends.cudnn.benchmark = False
213 torch.use_deterministic_algorithms(True)
215 def fit(self, dataset: npt.NDArray, split_idcs=None) -> "NeuroCausalFactorAnalysis":
216 self._set_deterministic_seed()
217 self.dataset = dataset
219 if self.biadj.size == 0:
220 self._compute_biadj()
222 if split_idcs is None:
223 train_split, valid_split = train_test_split(
224 dataset, train_size=0.7, random_state=self.seed
225 )
226 else:
227 train_split = dataset[split_idcs[0]]
228 valid_split = dataset[split_idcs[1]]
230 train_loader = self._data_loader(train_split)
231 valid_loader = self._data_loader(valid_split)
233 model_recon, loss_recon, error_recon = self._train_vae(
234 train_loader, valid_loader
235 )
237 if self.log_path:
238 torch.save(
239 model_recon.state_dict(), os.path.join(self.log_path, "model_recon.pt")
240 )
241 with open(os.path.join(self.log_path, "loss_recon.pkl"), "wb") as handle:
242 pickle.dump(loss_recon, handle, protocol=pickle.HIGHEST_PROTOCOL)
243 with open(os.path.join(self.log_path, "error_recon.pkl"), "wb") as handle:
244 pickle.dump(error_recon, handle, protocol=pickle.HIGHEST_PROTOCOL)
246 self.parameters.vae = model_recon
247 self.loss = {
248 "elbo_train": loss_recon[0],
249 "elbo_valid": loss_recon[1],
250 "recon_train": error_recon[0],
251 "recon_valid": error_recon[1],
252 }
254 return self
256 def _compute_biadj(self):
257 if self.udg.size == 0:
258 self._estimate_udg()
259 self.biadj = find_heuristic_1pc(self.udg)
261 def _estimate_udg(self):
262 self.udg, _ = estimate_UDG(
263 self.dataset,
264 method=self.hyperparams["method"],
265 significance_level=self.hyperparams["alpha"],
266 )
268 def _data_loader(self, sample):
269 sample_x = sample.astype(np.float32)
270 dataset = TensorDataset(torch.tensor(sample_x))
271 return DataLoader(
272 dataset,
273 batch_size=self.hyperparams["batch_size"],
274 shuffle=self.hyperparams["shuffle"],
275 num_workers=0,
276 )
278 def _train_vae(self, train_loader, valid_loader):
279 num_meas = self.dataset.shape[1]
280 biadj = torch.tensor(self.biadj.T, dtype=torch.float32)
282 model = VariationalAutoencoder(
283 num_latent=biadj.shape[1],
284 num_meas=num_meas,
285 num_hidden_layers=self.hyperparams["num_hidden_layers"],
286 latent_width=self.hyperparams["latent_width"],
287 meas_width=self.hyperparams["meas_width"],
288 biadj=biadj,
289 encoder_hidden_dim=self.hyperparams["encoder_hidden_dim"],
290 ).to(self.device)
292 optimizer = torch.optim.AdamW(model.parameters(), lr=self.hyperparams["lr"])
294 num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
295 self.log(f"Number of parameters: {num_params}")
297 train_elbo, train_error = [], []
298 valid_elbo, valid_error = [], []
300 best_valid = float("inf")
301 best_state = copy.deepcopy(model.state_dict())
302 epochs_without_improvement = 0
304 pbar = tqdm(
305 range(self.hyperparams["num_epochs"]), desc="Training NCFA", unit="epoch"
306 )
308 for epoch in pbar:
309 model.train()
310 train_lb, train_er, nbatch = 0.0, 0.0, 0
312 for (x_batch,) in train_loader:
313 x_batch = x_batch.to(self.device)
315 x_recon, mu, logvar = model(x_batch)
316 loss = self._vae_loss(
317 x_batch, x_recon, mu, logvar, beta=self.hyperparams["beta"]
318 )
319 error = self._recon_error(x_batch, x_recon)
321 optimizer.zero_grad()
322 loss.backward()
323 optimizer.step()
325 train_lb += loss.item() / x_batch.shape[0]
326 train_er += error.item() / x_batch.shape[0]
327 nbatch += 1
329 train_lb, train_er = self._eval_loss(model, train_loader)
330 train_elbo.append(train_lb)
331 train_error.append(train_er)
333 valid_lb, valid_er = self._eval_loss(model, valid_loader)
334 valid_elbo.append(valid_lb)
335 valid_error.append(valid_er)
337 pbar.set_postfix({"train": train_lb, "valid": valid_lb})
339 improved = valid_lb < (best_valid - self.hyperparams["min_delta"])
340 if improved:
341 best_valid = valid_lb
342 best_state = copy.deepcopy(model.state_dict())
343 epochs_without_improvement = 0
344 else:
345 epochs_without_improvement += 1
347 if (
348 self.hyperparams["early_stopping"]
349 and epochs_without_improvement >= self.hyperparams["patience"]
350 ):
351 self.log(f"Early stopping at epoch {epoch}")
352 break
354 model.load_state_dict(best_state)
356 return (
357 model,
358 [np.array(train_elbo), np.array(valid_elbo)],
359 [np.array(train_error), np.array(valid_error)],
360 )
362 def _eval_loss(self, model, loader):
363 model.eval()
364 total_loss = 0.0
365 total_recon = 0.0
366 n = 0
368 with torch.no_grad():
369 for (x_batch,) in loader:
370 x_batch = x_batch.to(self.device)
371 x_recon, mu, logvar = model(x_batch)
372 loss = self._vae_loss(
373 x_batch, x_recon, mu, logvar, beta=self.hyperparams["beta"]
374 )
375 recon = self._recon_error(x_batch, x_recon)
377 bs = x_batch.shape[0]
378 total_loss += loss.item()
379 total_recon += recon.item()
380 n += bs
382 return total_loss / n, total_recon / n
384 @staticmethod
385 def _vae_loss(x, x_recon, mu, logvar, beta=1.0):
386 recon_loss = F.mse_loss(x_recon, x, reduction="sum")
387 kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
388 return recon_loss + beta * kl_div
390 @staticmethod
391 def _recon_error(x, x_recon):
392 return torch.linalg.norm(x - x_recon, ord=2)
394 def set_full_decoder_mask(self, num_meas=None):
395 if num_meas is None:
396 if not hasattr(self, "dataset"):
397 raise ValueError("Provide num_meas or set dataset first.")
398 num_meas = self.dataset.shape[1]
400 num_latent = num_meas
401 self.biadj = np.ones((num_latent, num_meas), dtype=float)