Coverage for /opt/conda/lib/python3.12/site-packages/medil/models.py: 84%
740 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-25 05:42 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-25 05:42 +0000
1"""MeDIL causal model base class and a preconfigured NCFA class."""
3import os
4import pickle
5import warnings
6from datetime import datetime
7from itertools import chain
8from pathlib import Path
9from typing import Iterator, List
11import numpy as np
12import numpy.typing as npt
13import torch
14from numpy.random import default_rng
15from scipy.linalg import norm
16from scipy.optimize import minimize
17from sklearn.model_selection import train_test_split
18from torch import nn
19from torch.nn.functional import lp_pool2d
20from torch.nn.parameter import Parameter
21from torch.utils.data import DataLoader, TensorDataset
22from tqdm import tqdm
24from .ecc_algorithms import find_heuristic_1pc
25from .independence_testing import estimate_UDG
26from .interv_vae import VariationalAutoencoder as InterVAE
27from .interv_vae2 import VariationalAutoencoder as InterVAE2
28from .vae import VariationalAutoencoder
31class MedilCausalModel(object):
32 """Base class using principle of polymorphism to establish common
33 interface for derived parametric estimators.
34 """
36 def __init__(
37 self,
38 biadj: npt.NDArray = np.array([]),
39 udg: npt.NDArray = np.array([]),
40 one_pure_child: bool = True,
41 rng=default_rng(0),
42 ) -> None:
43 self.biadj = biadj
44 self.udg = udg
45 self.one_pure_child = one_pure_child
46 self.rng = rng
48 def fit(self, dataset: npt.NDArray) -> "MedilCausalModel":
49 raise NotImplementedError
51 def sample(self, sample_size: int) -> npt.NDArray:
52 raise NotImplementedError
55class Parameters(object):
56 "Different parameterizations of MeDIL causal Models."
58 def __init__(self, parameterization: str) -> None:
59 self.parameterization = parameterization
61 if parameterization == "Gaussian":
62 self.error_means = np.array([])
63 self.error_variances = np.array([])
64 self.biadj_weights = np.array([])
65 elif parameterization == "VAE":
66 self.weights = np.array([])
67 with warnings.catch_warnings(action="ignore"):
68 self.vae = VariationalAutoencoder(0, 0, 0, 0)
69 elif parameterization == "InterVAE":
70 self.weights = np.array([])
71 with warnings.catch_warnings(action="ignore"):
72 self.vae = InterVAE(0, 0, 0, 0, 0, 0)
73 elif parameterization == "InterVAE2":
74 self.weights = np.array([])
75 with warnings.catch_warnings(action="ignore"):
76 self.vae = InterVAE2(0, 0, 0, 0, 0, 0)
78 def __str__(self) -> str:
79 return "\n".join(
80 f"parameters.{attr}: {val}" for attr, val in vars(self).items()
81 )
84class GaussianMCM(MedilCausalModel):
85 """A linear MeDIL causal model with Gaussian random variables."""
87 def __init__(self, **kwargs):
88 super().__init__(**kwargs)
89 self.parameters = Parameters("Gaussian")
91 def fit(self, dataset: npt.NDArray) -> "GaussianMCM":
92 """Fit a Gaussian MCM to a dataset with constraint-based
93 structure learning and least squares parameter estimation."""
94 self.dataset = dataset
95 if self.biadj.size == 0:
96 self._compute_biadj()
98 self.parameters.error_means = self.dataset.mean(0)
100 cov = np.cov(self.dataset, rowvar=False)
102 num_weights = self.biadj.sum()
103 num_err_vars = self.biadj.shape[1]
105 def _objective(weights_and_err_vars):
106 weights = weights_and_err_vars[:num_weights]
107 err_vars = weights_and_err_vars[num_weights:]
109 biadj_weights = np.zeros_like(self.biadj, float)
110 biadj_weights[self.biadj] = weights
112 return (
113 (cov - biadj_weights.T @ biadj_weights - np.diagflat(err_vars)) ** 2
114 ).sum()
116 result = minimize(_objective, np.ones(num_weights + num_err_vars))
117 if not result.success:
118 warnings.warn(f"Optimization failed: {result.message}")
120 self.parameters.error_variances = result.x[num_weights:]
122 self.parameters.biadj_weights = np.zeros_like(self.biadj, float)
123 self.parameters.biadj_weights[self.biadj] = result.x[:num_weights]
125 return self
127 def _compute_biadj(self):
128 """Constraint-based structure learning."""
129 if self.udg.size == 0:
130 self._estimate_udg()
131 self.biadj = find_heuristic_1pc(self.udg)
133 def _estimate_udg(self):
134 """Constraint-based structure learning."""
135 samp_size = len(self.dataset)
136 cov = np.cov(self.dataset, rowvar=False)
137 corr = np.corrcoef(self.dataset, rowvar=False)
138 inner_numerator = 1 - cov * corr # should never be <= 0?
139 inner_numerator = inner_numerator.clip(min=0.00001)
140 inner_numerator[np.tril_indices_from(inner_numerator)] = 1
141 udg_triu = np.log(inner_numerator) < (-np.log(samp_size) / samp_size)
142 udg = udg_triu + udg_triu.T
143 self.udg = udg
145 def sample(self, sample_size: int) -> npt.NDArray:
146 """Sample a dataset from a GaussianMCM, after structure and
147 parameters have been specified or estimated."""
148 num_latent = len(self.biadj)
149 latent_sample = self.rng.multivariate_normal(
150 np.zeros(num_latent), np.eye(num_latent), sample_size
151 )
152 error_sample = self.rng.multivariate_normal(
153 self.parameters.error_means,
154 np.diagflat(self.parameters.error_variances),
155 sample_size,
156 )
157 sample = latent_sample @ self.parameters.biadj_weights + error_sample
158 return sample
161class NeuroCausalFactorAnalysis(MedilCausalModel):
162 """A MeDIL causal model represented by a deep generative model."""
164 def __init__(
165 self,
166 seed: int = 0,
167 log_path: str = "",
168 verbose: bool = False,
169 **kwargs,
170 ):
171 super().__init__(**kwargs)
172 if log_path:
173 Path(log_path).mkdir(exist_ok=True)
174 self.log_path = log_path
175 self.verbose = verbose
176 self.seed = seed
177 self.hyperparams = {
178 "heuristic": True,
179 "method": "xicor",
180 "alpha": 0.05,
181 "batch_size": 128,
182 "num_epochs": 200,
183 "lr": 0.005,
184 "beta": 1,
185 "num_valid": 1000,
186 "mu": 0.01,
187 "lambda": 0.01,
188 "deg_of_free": 2,
189 "width_per_meas": 2,
190 "num_hidden_layers": 1,
191 "prior_biadj": None,
192 }
193 self.parameters = Parameters("VAE")
194 self.loss = None
195 self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
197 def log(self, entry: str) -> None:
198 if not (self.log_path or self.verbose):
199 return
200 time_stamped_entry = f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} {entry}"
201 if self.log_path:
202 with open(f"{self.log_path}training.log", "a") as log_file:
203 log_file.write(time_stamped_entry + "\n")
204 if self.verbose:
205 print(time_stamped_entry)
207 def fit(self, dataset: npt.NDArray, split_idcs=None) -> "NeuroCausalFactorAnalysis":
208 self.dataset = dataset
210 # random train/val split if explicit indices not provided
211 if split_idcs is None:
212 train_split, valid_split = train_test_split(
213 dataset, train_size=0.7, random_state=self.seed
214 )
215 else:
216 train_split = dataset[split_idcs[0]]
217 valid_split = dataset[split_idcs[1]]
219 train_loader = self._data_loader(train_split)
220 valid_loader = self._data_loader(valid_split)
222 np.random.seed(self.seed)
224 model_recon, loss_recon, error_recon = self._train_vae(
225 train_loader, valid_loader
226 )
227 if self.log_path:
228 torch.save(model_recon, os.path.join(self.log_path, "model_recon.pt"))
229 with open(os.path.join(self.log_path, "loss_recon.pkl"), "wb") as handle:
230 pickle.dump(loss_recon, handle, protocol=pickle.HIGHEST_PROTOCOL)
231 with open(os.path.join(self.log_path, "error_recon.pkl"), "wb") as handle:
232 pickle.dump(error_recon, handle, protocol=pickle.HIGHEST_PROTOCOL)
233 # self.parameters.weights = (
234 # model_recon.decoder.mean_linear_fulcon.weight.detach().numpy().T
235 # )
236 self.parameters.vae = model_recon
237 self.loss = {
238 "elbo_train": loss_recon[0],
239 "elbo_valid": loss_recon[1],
240 "recon_train": error_recon[0],
241 "recon_valid": error_recon[1],
242 }
243 return self
245 def _compute_biadj(self):
246 if self.udg.size == 0:
247 self._estimate_udg()
248 self.biadj = find_heuristic_1pc(self.udg)
250 def _estimate_udg(self):
251 self.udg, pvals = estimate_UDG(
252 self.dataset,
253 method=self.hyperparams["method"],
254 significance_level=self.hyperparams["alpha"],
255 )
257 def _data_loader(self, sample):
258 sample_x = sample.astype(np.float32)
259 sample_z = np.empty(shape=(sample_x.shape[0], 0)).astype(np.float32)
260 dataset = TensorDataset(torch.tensor(sample_x), torch.tensor(sample_z))
261 data_loader = DataLoader(
262 dataset, batch_size=self.hyperparams["batch_size"], shuffle=False
263 )
264 return data_loader
266 def _train_vae(self, train_loader, valid_loader):
267 """Training VAE with the specified image dataset
268 :param m: dimension of the latent variable
269 :param n: dimension of the observed variable
270 :param train_loader: training image dataset loader
271 :param valid_loader: validation image dataset loader
272 :param biadj_mat: the adjacency matrix of the directed graph
273 :param seed: random seed for the experiments
274 :return: trained model and training loss history
275 """
277 num_meas = self.dataset.shape[1]
278 self.num_meas = num_meas
279 num_vae_latent = self.hyperparams["deg_of_free"] * num_meas
280 num_hidden_layers = self.hyperparams["num_hidden_layers"]
281 width_per_meas = self.hyperparams["width_per_meas"]
282 prior_biadj = self.hyperparams["prior_biadj"]
284 # building VAE
285 model = VariationalAutoencoder(
286 num_vae_latent, num_meas, num_hidden_layers, width_per_meas, prior_biadj
287 )
288 model = model.to(self.device)
289 optimizer = torch.optim.AdamW(
290 model.parameters(), lr=self.hyperparams["lr"], weight_decay=1e-5
291 )
292 # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.90)
293 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
294 optimizer, patience=10, factor=0.5
295 )
296 num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
297 self.log(f"Number of parameters: {num_params}")
299 # training loop
300 model.train()
301 train_elbo, train_error = [], []
302 valid_elbo, valid_error = [], []
304 pbar = tqdm(
305 range(self.hyperparams["num_epochs"]), desc="Training NCFA", unit="epoch"
306 )
307 for idx in pbar:
308 self.log(f"Training on epoch {idx}...")
309 train_lb, train_er, nbatch = 0.0, 0.0, 0
311 for x_batch, _ in train_loader:
312 batch_size = x_batch.shape[0]
313 x_batch = x_batch.to(self.device)
314 recon_batch, logcov_batch, mu_batch, logvar_batch = model(x_batch)
315 weight_batch = model.decoder.mean_linear_fulcon.weight
316 loss = self._elbo_gaussian(
317 x_batch,
318 recon_batch,
319 logcov_batch,
320 mu_batch,
321 logvar_batch,
322 weight_batch,
323 self.hyperparams["beta"],
324 )
325 error = self._recon_error(
326 x_batch, recon_batch, logcov_batch, weighted=False
327 )
328 optimizer.zero_grad()
329 loss.backward()
330 optimizer.step()
332 # update loss and nbatch
333 train_lb += loss.item() / batch_size
334 train_er += error.item() / batch_size
335 nbatch += 1
337 # finish training epoch
338 # scheduler.step()
339 train_lb = train_lb / nbatch
340 train_er = train_er / nbatch
341 train_elbo.append(train_lb)
342 train_error.append(train_er)
343 self.log(f"Finish training epoch {idx} with loss {train_lb}")
345 # append validation loss
346 valid_lb, valid_er = self._valid_vae(model, valid_loader)
347 valid_elbo.append(valid_lb)
348 valid_error.append(valid_er)
350 # decrease learning rate if validation plateaus
351 scheduler.step(valid_lb)
353 # update tqdm progress bar
354 pbar.set_postfix({"loss": train_lb}) # , "validation loss": valid_lb
356 train_elbo, train_error = np.array(train_elbo), np.array(train_error)
357 valid_elbo, valid_error = np.array(valid_elbo), np.array(valid_error)
358 elbo = [train_elbo, valid_elbo]
359 error = [train_error, valid_error]
361 return model, elbo, error
363 def _valid_vae(self, model, valid_loader):
364 """Training VAE with the specified image dataset
365 :param model: trained VAE model
366 :param valid_loader: validation image dataset loader
367 :return: validation loss
368 """
369 # set to evaluation mode
370 model.eval()
371 valid_lb, valid_er, nbatch = 0.0, 0.0, 0
373 for x_batch, _ in valid_loader:
374 with torch.no_grad():
375 batch_size = x_batch.shape[0]
376 x_batch = x_batch.to(self.device)
377 recon_batch, logcov_batch, mu_batch, logvar_batch = model(x_batch)
378 loss = self._elbo_gaussian(
379 x_batch,
380 recon_batch,
381 logcov_batch,
382 mu_batch,
383 logvar_batch,
384 None,
385 self.hyperparams["beta"],
386 )
387 error = self._recon_error(
388 x_batch, recon_batch, logcov_batch, weighted=False
389 )
391 # update loss and nbatch
392 valid_lb += loss.item() / batch_size
393 valid_er += error.item() / batch_size
394 nbatch += 1
396 # report validation loss
397 valid_lb = valid_lb / nbatch
398 valid_er = valid_er / nbatch
399 self.log(f"Finish validation with loss {valid_lb}")
401 return valid_lb, valid_er
403 def _elbo_gaussian(self, x, x_recon, logcov, mu, logvar, weight, beta):
404 """Calculating loss for variational autoencoder
405 :param x: original image
406 :param x_recon: reconstruction in the output layer
407 :param logcov: log of covariance matrix of the data distribution
408 :param mu: mean in the fitted variational distribution
409 :param logvar: log of the variance in the variational distribution
410 :param beta: beta
411 :return: reconstruction loss + KL
412 """
414 # KL-divergence
415 # https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
416 # https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py
417 # https://arxiv.org/pdf/1312.6114.pdf
418 kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
420 # reconstruction loss
421 cov = torch.exp(logcov)
422 cov = self._apply_along_axis(torch.diag, cov, axis=0)
423 cov = cov.mean(axis=0)
425 diff = x - x_recon
426 recon_loss = torch.sum(
427 torch.det(cov)
428 + torch.diagonal(
429 torch.mm(
430 torch.mm(diff, torch.inverse(cov)), torch.transpose(diff, 0, 1)
431 )
432 )
433 ).mul(-1 / 2)
435 # elbo
436 loss = -beta * kl_div + recon_loss
437 if weight is not None:
438 llambda, mu = self.hyperparams["lambda"], self.hyperparams["mu"]
439 norm_type = 2
440 kernel_size = (
441 self.hyperparams["width_per_meas"],
442 self.hyperparams["deg_of_free"],
443 )
444 weight = weight[None, None, :, :]
445 mu_weight = lp_pool2d(
446 weight, norm_type, kernel_size
447 ).squeeze() # penalize num edges
448 self.parameters.biadj = mu_weight.detach().numpy().T
449 ll_kernel_size = (
450 self.hyperparams["width_per_meas"] * self.num_meas,
451 self.hyperparams["deg_of_free"],
452 )
453 ll_weight = lp_pool2d(
454 weight, norm_type, ll_kernel_size
455 ).squeeze() # penalize num latents
456 return -loss + llambda * ll_weight.norm(1) + mu * mu_weight.norm(1)
457 return -loss
459 def _recon_error(self, x, x_recon, logcov, weighted):
460 """Reconstruction error given x and x_recon
461 :param x: original image
462 :param x_recon: reconstruction in the output layer
463 :param logcov: covariance matrix of the data distribution
464 :param weighted: whether to use weighted reconstruction
466 Returns
467 -------
468 error: reconstruction error
469 """
471 # reconstruction loss
472 cov = torch.exp(logcov)
473 cov = self._apply_along_axis(torch.diag, cov, axis=0)
474 cov = cov.mean(axis=0)
476 diff = x - x_recon
477 if weighted:
478 error = torch.sum(
479 torch.det(cov)
480 + torch.diagonal(
481 torch.mm(
482 torch.mm(diff, torch.inverse(cov)), torch.transpose(diff, 0, 1)
483 )
484 )
485 ).mul(-1 / 2)
486 else:
487 error = torch.linalg.norm(diff, ord=2)
489 return error
491 @staticmethod
492 def _apply_along_axis(function, x, axis=0):
493 """Helper function to return along a particular axis
494 Parameters
495 ----------
496 function: function to be applied
497 x: data
498 axis: axis to apply the function
500 Returns
501 -------
502 The output applied to the axis
503 """
505 return torch.stack(
506 [function(x_i) for x_i in torch.unbind(x, dim=axis)], dim=axis
507 )
510# implement penalized mle and penalized lse with a new class
511class DevMedil(MedilCausalModel):
512 def __init__(self, **kwargs):
513 super().__init__(**kwargs)
514 self.init_W = None
516 # penalized MLE
517 def fit_penalized_mle(
518 self,
519 dataset: npt.NDArray,
520 lambda_reg: float = 0.1,
521 mu_reg: float = 0.1,
522 ) -> "DevMedil":
523 num_meas = dataset.shape[1]
524 if self.one_pure_child:
525 num_latent = num_meas
526 else:
527 num_latent = (num_meas**2) // 4
528 Sigma_hat = np.cov(dataset.T)
530 def penalized_mle_loss(W_and_D):
531 W = W_and_D[: num_latent * num_meas].reshape(num_latent, num_meas)
532 D = np.diag(W_and_D[num_latent * num_meas :])
534 Sigma = self.compute_sigma(W, D)
535 Sigma_inv = np.linalg.inv(Sigma)
536 sign, logdet = np.linalg.slogdet(Sigma_inv)
538 if sign <= 0:
539 return np.inf
541 loss = np.trace(np.dot(Sigma_hat, Sigma_inv)) - sign * logdet
542 loss += lambda_reg * self.rho(W) + mu_reg * self.sigma(W)
544 return loss
546 initial_W = (
547 self.rng.standard_normal((num_latent, num_meas))
548 if self.init_W is None
549 else self.init_W
550 )
551 initial_D = self.rng.random(num_meas)
552 initial_W_and_D = np.hstack([initial_W.flatten(), initial_D])
554 result = minimize(penalized_mle_loss, initial_W_and_D, method="BFGS")
555 self.result = result
556 self.W_hat = result.x[: num_latent * num_meas].reshape(num_latent, num_meas)
557 self.D_hat = np.diag(result.x[num_latent * num_meas :])
558 self.convergence_success_mle = result.success
559 self.convergence_message_mle = result.message
561 return self
563 def validation_mle(self, lambda_reg, mu_reg, data):
564 W = self.W_hat
565 D = self.D_hat
566 Sigma_hat = np.cov(data, rowvar=False)
567 Sigma = self.compute_sigma(W, D)
568 Sigma_inv = np.linalg.inv(Sigma)
569 sign, logdet = np.linalg.slogdet(Sigma_inv)
571 if sign <= 0:
572 return np.inf
574 loss = np.trace(np.dot(Sigma_hat, Sigma_inv)) - sign * logdet
575 loss += lambda_reg * self.rho(W) + mu_reg * self.sigma(W)
576 return loss
578 # penalized LSE
579 def fit_penalized_lse(
580 self,
581 dataset: npt.NDArray,
582 lambda_reg: float = 0.1,
583 mu_reg: float = 0.1,
584 ) -> "DevMedil":
585 num_meas = dataset.shape[1]
586 if self.one_pure_child:
587 num_latent = num_meas
588 else:
589 num_latent = (num_meas**2) // 4
590 Sigma_hat = np.cov(dataset.T)
592 def penalized_lse_loss(W_and_D):
593 W = W_and_D[: num_latent * num_meas].reshape(num_latent, num_meas)
594 D = np.diag(W_and_D[num_latent * num_meas :])
596 loss = norm(Sigma_hat - W.T @ W - D, "fro") ** 2
597 # nuclear norm for the first penalty term
598 loss += lambda_reg * self.rho(W)
599 # L1 norm for the second penalty function
600 loss += mu_reg * self.sigma(W)
602 return loss
604 initial_W = self.rng.standard_normal((num_latent, num_meas))
605 initial_D = np.abs(self.rng.standard_normal(num_meas))
606 initial_params = np.concatenate([initial_W.flatten(), initial_D])
608 result = minimize(penalized_lse_loss, initial_params, method="BFGS")
609 self.result = result
610 self.W_hat = result.x[: num_latent * num_meas].reshape(num_latent, num_meas)
611 self.D_hat = np.diag(np.abs(result.x[num_latent * num_meas :]))
612 self.convergence_success_lse = result.success
613 self.convergence_message_lse = result.message
615 return self
617 def validation_lse(self, lambda_reg, mu_reg, data):
618 W = self.W_hat
619 D = self.D_hat
620 Sigma_hat = np.cov(data, rowvar=False)
621 loss = norm(Sigma_hat - W.T @ W - D, "fro") ** 2
622 loss += lambda_reg * self.rho(W)
623 loss += mu_reg * self.sigma(W)
624 return loss
626 # compute sigma
627 def compute_sigma(self, W: npt.NDArray, D: npt.NDArray) -> npt.NDArray:
628 Sigma = np.dot(W.T, W) + D
629 return Sigma
631 # ρ(W)
632 def rho(self, W: npt.NDArray) -> float:
633 return norm(W, "nuc")
635 # σ(W), the sum of absolute values of elements (L1 norm)
636 def sigma(self, W: npt.NDArray) -> float:
637 return np.sum(np.abs(W))
639 def sample(self, sample_size: int, method: str = "mle") -> npt.NDArray:
640 if method not in ["mle", "lse"]:
641 raise ValueError("Method must be either 'mle' or 'lse'")
643 if method == "mle":
644 if not hasattr(self, "W_hat_mle") or not hasattr(self, "D_hat_mle"):
645 raise ValueError("MLE model must be fitted before sampling")
646 W_hat, D_hat = self.W_hat_mle, self.D_hat_mle
647 else:
648 if not hasattr(self, "W_hat_lse") or not hasattr(self, "D_hat_lse"):
649 raise ValueError("LSE model must be fitted before sampling")
650 W_hat, D_hat = self.W_hat_lse, self.D_hat_lse
652 k, n = W_hat.shape
653 L = self.rng.standard_normal((sample_size, k))
654 epsilon = self.rng.multivariate_normal(np.zeros(n), D_hat, sample_size)
655 return np.dot(L, W_hat) + epsilon
657 def fit(
658 self, dataset: npt.NDArray, method: str = "mle", lambda_reg=0.1, mu_reg=0.1
659 ) -> "DevMedil":
660 if method == "mle":
661 return self.fit_penalized_mle(dataset, lambda_reg, mu_reg)
662 elif method == "lse":
663 return self.fit_penalized_lse(dataset, lambda_reg, mu_reg)
664 else:
665 raise ValueError("Method must be either 'mle' or 'lse'")
668class DevMedilInterv(NeuroCausalFactorAnalysis):
669 def __init__(self, **kwargs):
670 super().__init__(**kwargs)
671 self.parameters = Parameters("InterVAE")
672 self.hyperparams.update(
673 {
674 "batch_size": 128,
675 "num_epochs": 200,
676 "lr": 0.005,
677 "beta": 1,
678 "num_valid": 1000,
679 "lambda": 0.01,
680 "meas_width": 1,
681 "meas_depth": 0,
682 "num_latent": 5,
683 "latent_width": 1,
684 "latent_depth": 0,
685 }
686 )
688 def _train_vae(self, train_loader, valid_loader):
689 """Training VAE with the specified image dataset
690 :param m: dimension of the latent variable
691 :param n: dimension of the observed variable
692 :param train_loader: training image dataset loader
693 :param valid_loader: validation image dataset loader
694 :param biadj_mat: the adjacency matrix of the directed graph
695 :param seed: random seed for the experiments
696 :return: trained model and training loss history
697 """
699 num_meas = self.dataset.shape[1] - 1 # one column for interv labels;
700 self.num_meas = num_meas
702 # building VAE
703 model = InterVAE(
704 num_meas,
705 self.hyperparams["meas_width"],
706 self.hyperparams["meas_depth"],
707 self.hyperparams["num_latent"],
708 self.hyperparams["latent_width"],
709 self.hyperparams["latent_depth"],
710 )
711 model = model.to(self.device)
712 optimizer = torch.optim.AdamW(
713 model.parameters(), lr=self.hyperparams["lr"], weight_decay=1e-5
714 )
715 # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.90)
716 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
717 optimizer, patience=10, factor=0.5
718 )
719 num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
720 self.log(f"Number of parameters: {num_params}")
722 # training loop
723 model.train()
724 train_elbo, train_error = [], []
725 valid_elbo, valid_error = [], []
727 pbar = tqdm(
728 range(self.hyperparams["num_epochs"]), desc="Training NCFA", unit="epoch"
729 )
730 for idx in pbar:
731 self.log(f"Training on epoch {idx}...")
732 train_lb, train_er, nbatch = 0.0, 0.0, 0
734 for batch, _ in train_loader:
735 x_batch = batch[:, :-1]
736 interv_idx_batch = batch[:, -1, None]
737 batch_size = x_batch.shape[0]
738 x_batch = x_batch.to(self.device)
739 recon_batch, logcov_batch, mu_batch, logvar_batch = model(
740 x_batch, interv_idx_batch
741 )
742 causal_biadj_batch = model.decoder.mean_causal.weight
743 loss = self._elbo_gaussian(
744 x_batch,
745 recon_batch,
746 logcov_batch,
747 mu_batch,
748 logvar_batch,
749 causal_biadj_batch,
750 self.hyperparams["beta"],
751 )
752 error = self._recon_error(
753 x_batch, recon_batch, logcov_batch, weighted=False
754 )
755 optimizer.zero_grad()
756 loss.backward()
757 optimizer.step()
759 # update loss and nbatch
760 train_lb += loss.item() / batch_size
761 train_er += error.item() / batch_size
762 nbatch += 1
764 # finish training epoch
765 # scheduler.step()
766 train_lb = train_lb / nbatch
767 train_er = train_er / nbatch
768 train_elbo.append(train_lb)
769 train_error.append(train_er)
770 self.log(f"Finish training epoch {idx} with loss {train_lb}")
772 # append validation loss
773 valid_lb, valid_er = self._valid_vae(model, valid_loader)
774 valid_elbo.append(valid_lb)
775 valid_error.append(valid_er)
777 # decrease learning rate if validation plateaus
778 scheduler.step(valid_lb)
780 # update tqdm progress bar
781 pbar.set_postfix({"loss": train_lb}) # , "validation loss": valid_lb
783 train_elbo, train_error = np.array(train_elbo), np.array(train_error)
784 valid_elbo, valid_error = np.array(valid_elbo), np.array(valid_error)
785 elbo = [train_elbo, valid_elbo]
786 error = [train_error, valid_error]
788 return model, elbo, error
790 def _valid_vae(self, model, valid_loader):
791 """Training VAE with the specified image dataset
792 :param model: trained VAE model
793 :param valid_loader: validation image dataset loader
794 :return: validation loss
795 """
796 # set to evaluation mode
797 model.eval()
798 valid_lb, valid_er, nbatch = 0.0, 0.0, 0
800 for batch, _ in valid_loader:
801 x_batch = batch[:, :-1]
802 interv_idx_batch = batch[:, -1, None]
804 with torch.no_grad():
805 batch_size = x_batch.shape[0]
806 x_batch = x_batch.to(self.device)
807 recon_batch, logcov_batch, mu_batch, logvar_batch = model(
808 x_batch, interv_idx_batch
809 )
810 loss = self._elbo_gaussian(
811 x_batch,
812 recon_batch,
813 logcov_batch,
814 mu_batch,
815 logvar_batch,
816 None,
817 self.hyperparams["beta"],
818 )
819 error = self._recon_error(
820 x_batch, recon_batch, logcov_batch, weighted=False
821 )
823 # update loss and nbatch
824 valid_lb += loss.item() / batch_size
825 valid_er += error.item() / batch_size
826 nbatch += 1
828 # report validation loss
829 valid_lb = valid_lb / nbatch
830 valid_er = valid_er / nbatch
831 self.log(f"Finish validation with loss {valid_lb}")
833 return valid_lb, valid_er
835 def _elbo_gaussian(self, x, x_recon, logcov, mu, logvar, causal_biadj, beta):
836 """Calculating loss for variational autoencoder
837 :param x: original image
838 :param x_recon: reconstruction in the output layer
839 :param logcov: log of covariance matrix of the data distribution
840 :param mu: mean in the fitted variational distribution
841 :param logvar: log of the variance in the variational distribution
842 :param beta: beta
843 :return: reconstruction loss + KL
844 """
846 # KL-divergence
847 # https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
848 # https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py
849 # https://arxiv.org/pdf/1312.6114.pdf
850 kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
852 # reconstruction loss
853 cov = torch.exp(logcov)
854 cov = self._apply_along_axis(torch.diag, cov, axis=0)
855 cov = cov.mean(axis=0)
857 diff = x - x_recon
858 recon_loss = torch.sum(
859 torch.det(cov)
860 + torch.diagonal(
861 torch.mm(
862 torch.mm(diff, torch.inverse(cov)), torch.transpose(diff, 0, 1)
863 )
864 )
865 ).mul(-1 / 2)
867 # elbo
868 loss = -beta * kl_div + recon_loss
869 if causal_biadj is not None:
870 llambda = self.hyperparams["lambda"]
871 norm_type = 2
872 kernel_size = (
873 self.hyperparams["latent_width"],
874 self.hyperparams["latent_width"],
875 )
876 causal_biadj = causal_biadj[None, None, :, :]
877 causal_biadj = lp_pool2d(
878 causal_biadj, norm_type, kernel_size
879 ).squeeze() # penalize num edges
880 self.parameters.causal_biadj = causal_biadj.detach().numpy().T
881 return -loss + llambda * causal_biadj.norm(1)
882 return -loss
885class DevMedilInterv2(NeuroCausalFactorAnalysis):
886 def __init__(self, **kwargs):
887 super().__init__(**kwargs)
888 self.parameters = Parameters("InterVAE2")
889 self.hyperparams.update(
890 {
891 "batch_size": 128,
892 "num_epochs": 200,
893 "lr": 0.005,
894 "beta": 1,
895 "num_valid": 1000,
896 "sparse_reg": 10,
897 "dag_reg": 10,
898 "meas_width": 1,
899 "meas_depth": 0,
900 "num_latent": 5,
901 "latent_width": 1,
902 "latent_depth": 0,
903 }
904 )
906 def _data_loader(self, sample):
907 sample_x = sample.astype(np.float32)
908 sample_z = np.empty(shape=(sample_x.shape[0], 0)).astype(np.float32)
909 dataset = TensorDataset(torch.tensor(sample_x), torch.tensor(sample_z))
910 data_loader = DataLoader(
911 dataset,
912 batch_sampler=self._sampler(dataset, self.hyperparams["batch_size"]),
913 )
914 return data_loader
916 class _sampler(torch.utils.data.Sampler):
917 def __init__(self, data: List[str], batch_size: int) -> None:
918 self.data = data
919 self.batch_size = batch_size
920 sample_x = data[:][0]
921 self.labels = sample_x[:, -1]
922 self.contexts, self.inv, self.counts = torch.unique(
923 self.labels, return_inverse=True, return_counts=True
924 )
926 def __len__(self) -> int:
927 # self.chunk_sizes = [
928 # (count + self.batch_size - 1) // self.batch_size
929 # for count in self.counts
930 # ]
931 # return sum(self.chunk_sizes)
932 return (len(self.data) + self.batch_size - 1) // self.batch_size
934 def __iter__(self) -> Iterator[List[int]]:
935 context_idcs = torch.multinomial(
936 self.counts / self.counts.sum(), len(self), replacement=True
937 )
938 # contexts = self.contexts[context_idcs]
939 for context_idx in context_idcs:
940 context_data_idcs = torch.where(self.inv == context_idx)[0]
941 batch_idcs = context_data_idcs[torch.randperm(len(context_data_idcs))][
942 : self.batch_size
943 ]
944 yield batch_idcs
946 def _train_vae(self, train_loader, valid_loader):
947 """Training VAE with the specified image dataset
948 :param m: dimension of the latent variable
949 :param n: dimension of the observed variable
950 :param train_loader: training image dataset loader
951 :param valid_loader: validation image dataset loader
952 :param biadj_mat: the adjacency matrix of the directed graph
953 :param seed: random seed for the experiments
954 :return: trained model and training loss history
955 """
957 num_meas = self.dataset.shape[1] - 1 # one column for interv labels;
958 self.num_meas = num_meas
960 # building VAE
961 model = InterVAE2(
962 num_meas,
963 self.hyperparams["meas_width"],
964 self.hyperparams["meas_depth"],
965 self.hyperparams["num_latent"],
966 self.hyperparams["latent_width"],
967 self.hyperparams["latent_depth"],
968 )
969 model = model.to(self.device)
970 optimizer = torch.optim.AdamW(
971 model.parameters(), lr=self.hyperparams["lr"], weight_decay=1e-5
972 )
973 # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.90)
974 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
975 optimizer, patience=10, factor=0.5
976 )
977 num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
978 self.log(f"Number of parameters: {num_params}")
980 # training loop
981 model.train()
982 train_elbo, train_error = [], []
983 valid_elbo, valid_error = [], []
985 pbar = tqdm(
986 range(self.hyperparams["num_epochs"]), desc="Training NCFA", unit="epoch"
987 )
988 for idx in pbar:
989 self.log(f"Training on epoch {idx}...")
990 train_lb, train_er, nbatch = 0.0, 0.0, 0
992 for batch, _ in train_loader:
993 x_batch = batch[:, :-1]
994 interv_idx_batch = batch[:, -1, None]
995 interv_idx = torch.unique(interv_idx_batch).to(int)
996 assert len(interv_idx) == 1
997 interv_idx = int(interv_idx)
998 batch_size = x_batch.shape[0]
999 x_batch = x_batch.to(self.device)
1000 recon_batch, logcov_batch, mu_batch, logvar_batch = model(
1001 x_batch, interv_idx_batch
1002 )
1003 causal_biadj_batch = model.decoder.mean_causal[interv_idx].weight
1004 loss = self._elbo_gaussian(
1005 x_batch,
1006 recon_batch,
1007 logcov_batch,
1008 mu_batch,
1009 logvar_batch,
1010 causal_biadj_batch,
1011 self.hyperparams["beta"],
1012 )
1013 error = self._recon_error(
1014 x_batch, recon_batch, logcov_batch, weighted=False
1015 )
1016 optimizer.zero_grad()
1017 loss.backward()
1018 optimizer.step()
1020 # update loss and nbatch
1021 train_lb += loss.item() / batch_size
1022 train_er += error.item() / batch_size
1023 nbatch += 1
1025 # finish training epoch
1026 # scheduler.step()
1027 train_lb = train_lb / nbatch
1028 train_er = train_er / nbatch
1029 train_elbo.append(train_lb)
1030 train_error.append(train_er)
1031 self.log(f"Finish training epoch {idx} with loss {train_lb}")
1033 # append validation loss
1034 valid_lb, valid_er = self._valid_vae(model, valid_loader)
1035 valid_elbo.append(valid_lb)
1036 valid_error.append(valid_er)
1038 # decrease learning rate if validation plateaus
1039 scheduler.step(valid_lb)
1041 # update tqdm progress bar
1042 pbar.set_postfix({"loss": train_lb}) # , "validation loss": valid_lb
1044 train_elbo, train_error = np.array(train_elbo), np.array(train_error)
1045 valid_elbo, valid_error = np.array(valid_elbo), np.array(valid_error)
1046 elbo = [train_elbo, valid_elbo]
1047 error = [train_error, valid_error]
1049 temp = {
1050 k: v.weight.detach().numpy().T for k, v in model.decoder.mean_causal.items()
1051 }
1052 norm_type = 2
1053 kernel_size = (
1054 self.hyperparams["latent_width"],
1055 self.hyperparams["latent_width"],
1056 )
1057 temp = {
1058 t: lp_pool2d(
1059 v.weight.detach()[None, None, :, :], norm_type, kernel_size
1060 ).squeeze()
1061 for t, v in model.decoder.mean_causal.items()
1062 }
1063 self.parameters.causal_biadj_dict = {k: v.numpy().T for k, v in temp.items()}
1065 return model, elbo, error
1067 def _valid_vae(self, model, valid_loader):
1068 """Training VAE with the specified image dataset
1069 :param model: trained VAE model
1070 :param valid_loader: validation image dataset loader
1071 :return: validation loss
1072 """
1073 # set to evaluation mode
1074 model.eval()
1075 valid_lb, valid_er, nbatch = 0.0, 0.0, 0
1077 for batch, _ in valid_loader:
1078 x_batch = batch[:, :-1]
1079 interv_idx_batch = batch[:, -1, None]
1081 with torch.no_grad():
1082 batch_size = x_batch.shape[0]
1083 x_batch = x_batch.to(self.device)
1084 recon_batch, logcov_batch, mu_batch, logvar_batch = model(
1085 x_batch, interv_idx_batch
1086 )
1087 loss = self._elbo_gaussian(
1088 x_batch,
1089 recon_batch,
1090 logcov_batch,
1091 mu_batch,
1092 logvar_batch,
1093 None,
1094 self.hyperparams["beta"],
1095 )
1096 error = self._recon_error(
1097 x_batch, recon_batch, logcov_batch, weighted=False
1098 )
1100 # update loss and nbatch
1101 valid_lb += loss.item() / batch_size
1102 valid_er += error.item() / batch_size
1103 nbatch += 1
1105 # report validation loss
1106 valid_lb = valid_lb / nbatch
1107 valid_er = valid_er / nbatch
1108 self.log(f"Finish validation with loss {valid_lb}")
1110 return valid_lb, valid_er
1112 def _elbo_gaussian(self, x, x_recon, logcov, mu, logvar, causal_biadj, beta):
1113 """Calculating loss for variational autoencoder
1114 :param x: original image
1115 :param x_recon: reconstruction in the output layer
1116 :param logcov: log of covariance matrix of the data distribution
1117 :param mu: mean in the fitted variational distribution
1118 :param logvar: log of the variance in the variational distribution
1119 :param beta: beta
1120 :return: reconstruction loss + KL
1121 """
1123 # KL-divergence
1124 # https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
1125 # https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py
1126 # https://arxiv.org/pdf/1312.6114.pdf
1127 kl_div_loss = torch.mean(
1128 -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp(), dim=1), dim=0
1129 )
1131 recon_loss = torch.nn.functional.mse_loss(x_recon, x)
1133 # elbo loss
1134 loss = beta * kl_div_loss + recon_loss
1135 if causal_biadj is not None:
1136 sparse_reg = self.hyperparams["sparse_reg"]
1137 dag_reg = self.hyperparams["dag_reg"]
1138 norm_type = 2
1139 kernel_size = (
1140 self.hyperparams["latent_width"],
1141 self.hyperparams["latent_width"],
1142 )
1143 pooled = causal_biadj[None, None, :, :]
1144 pooled = lp_pool2d(
1145 pooled, norm_type, kernel_size
1146 ).squeeze() # penalize num edges
1147 density = pooled.mean() # L1 norm / num_entries
1149 # https://dagma.readthedocs.io/en/latest/#the-log-det-acyclicity-characterization
1150 s = torch.tensor([5])
1151 d = len(pooled)
1152 nondagness = -torch.logdet(
1153 s * torch.eye(d) - torch.square(pooled - torch.diag(pooled))
1154 ) + d * torch.log(s)
1156 return loss + sparse_reg * density + dag_reg * nondagness
1157 return loss
1160class BlockLinear(nn.Module):
1161 def __init__(
1162 self,
1163 context_dims,
1164 width,
1165 bias=True,
1166 device=None,
1167 dtype=None,
1168 ):
1169 factory_kwargs = {"device": device, "dtype": dtype}
1170 super().__init__()
1171 self.block_mask = torch.eye(context_dims).kron(torch.ones(width, width))
1172 num_features = context_dims * width
1173 self.weight = Parameter(
1174 torch.empty((num_features, num_features), **factory_kwargs)
1175 )
1176 if bias:
1177 self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
1178 else:
1179 self.register_parameter("bias", None)
1180 self.reset_parameters()
1182 def reset_parameters(self):
1183 # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
1184 nn.init.orthogonal_(self.weight)
1185 # nn.init.sparse_(self.weight, 2 / 3)
1186 if self.bias is not None:
1187 fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
1188 bound = 1 / torch.sqrt(torch.tensor(fan_in)) if fan_in > 0 else 0
1189 nn.init.uniform_(self.bias, -bound, bound)
1191 def forward(self, input):
1192 # masked linear layer
1193 return nn.functional.linear(input, self.weight * self.block_mask, self.bias)
1195 def extra_repr(self):
1196 return "in_features={}, out_features={}, bias={}".format(
1197 self.in_features, self.out_features, self.bias is not None
1198 )
1201class Intervenable(nn.Module):
1202 def __init__(
1203 self,
1204 in_features,
1205 out_features,
1206 mask=None,
1207 bias=True,
1208 device=None,
1209 dtype=None,
1210 ):
1211 self.factory_kwargs = {"device": device, "dtype": dtype}
1212 super().__init__()
1213 self.in_features = in_features
1214 self.out_features = out_features
1215 self.mask = mask
1216 self.weight = Parameter(
1217 torch.empty((out_features, in_features), **self.factory_kwargs)
1218 )
1220 if bias:
1221 self.bias = Parameter(torch.empty(out_features, **self.factory_kwargs))
1222 else:
1223 self.register_parameter("bias", None)
1224 self.reset_parameters()
1226 def reset_parameters(self):
1227 nn.init.orthogonal_(self.weight)
1228 if self.bias is not None:
1229 fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
1230 bound = 1 / torch.sqrt(torch.tensor(fan_in)) if fan_in > 0 else 0
1231 nn.init.uniform_(self.bias, -bound, bound)
1233 def forward(self, input, obs_weight, interv_idx):
1234 if interv_idx != -1:
1235 min_weight = torch.minimum(self.weight, obs_weight)
1236 interv_mask = torch.ones_like(min_weight, **self.factory_kwargs)
1237 interv_mask[interv_idx] = 0
1238 interv_mask[interv_idx, interv_idx] = 1
1239 self.weight.data = min_weight * interv_mask
1240 if self.mask is None:
1241 return nn.functional.linear(input, self.weight, self.bias)
1242 else:
1243 return nn.functional.linear(input, self.weight * self.mask, self.bias)
1245 def extra_repr(self):
1246 return "in_features={}, out_features={}, bias={}".format(
1247 self.in_features, self.out_features, self.bias is not None
1248 )
1251class VAE(nn.Module):
1252 def __init__(self, input_dims, context_dims, width, depth, hidden_dims):
1253 super().__init__()
1254 latent_dims = context_dims * width
1255 self.input_dims = input_dims
1256 self.width = width
1258 # Encoder
1259 self.encoder = nn.Sequential(
1260 nn.Linear(input_dims, hidden_dims),
1261 nn.BatchNorm1d(hidden_dims),
1262 nn.GELU(),
1263 nn.Linear(hidden_dims, hidden_dims),
1264 nn.BatchNorm1d(hidden_dims),
1265 nn.GELU(),
1266 )
1267 self.fc_mu = nn.Linear(hidden_dims, latent_dims)
1268 self.fc_var = nn.Linear(hidden_dims, latent_dims)
1270 # Our module
1271 unchained = BlockLinear(context_dims, width), nn.GELU()
1272 deeply_expressive = chain(*(unchained for _ in range(depth)))
1273 self.expressive_layer = nn.Sequential(*deeply_expressive, nn.AvgPool1d(width))
1274 self.causal_layer = nn.ModuleDict(
1275 {
1276 str(interv_idx): Intervenable(
1277 in_features=context_dims,
1278 out_features=input_dims,
1279 )
1280 for interv_idx in range(-1, context_dims)
1281 }
1282 )
1284 # Decoder
1285 self.decoder = nn.Sequential(
1286 nn.Linear(input_dims, hidden_dims),
1287 nn.BatchNorm1d(hidden_dims),
1288 nn.GELU(),
1289 nn.Linear(hidden_dims, input_dims),
1290 )
1292 def encode(self, x):
1293 h = self.encoder(x)
1294 return self.fc_mu(h), self.fc_var(h)
1296 def reparameterize(self, mu, log_var):
1297 std = torch.exp(0.5 * log_var)
1298 eps = torch.randn_like(std)
1299 return mu + eps * std
1301 def decode(self, z):
1302 epsilon = self.expressive_layer(z)
1303 obs_weight = self.causal_layer[str(-1)].weight
1304 l = self.causal_layer[str(self.batch_label)](
1305 epsilon, obs_weight, self.batch_label
1306 )
1307 return self.decoder(l)
1309 def forward(self, x, label):
1310 self.batch_label = label
1311 mu, log_var = self.encode(x)
1312 z = self.reparameterize(mu, log_var)
1313 return self.decode(z), mu, log_var
1316def _numpy_to_pytorch_dataset(np_dataset):
1317 """
1318 Converts a NumPy dataset into a PyTorch dataset.
1320 Parameters:
1321 - np_dataset: NumPy array where rows are samples and columns are features.
1323 Returns:
1324 - A PyTorch TensorDataset containing samples and labels.
1325 """
1327 # Ensure the input is a NumPy array
1328 if not isinstance(np_dataset, np.ndarray):
1329 raise ValueError("Input must be a NumPy array.")
1331 # Split the dataset into features (X) and labels (y)
1332 X = np_dataset[:, :-1] # Features are all columns except the last
1333 y = np_dataset[:, -1] # Labels are the last column
1335 # Convert NumPy arrays to PyTorch tensors
1336 X_tensor = torch.from_numpy(X).float()
1337 y_tensor = (
1338 torch.from_numpy(y).float().unsqueeze(-1)
1339 ) # Optional: unsqueeze for consistency
1341 # Create a PyTorch TensorDataset
1342 dataset = TensorDataset(X_tensor, y_tensor)
1344 return dataset
1347class _sampler(torch.utils.data.Sampler):
1348 def __init__(self, dataset, batch_size: int) -> None:
1349 self.data = dataset
1350 self.batch_size = batch_size
1351 self.labels = dataset.tensors[1]
1352 self.contexts, self.inv, self.counts = torch.unique(
1353 self.labels, return_inverse=True, return_counts=True
1354 )
1356 def __len__(self) -> int:
1357 return (len(self.data) + self.batch_size - 1) // self.batch_size
1359 def __iter__(self):
1360 context_idcs = torch.multinomial(
1361 self.counts / self.counts.sum(), len(self), replacement=True
1362 )
1363 # contexts = self.contexts[context_idcs]
1364 for context_idx in context_idcs:
1365 context_data_idcs = torch.where(self.inv == context_idx)[0]
1366 batch_idcs = context_data_idcs[torch.randperm(len(context_data_idcs))][
1367 : self.batch_size
1368 ]
1369 yield batch_idcs
1372class IvnFA(object):
1373 def __init__(self):
1374 self.hyperparams = {
1375 "batch_size": 128,
1376 "num_epochs": 100,
1377 "lr": 0.005,
1378 "beta": 1,
1379 "num_valid": 1000,
1380 "sparse_reg": 10,
1381 "width": 1,
1382 "depth": 0,
1383 "context_dims": 5,
1384 "hidden_dims": 50,
1385 }
1386 self.checkpoint_save_path = None
1387 self.final_save_path = None
1388 self.loss_dict = {
1389 "elbo_train": [],
1390 "recon_train": [],
1391 "elbo_valid": [],
1392 "recon_valid": [],
1393 }
1395 def fit(self, dataset, split_idcs=None):
1396 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1398 input_dims = dataset.shape[1] - 1
1400 # random train/val split if explicit indices not provided
1401 if split_idcs is None:
1402 train_split, valid_split = train_test_split(dataset, train_size=0.7)
1403 else:
1404 train_split = dataset[split_idcs[0]]
1405 valid_split = dataset[split_idcs[1]]
1407 # train
1408 dataset = _numpy_to_pytorch_dataset(train_split)
1409 sampler = _sampler(dataset, self.hyperparams["batch_size"])
1410 self.train_loader = DataLoader(dataset, batch_sampler=sampler)
1412 # valid
1413 dataset = _numpy_to_pytorch_dataset(valid_split)
1414 sampler = _sampler(dataset, self.hyperparams["batch_size"])
1415 self.valid_loader = DataLoader(dataset, batch_sampler=sampler)
1417 context_dims, width, depth, hidden_dims = (
1418 self.hyperparams["context_dims"],
1419 self.hyperparams["width"],
1420 self.hyperparams["depth"],
1421 self.hyperparams["hidden_dims"],
1422 )
1423 self.model = VAE(input_dims, context_dims, width, depth, hidden_dims).to(
1424 self.device
1425 )
1426 self.optimizer = torch.optim.Adam(
1427 self.model.parameters(), lr=self.hyperparams["lr"]
1428 )
1430 losses = []
1431 pbar = tqdm(
1432 range(self.hyperparams["num_epochs"]), desc="Training...", unit="epoch"
1433 )
1435 for epoch in pbar:
1436 loss = self._train()
1437 self._validate()
1438 losses.append(loss)
1439 pbar.set_postfix({"loss": f"{loss:.4f}"})
1441 # Save checkpoint after each epoch
1442 if self.checkpoint_save_path is not None:
1443 torch.save(
1444 {
1445 "epoch": epoch,
1446 "model_state_dict": self.model.state_dict(),
1447 "optimizer_state_dict": self.optimizer.state_dict(),
1448 "loss": loss,
1449 "losses": losses,
1450 },
1451 f"{self.checkpoint_save_path}.pt",
1452 )
1454 # Save model and training losses
1455 if self.final_save_path is not None:
1456 torch.save(
1457 {
1458 "model_state_dict": self.model.state_dict(),
1459 "optimizer_state_dict": self.optimizer.state_dict(),
1460 "losses": losses,
1461 },
1462 f"{self.final_save_path}.pt",
1463 )
1464 self.losses = losses
1465 self.causal_weight_dict = {
1466 k: v.weight.detach().numpy() for k, v in self.model.causal_layer.items()
1467 }
1469 def _loss_function(self, recon_x, x, mu, log_var, causal_weights):
1470 MSE = nn.MSELoss(reduction="mean")
1471 mse_loss = MSE(recon_x, x)
1473 # see Appendix B from VAE paper:
1474 # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
1475 # https://arxiv.org/abs/1312.6114
1476 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
1477 KLD = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
1479 causal_weights = torch.abs(causal_weights)
1480 sparse_reg = causal_weights.pow(2).mean()
1482 return (
1483 mse_loss
1484 + self.hyperparams["beta"] * KLD
1485 + self.hyperparams["llambda"] * sparse_reg
1486 ), (mse_loss, KLD)
1488 def _train(self):
1489 self.model.train()
1490 batch_size = len(self.train_loader.dataset)
1491 train_loss = 0
1492 mse = 0
1493 kl = 0
1494 pbar = tqdm(
1495 self.train_loader, desc="training epoch...", unit="batch", leave=False
1496 )
1497 for batch_idx, (data, labels) in enumerate(pbar):
1498 data = data.to(self.device)
1499 label = torch.unique(labels).to(int).item()
1500 self.optimizer.zero_grad()
1501 recon_batch, mu, log_var = self.model(data, label)
1502 causal_weights_batch = self.model.causal_layer[str(label)].weight
1503 loss, (mse, kl) = self._loss_function(
1504 recon_batch, data, mu, log_var, causal_weights_batch
1505 )
1506 loss.backward()
1507 train_loss += loss.item()
1508 mse += mse.item()
1509 kl += kl.item()
1510 self.optimizer.step()
1511 mse = mse / batch_size
1512 self.loss_dict["recon_train"].append(mse.item())
1513 elbo = mse + (kl / batch_size)
1514 self.loss_dict["elbo_train"].append(elbo.item())
1515 return train_loss / batch_size
1517 def _validate(self):
1518 self.model.eval()
1519 batch_size = len(self.train_loader.dataset)
1520 mse = 0
1521 kl = 0
1522 for batch_idx, (data, labels) in enumerate(self.valid_loader):
1523 data = data.to(self.device)
1524 label = torch.unique(labels).to(int).item()
1525 recon_batch, mu, log_var = self.model(data, label)
1526 causal_weights_batch = self.model.causal_layer[str(label)].weight
1527 _, (mse, kl) = self._loss_function(
1528 recon_batch, data, mu, log_var, causal_weights_batch
1529 )
1530 mse += mse.item()
1531 kl += kl.item()
1532 mse = mse / batch_size
1533 self.loss_dict["recon_valid"].append(mse.item())
1534 elbo = mse + (kl / batch_size)
1535 self.loss_dict["elbo_valid"].append(elbo.item())