Coverage for /opt/conda/lib/python3.12/site-packages/medil/models.py: 84%

740 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-25 05:42 +0000

1"""MeDIL causal model base class and a preconfigured NCFA class.""" 

2 

3import os 

4import pickle 

5import warnings 

6from datetime import datetime 

7from itertools import chain 

8from pathlib import Path 

9from typing import Iterator, List 

10 

11import numpy as np 

12import numpy.typing as npt 

13import torch 

14from numpy.random import default_rng 

15from scipy.linalg import norm 

16from scipy.optimize import minimize 

17from sklearn.model_selection import train_test_split 

18from torch import nn 

19from torch.nn.functional import lp_pool2d 

20from torch.nn.parameter import Parameter 

21from torch.utils.data import DataLoader, TensorDataset 

22from tqdm import tqdm 

23 

24from .ecc_algorithms import find_heuristic_1pc 

25from .independence_testing import estimate_UDG 

26from .interv_vae import VariationalAutoencoder as InterVAE 

27from .interv_vae2 import VariationalAutoencoder as InterVAE2 

28from .vae import VariationalAutoencoder 

29 

30 

31class MedilCausalModel(object): 

32 """Base class using principle of polymorphism to establish common 

33 interface for derived parametric estimators. 

34 """ 

35 

36 def __init__( 

37 self, 

38 biadj: npt.NDArray = np.array([]), 

39 udg: npt.NDArray = np.array([]), 

40 one_pure_child: bool = True, 

41 rng=default_rng(0), 

42 ) -> None: 

43 self.biadj = biadj 

44 self.udg = udg 

45 self.one_pure_child = one_pure_child 

46 self.rng = rng 

47 

48 def fit(self, dataset: npt.NDArray) -> "MedilCausalModel": 

49 raise NotImplementedError 

50 

51 def sample(self, sample_size: int) -> npt.NDArray: 

52 raise NotImplementedError 

53 

54 

55class Parameters(object): 

56 "Different parameterizations of MeDIL causal Models." 

57 

58 def __init__(self, parameterization: str) -> None: 

59 self.parameterization = parameterization 

60 

61 if parameterization == "Gaussian": 

62 self.error_means = np.array([]) 

63 self.error_variances = np.array([]) 

64 self.biadj_weights = np.array([]) 

65 elif parameterization == "VAE": 

66 self.weights = np.array([]) 

67 with warnings.catch_warnings(action="ignore"): 

68 self.vae = VariationalAutoencoder(0, 0, 0, 0) 

69 elif parameterization == "InterVAE": 

70 self.weights = np.array([]) 

71 with warnings.catch_warnings(action="ignore"): 

72 self.vae = InterVAE(0, 0, 0, 0, 0, 0) 

73 elif parameterization == "InterVAE2": 

74 self.weights = np.array([]) 

75 with warnings.catch_warnings(action="ignore"): 

76 self.vae = InterVAE2(0, 0, 0, 0, 0, 0) 

77 

78 def __str__(self) -> str: 

79 return "\n".join( 

80 f"parameters.{attr}: {val}" for attr, val in vars(self).items() 

81 ) 

82 

83 

84class GaussianMCM(MedilCausalModel): 

85 """A linear MeDIL causal model with Gaussian random variables.""" 

86 

87 def __init__(self, **kwargs): 

88 super().__init__(**kwargs) 

89 self.parameters = Parameters("Gaussian") 

90 

91 def fit(self, dataset: npt.NDArray) -> "GaussianMCM": 

92 """Fit a Gaussian MCM to a dataset with constraint-based 

93 structure learning and least squares parameter estimation.""" 

94 self.dataset = dataset 

95 if self.biadj.size == 0: 

96 self._compute_biadj() 

97 

98 self.parameters.error_means = self.dataset.mean(0) 

99 

100 cov = np.cov(self.dataset, rowvar=False) 

101 

102 num_weights = self.biadj.sum() 

103 num_err_vars = self.biadj.shape[1] 

104 

105 def _objective(weights_and_err_vars): 

106 weights = weights_and_err_vars[:num_weights] 

107 err_vars = weights_and_err_vars[num_weights:] 

108 

109 biadj_weights = np.zeros_like(self.biadj, float) 

110 biadj_weights[self.biadj] = weights 

111 

112 return ( 

113 (cov - biadj_weights.T @ biadj_weights - np.diagflat(err_vars)) ** 2 

114 ).sum() 

115 

116 result = minimize(_objective, np.ones(num_weights + num_err_vars)) 

117 if not result.success: 

118 warnings.warn(f"Optimization failed: {result.message}") 

119 

120 self.parameters.error_variances = result.x[num_weights:] 

121 

122 self.parameters.biadj_weights = np.zeros_like(self.biadj, float) 

123 self.parameters.biadj_weights[self.biadj] = result.x[:num_weights] 

124 

125 return self 

126 

127 def _compute_biadj(self): 

128 """Constraint-based structure learning.""" 

129 if self.udg.size == 0: 

130 self._estimate_udg() 

131 self.biadj = find_heuristic_1pc(self.udg) 

132 

133 def _estimate_udg(self): 

134 """Constraint-based structure learning.""" 

135 samp_size = len(self.dataset) 

136 cov = np.cov(self.dataset, rowvar=False) 

137 corr = np.corrcoef(self.dataset, rowvar=False) 

138 inner_numerator = 1 - cov * corr # should never be <= 0? 

139 inner_numerator = inner_numerator.clip(min=0.00001) 

140 inner_numerator[np.tril_indices_from(inner_numerator)] = 1 

141 udg_triu = np.log(inner_numerator) < (-np.log(samp_size) / samp_size) 

142 udg = udg_triu + udg_triu.T 

143 self.udg = udg 

144 

145 def sample(self, sample_size: int) -> npt.NDArray: 

146 """Sample a dataset from a GaussianMCM, after structure and 

147 parameters have been specified or estimated.""" 

148 num_latent = len(self.biadj) 

149 latent_sample = self.rng.multivariate_normal( 

150 np.zeros(num_latent), np.eye(num_latent), sample_size 

151 ) 

152 error_sample = self.rng.multivariate_normal( 

153 self.parameters.error_means, 

154 np.diagflat(self.parameters.error_variances), 

155 sample_size, 

156 ) 

157 sample = latent_sample @ self.parameters.biadj_weights + error_sample 

158 return sample 

159 

160 

161class NeuroCausalFactorAnalysis(MedilCausalModel): 

162 """A MeDIL causal model represented by a deep generative model.""" 

163 

164 def __init__( 

165 self, 

166 seed: int = 0, 

167 log_path: str = "", 

168 verbose: bool = False, 

169 **kwargs, 

170 ): 

171 super().__init__(**kwargs) 

172 if log_path: 

173 Path(log_path).mkdir(exist_ok=True) 

174 self.log_path = log_path 

175 self.verbose = verbose 

176 self.seed = seed 

177 self.hyperparams = { 

178 "heuristic": True, 

179 "method": "xicor", 

180 "alpha": 0.05, 

181 "batch_size": 128, 

182 "num_epochs": 200, 

183 "lr": 0.005, 

184 "beta": 1, 

185 "num_valid": 1000, 

186 "mu": 0.01, 

187 "lambda": 0.01, 

188 "deg_of_free": 2, 

189 "width_per_meas": 2, 

190 "num_hidden_layers": 1, 

191 "prior_biadj": None, 

192 } 

193 self.parameters = Parameters("VAE") 

194 self.loss = None 

195 self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

196 

197 def log(self, entry: str) -> None: 

198 if not (self.log_path or self.verbose): 

199 return 

200 time_stamped_entry = f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} {entry}" 

201 if self.log_path: 

202 with open(f"{self.log_path}training.log", "a") as log_file: 

203 log_file.write(time_stamped_entry + "\n") 

204 if self.verbose: 

205 print(time_stamped_entry) 

206 

207 def fit(self, dataset: npt.NDArray, split_idcs=None) -> "NeuroCausalFactorAnalysis": 

208 self.dataset = dataset 

209 

210 # random train/val split if explicit indices not provided 

211 if split_idcs is None: 

212 train_split, valid_split = train_test_split( 

213 dataset, train_size=0.7, random_state=self.seed 

214 ) 

215 else: 

216 train_split = dataset[split_idcs[0]] 

217 valid_split = dataset[split_idcs[1]] 

218 

219 train_loader = self._data_loader(train_split) 

220 valid_loader = self._data_loader(valid_split) 

221 

222 np.random.seed(self.seed) 

223 

224 model_recon, loss_recon, error_recon = self._train_vae( 

225 train_loader, valid_loader 

226 ) 

227 if self.log_path: 

228 torch.save(model_recon, os.path.join(self.log_path, "model_recon.pt")) 

229 with open(os.path.join(self.log_path, "loss_recon.pkl"), "wb") as handle: 

230 pickle.dump(loss_recon, handle, protocol=pickle.HIGHEST_PROTOCOL) 

231 with open(os.path.join(self.log_path, "error_recon.pkl"), "wb") as handle: 

232 pickle.dump(error_recon, handle, protocol=pickle.HIGHEST_PROTOCOL) 

233 # self.parameters.weights = ( 

234 # model_recon.decoder.mean_linear_fulcon.weight.detach().numpy().T 

235 # ) 

236 self.parameters.vae = model_recon 

237 self.loss = { 

238 "elbo_train": loss_recon[0], 

239 "elbo_valid": loss_recon[1], 

240 "recon_train": error_recon[0], 

241 "recon_valid": error_recon[1], 

242 } 

243 return self 

244 

245 def _compute_biadj(self): 

246 if self.udg.size == 0: 

247 self._estimate_udg() 

248 self.biadj = find_heuristic_1pc(self.udg) 

249 

250 def _estimate_udg(self): 

251 self.udg, pvals = estimate_UDG( 

252 self.dataset, 

253 method=self.hyperparams["method"], 

254 significance_level=self.hyperparams["alpha"], 

255 ) 

256 

257 def _data_loader(self, sample): 

258 sample_x = sample.astype(np.float32) 

259 sample_z = np.empty(shape=(sample_x.shape[0], 0)).astype(np.float32) 

260 dataset = TensorDataset(torch.tensor(sample_x), torch.tensor(sample_z)) 

261 data_loader = DataLoader( 

262 dataset, batch_size=self.hyperparams["batch_size"], shuffle=False 

263 ) 

264 return data_loader 

265 

266 def _train_vae(self, train_loader, valid_loader): 

267 """Training VAE with the specified image dataset 

268 :param m: dimension of the latent variable 

269 :param n: dimension of the observed variable 

270 :param train_loader: training image dataset loader 

271 :param valid_loader: validation image dataset loader 

272 :param biadj_mat: the adjacency matrix of the directed graph 

273 :param seed: random seed for the experiments 

274 :return: trained model and training loss history 

275 """ 

276 

277 num_meas = self.dataset.shape[1] 

278 self.num_meas = num_meas 

279 num_vae_latent = self.hyperparams["deg_of_free"] * num_meas 

280 num_hidden_layers = self.hyperparams["num_hidden_layers"] 

281 width_per_meas = self.hyperparams["width_per_meas"] 

282 prior_biadj = self.hyperparams["prior_biadj"] 

283 

284 # building VAE 

285 model = VariationalAutoencoder( 

286 num_vae_latent, num_meas, num_hidden_layers, width_per_meas, prior_biadj 

287 ) 

288 model = model.to(self.device) 

289 optimizer = torch.optim.AdamW( 

290 model.parameters(), lr=self.hyperparams["lr"], weight_decay=1e-5 

291 ) 

292 # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.90) 

293 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 

294 optimizer, patience=10, factor=0.5 

295 ) 

296 num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 

297 self.log(f"Number of parameters: {num_params}") 

298 

299 # training loop 

300 model.train() 

301 train_elbo, train_error = [], [] 

302 valid_elbo, valid_error = [], [] 

303 

304 pbar = tqdm( 

305 range(self.hyperparams["num_epochs"]), desc="Training NCFA", unit="epoch" 

306 ) 

307 for idx in pbar: 

308 self.log(f"Training on epoch {idx}...") 

309 train_lb, train_er, nbatch = 0.0, 0.0, 0 

310 

311 for x_batch, _ in train_loader: 

312 batch_size = x_batch.shape[0] 

313 x_batch = x_batch.to(self.device) 

314 recon_batch, logcov_batch, mu_batch, logvar_batch = model(x_batch) 

315 weight_batch = model.decoder.mean_linear_fulcon.weight 

316 loss = self._elbo_gaussian( 

317 x_batch, 

318 recon_batch, 

319 logcov_batch, 

320 mu_batch, 

321 logvar_batch, 

322 weight_batch, 

323 self.hyperparams["beta"], 

324 ) 

325 error = self._recon_error( 

326 x_batch, recon_batch, logcov_batch, weighted=False 

327 ) 

328 optimizer.zero_grad() 

329 loss.backward() 

330 optimizer.step() 

331 

332 # update loss and nbatch 

333 train_lb += loss.item() / batch_size 

334 train_er += error.item() / batch_size 

335 nbatch += 1 

336 

337 # finish training epoch 

338 # scheduler.step() 

339 train_lb = train_lb / nbatch 

340 train_er = train_er / nbatch 

341 train_elbo.append(train_lb) 

342 train_error.append(train_er) 

343 self.log(f"Finish training epoch {idx} with loss {train_lb}") 

344 

345 # append validation loss 

346 valid_lb, valid_er = self._valid_vae(model, valid_loader) 

347 valid_elbo.append(valid_lb) 

348 valid_error.append(valid_er) 

349 

350 # decrease learning rate if validation plateaus 

351 scheduler.step(valid_lb) 

352 

353 # update tqdm progress bar 

354 pbar.set_postfix({"loss": train_lb}) # , "validation loss": valid_lb 

355 

356 train_elbo, train_error = np.array(train_elbo), np.array(train_error) 

357 valid_elbo, valid_error = np.array(valid_elbo), np.array(valid_error) 

358 elbo = [train_elbo, valid_elbo] 

359 error = [train_error, valid_error] 

360 

361 return model, elbo, error 

362 

363 def _valid_vae(self, model, valid_loader): 

364 """Training VAE with the specified image dataset 

365 :param model: trained VAE model 

366 :param valid_loader: validation image dataset loader 

367 :return: validation loss 

368 """ 

369 # set to evaluation mode 

370 model.eval() 

371 valid_lb, valid_er, nbatch = 0.0, 0.0, 0 

372 

373 for x_batch, _ in valid_loader: 

374 with torch.no_grad(): 

375 batch_size = x_batch.shape[0] 

376 x_batch = x_batch.to(self.device) 

377 recon_batch, logcov_batch, mu_batch, logvar_batch = model(x_batch) 

378 loss = self._elbo_gaussian( 

379 x_batch, 

380 recon_batch, 

381 logcov_batch, 

382 mu_batch, 

383 logvar_batch, 

384 None, 

385 self.hyperparams["beta"], 

386 ) 

387 error = self._recon_error( 

388 x_batch, recon_batch, logcov_batch, weighted=False 

389 ) 

390 

391 # update loss and nbatch 

392 valid_lb += loss.item() / batch_size 

393 valid_er += error.item() / batch_size 

394 nbatch += 1 

395 

396 # report validation loss 

397 valid_lb = valid_lb / nbatch 

398 valid_er = valid_er / nbatch 

399 self.log(f"Finish validation with loss {valid_lb}") 

400 

401 return valid_lb, valid_er 

402 

403 def _elbo_gaussian(self, x, x_recon, logcov, mu, logvar, weight, beta): 

404 """Calculating loss for variational autoencoder 

405 :param x: original image 

406 :param x_recon: reconstruction in the output layer 

407 :param logcov: log of covariance matrix of the data distribution 

408 :param mu: mean in the fitted variational distribution 

409 :param logvar: log of the variance in the variational distribution 

410 :param beta: beta 

411 :return: reconstruction loss + KL 

412 """ 

413 

414 # KL-divergence 

415 # https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py 

416 # https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py 

417 # https://arxiv.org/pdf/1312.6114.pdf 

418 kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 

419 

420 # reconstruction loss 

421 cov = torch.exp(logcov) 

422 cov = self._apply_along_axis(torch.diag, cov, axis=0) 

423 cov = cov.mean(axis=0) 

424 

425 diff = x - x_recon 

426 recon_loss = torch.sum( 

427 torch.det(cov) 

428 + torch.diagonal( 

429 torch.mm( 

430 torch.mm(diff, torch.inverse(cov)), torch.transpose(diff, 0, 1) 

431 ) 

432 ) 

433 ).mul(-1 / 2) 

434 

435 # elbo 

436 loss = -beta * kl_div + recon_loss 

437 if weight is not None: 

438 llambda, mu = self.hyperparams["lambda"], self.hyperparams["mu"] 

439 norm_type = 2 

440 kernel_size = ( 

441 self.hyperparams["width_per_meas"], 

442 self.hyperparams["deg_of_free"], 

443 ) 

444 weight = weight[None, None, :, :] 

445 mu_weight = lp_pool2d( 

446 weight, norm_type, kernel_size 

447 ).squeeze() # penalize num edges 

448 self.parameters.biadj = mu_weight.detach().numpy().T 

449 ll_kernel_size = ( 

450 self.hyperparams["width_per_meas"] * self.num_meas, 

451 self.hyperparams["deg_of_free"], 

452 ) 

453 ll_weight = lp_pool2d( 

454 weight, norm_type, ll_kernel_size 

455 ).squeeze() # penalize num latents 

456 return -loss + llambda * ll_weight.norm(1) + mu * mu_weight.norm(1) 

457 return -loss 

458 

459 def _recon_error(self, x, x_recon, logcov, weighted): 

460 """Reconstruction error given x and x_recon 

461 :param x: original image 

462 :param x_recon: reconstruction in the output layer 

463 :param logcov: covariance matrix of the data distribution 

464 :param weighted: whether to use weighted reconstruction 

465 

466 Returns 

467 ------- 

468 error: reconstruction error 

469 """ 

470 

471 # reconstruction loss 

472 cov = torch.exp(logcov) 

473 cov = self._apply_along_axis(torch.diag, cov, axis=0) 

474 cov = cov.mean(axis=0) 

475 

476 diff = x - x_recon 

477 if weighted: 

478 error = torch.sum( 

479 torch.det(cov) 

480 + torch.diagonal( 

481 torch.mm( 

482 torch.mm(diff, torch.inverse(cov)), torch.transpose(diff, 0, 1) 

483 ) 

484 ) 

485 ).mul(-1 / 2) 

486 else: 

487 error = torch.linalg.norm(diff, ord=2) 

488 

489 return error 

490 

491 @staticmethod 

492 def _apply_along_axis(function, x, axis=0): 

493 """Helper function to return along a particular axis 

494 Parameters 

495 ---------- 

496 function: function to be applied 

497 x: data 

498 axis: axis to apply the function 

499 

500 Returns 

501 ------- 

502 The output applied to the axis 

503 """ 

504 

505 return torch.stack( 

506 [function(x_i) for x_i in torch.unbind(x, dim=axis)], dim=axis 

507 ) 

508 

509 

510# implement penalized mle and penalized lse with a new class 

511class DevMedil(MedilCausalModel): 

512 def __init__(self, **kwargs): 

513 super().__init__(**kwargs) 

514 self.init_W = None 

515 

516 # penalized MLE 

517 def fit_penalized_mle( 

518 self, 

519 dataset: npt.NDArray, 

520 lambda_reg: float = 0.1, 

521 mu_reg: float = 0.1, 

522 ) -> "DevMedil": 

523 num_meas = dataset.shape[1] 

524 if self.one_pure_child: 

525 num_latent = num_meas 

526 else: 

527 num_latent = (num_meas**2) // 4 

528 Sigma_hat = np.cov(dataset.T) 

529 

530 def penalized_mle_loss(W_and_D): 

531 W = W_and_D[: num_latent * num_meas].reshape(num_latent, num_meas) 

532 D = np.diag(W_and_D[num_latent * num_meas :]) 

533 

534 Sigma = self.compute_sigma(W, D) 

535 Sigma_inv = np.linalg.inv(Sigma) 

536 sign, logdet = np.linalg.slogdet(Sigma_inv) 

537 

538 if sign <= 0: 

539 return np.inf 

540 

541 loss = np.trace(np.dot(Sigma_hat, Sigma_inv)) - sign * logdet 

542 loss += lambda_reg * self.rho(W) + mu_reg * self.sigma(W) 

543 

544 return loss 

545 

546 initial_W = ( 

547 self.rng.standard_normal((num_latent, num_meas)) 

548 if self.init_W is None 

549 else self.init_W 

550 ) 

551 initial_D = self.rng.random(num_meas) 

552 initial_W_and_D = np.hstack([initial_W.flatten(), initial_D]) 

553 

554 result = minimize(penalized_mle_loss, initial_W_and_D, method="BFGS") 

555 self.result = result 

556 self.W_hat = result.x[: num_latent * num_meas].reshape(num_latent, num_meas) 

557 self.D_hat = np.diag(result.x[num_latent * num_meas :]) 

558 self.convergence_success_mle = result.success 

559 self.convergence_message_mle = result.message 

560 

561 return self 

562 

563 def validation_mle(self, lambda_reg, mu_reg, data): 

564 W = self.W_hat 

565 D = self.D_hat 

566 Sigma_hat = np.cov(data, rowvar=False) 

567 Sigma = self.compute_sigma(W, D) 

568 Sigma_inv = np.linalg.inv(Sigma) 

569 sign, logdet = np.linalg.slogdet(Sigma_inv) 

570 

571 if sign <= 0: 

572 return np.inf 

573 

574 loss = np.trace(np.dot(Sigma_hat, Sigma_inv)) - sign * logdet 

575 loss += lambda_reg * self.rho(W) + mu_reg * self.sigma(W) 

576 return loss 

577 

578 # penalized LSE 

579 def fit_penalized_lse( 

580 self, 

581 dataset: npt.NDArray, 

582 lambda_reg: float = 0.1, 

583 mu_reg: float = 0.1, 

584 ) -> "DevMedil": 

585 num_meas = dataset.shape[1] 

586 if self.one_pure_child: 

587 num_latent = num_meas 

588 else: 

589 num_latent = (num_meas**2) // 4 

590 Sigma_hat = np.cov(dataset.T) 

591 

592 def penalized_lse_loss(W_and_D): 

593 W = W_and_D[: num_latent * num_meas].reshape(num_latent, num_meas) 

594 D = np.diag(W_and_D[num_latent * num_meas :]) 

595 

596 loss = norm(Sigma_hat - W.T @ W - D, "fro") ** 2 

597 # nuclear norm for the first penalty term 

598 loss += lambda_reg * self.rho(W) 

599 # L1 norm for the second penalty function 

600 loss += mu_reg * self.sigma(W) 

601 

602 return loss 

603 

604 initial_W = self.rng.standard_normal((num_latent, num_meas)) 

605 initial_D = np.abs(self.rng.standard_normal(num_meas)) 

606 initial_params = np.concatenate([initial_W.flatten(), initial_D]) 

607 

608 result = minimize(penalized_lse_loss, initial_params, method="BFGS") 

609 self.result = result 

610 self.W_hat = result.x[: num_latent * num_meas].reshape(num_latent, num_meas) 

611 self.D_hat = np.diag(np.abs(result.x[num_latent * num_meas :])) 

612 self.convergence_success_lse = result.success 

613 self.convergence_message_lse = result.message 

614 

615 return self 

616 

617 def validation_lse(self, lambda_reg, mu_reg, data): 

618 W = self.W_hat 

619 D = self.D_hat 

620 Sigma_hat = np.cov(data, rowvar=False) 

621 loss = norm(Sigma_hat - W.T @ W - D, "fro") ** 2 

622 loss += lambda_reg * self.rho(W) 

623 loss += mu_reg * self.sigma(W) 

624 return loss 

625 

626 # compute sigma 

627 def compute_sigma(self, W: npt.NDArray, D: npt.NDArray) -> npt.NDArray: 

628 Sigma = np.dot(W.T, W) + D 

629 return Sigma 

630 

631 # ρ(W) 

632 def rho(self, W: npt.NDArray) -> float: 

633 return norm(W, "nuc") 

634 

635 # σ(W), the sum of absolute values of elements (L1 norm) 

636 def sigma(self, W: npt.NDArray) -> float: 

637 return np.sum(np.abs(W)) 

638 

639 def sample(self, sample_size: int, method: str = "mle") -> npt.NDArray: 

640 if method not in ["mle", "lse"]: 

641 raise ValueError("Method must be either 'mle' or 'lse'") 

642 

643 if method == "mle": 

644 if not hasattr(self, "W_hat_mle") or not hasattr(self, "D_hat_mle"): 

645 raise ValueError("MLE model must be fitted before sampling") 

646 W_hat, D_hat = self.W_hat_mle, self.D_hat_mle 

647 else: 

648 if not hasattr(self, "W_hat_lse") or not hasattr(self, "D_hat_lse"): 

649 raise ValueError("LSE model must be fitted before sampling") 

650 W_hat, D_hat = self.W_hat_lse, self.D_hat_lse 

651 

652 k, n = W_hat.shape 

653 L = self.rng.standard_normal((sample_size, k)) 

654 epsilon = self.rng.multivariate_normal(np.zeros(n), D_hat, sample_size) 

655 return np.dot(L, W_hat) + epsilon 

656 

657 def fit( 

658 self, dataset: npt.NDArray, method: str = "mle", lambda_reg=0.1, mu_reg=0.1 

659 ) -> "DevMedil": 

660 if method == "mle": 

661 return self.fit_penalized_mle(dataset, lambda_reg, mu_reg) 

662 elif method == "lse": 

663 return self.fit_penalized_lse(dataset, lambda_reg, mu_reg) 

664 else: 

665 raise ValueError("Method must be either 'mle' or 'lse'") 

666 

667 

668class DevMedilInterv(NeuroCausalFactorAnalysis): 

669 def __init__(self, **kwargs): 

670 super().__init__(**kwargs) 

671 self.parameters = Parameters("InterVAE") 

672 self.hyperparams.update( 

673 { 

674 "batch_size": 128, 

675 "num_epochs": 200, 

676 "lr": 0.005, 

677 "beta": 1, 

678 "num_valid": 1000, 

679 "lambda": 0.01, 

680 "meas_width": 1, 

681 "meas_depth": 0, 

682 "num_latent": 5, 

683 "latent_width": 1, 

684 "latent_depth": 0, 

685 } 

686 ) 

687 

688 def _train_vae(self, train_loader, valid_loader): 

689 """Training VAE with the specified image dataset 

690 :param m: dimension of the latent variable 

691 :param n: dimension of the observed variable 

692 :param train_loader: training image dataset loader 

693 :param valid_loader: validation image dataset loader 

694 :param biadj_mat: the adjacency matrix of the directed graph 

695 :param seed: random seed for the experiments 

696 :return: trained model and training loss history 

697 """ 

698 

699 num_meas = self.dataset.shape[1] - 1 # one column for interv labels; 

700 self.num_meas = num_meas 

701 

702 # building VAE 

703 model = InterVAE( 

704 num_meas, 

705 self.hyperparams["meas_width"], 

706 self.hyperparams["meas_depth"], 

707 self.hyperparams["num_latent"], 

708 self.hyperparams["latent_width"], 

709 self.hyperparams["latent_depth"], 

710 ) 

711 model = model.to(self.device) 

712 optimizer = torch.optim.AdamW( 

713 model.parameters(), lr=self.hyperparams["lr"], weight_decay=1e-5 

714 ) 

715 # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.90) 

716 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 

717 optimizer, patience=10, factor=0.5 

718 ) 

719 num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 

720 self.log(f"Number of parameters: {num_params}") 

721 

722 # training loop 

723 model.train() 

724 train_elbo, train_error = [], [] 

725 valid_elbo, valid_error = [], [] 

726 

727 pbar = tqdm( 

728 range(self.hyperparams["num_epochs"]), desc="Training NCFA", unit="epoch" 

729 ) 

730 for idx in pbar: 

731 self.log(f"Training on epoch {idx}...") 

732 train_lb, train_er, nbatch = 0.0, 0.0, 0 

733 

734 for batch, _ in train_loader: 

735 x_batch = batch[:, :-1] 

736 interv_idx_batch = batch[:, -1, None] 

737 batch_size = x_batch.shape[0] 

738 x_batch = x_batch.to(self.device) 

739 recon_batch, logcov_batch, mu_batch, logvar_batch = model( 

740 x_batch, interv_idx_batch 

741 ) 

742 causal_biadj_batch = model.decoder.mean_causal.weight 

743 loss = self._elbo_gaussian( 

744 x_batch, 

745 recon_batch, 

746 logcov_batch, 

747 mu_batch, 

748 logvar_batch, 

749 causal_biadj_batch, 

750 self.hyperparams["beta"], 

751 ) 

752 error = self._recon_error( 

753 x_batch, recon_batch, logcov_batch, weighted=False 

754 ) 

755 optimizer.zero_grad() 

756 loss.backward() 

757 optimizer.step() 

758 

759 # update loss and nbatch 

760 train_lb += loss.item() / batch_size 

761 train_er += error.item() / batch_size 

762 nbatch += 1 

763 

764 # finish training epoch 

765 # scheduler.step() 

766 train_lb = train_lb / nbatch 

767 train_er = train_er / nbatch 

768 train_elbo.append(train_lb) 

769 train_error.append(train_er) 

770 self.log(f"Finish training epoch {idx} with loss {train_lb}") 

771 

772 # append validation loss 

773 valid_lb, valid_er = self._valid_vae(model, valid_loader) 

774 valid_elbo.append(valid_lb) 

775 valid_error.append(valid_er) 

776 

777 # decrease learning rate if validation plateaus 

778 scheduler.step(valid_lb) 

779 

780 # update tqdm progress bar 

781 pbar.set_postfix({"loss": train_lb}) # , "validation loss": valid_lb 

782 

783 train_elbo, train_error = np.array(train_elbo), np.array(train_error) 

784 valid_elbo, valid_error = np.array(valid_elbo), np.array(valid_error) 

785 elbo = [train_elbo, valid_elbo] 

786 error = [train_error, valid_error] 

787 

788 return model, elbo, error 

789 

790 def _valid_vae(self, model, valid_loader): 

791 """Training VAE with the specified image dataset 

792 :param model: trained VAE model 

793 :param valid_loader: validation image dataset loader 

794 :return: validation loss 

795 """ 

796 # set to evaluation mode 

797 model.eval() 

798 valid_lb, valid_er, nbatch = 0.0, 0.0, 0 

799 

800 for batch, _ in valid_loader: 

801 x_batch = batch[:, :-1] 

802 interv_idx_batch = batch[:, -1, None] 

803 

804 with torch.no_grad(): 

805 batch_size = x_batch.shape[0] 

806 x_batch = x_batch.to(self.device) 

807 recon_batch, logcov_batch, mu_batch, logvar_batch = model( 

808 x_batch, interv_idx_batch 

809 ) 

810 loss = self._elbo_gaussian( 

811 x_batch, 

812 recon_batch, 

813 logcov_batch, 

814 mu_batch, 

815 logvar_batch, 

816 None, 

817 self.hyperparams["beta"], 

818 ) 

819 error = self._recon_error( 

820 x_batch, recon_batch, logcov_batch, weighted=False 

821 ) 

822 

823 # update loss and nbatch 

824 valid_lb += loss.item() / batch_size 

825 valid_er += error.item() / batch_size 

826 nbatch += 1 

827 

828 # report validation loss 

829 valid_lb = valid_lb / nbatch 

830 valid_er = valid_er / nbatch 

831 self.log(f"Finish validation with loss {valid_lb}") 

832 

833 return valid_lb, valid_er 

834 

835 def _elbo_gaussian(self, x, x_recon, logcov, mu, logvar, causal_biadj, beta): 

836 """Calculating loss for variational autoencoder 

837 :param x: original image 

838 :param x_recon: reconstruction in the output layer 

839 :param logcov: log of covariance matrix of the data distribution 

840 :param mu: mean in the fitted variational distribution 

841 :param logvar: log of the variance in the variational distribution 

842 :param beta: beta 

843 :return: reconstruction loss + KL 

844 """ 

845 

846 # KL-divergence 

847 # https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py 

848 # https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py 

849 # https://arxiv.org/pdf/1312.6114.pdf 

850 kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 

851 

852 # reconstruction loss 

853 cov = torch.exp(logcov) 

854 cov = self._apply_along_axis(torch.diag, cov, axis=0) 

855 cov = cov.mean(axis=0) 

856 

857 diff = x - x_recon 

858 recon_loss = torch.sum( 

859 torch.det(cov) 

860 + torch.diagonal( 

861 torch.mm( 

862 torch.mm(diff, torch.inverse(cov)), torch.transpose(diff, 0, 1) 

863 ) 

864 ) 

865 ).mul(-1 / 2) 

866 

867 # elbo 

868 loss = -beta * kl_div + recon_loss 

869 if causal_biadj is not None: 

870 llambda = self.hyperparams["lambda"] 

871 norm_type = 2 

872 kernel_size = ( 

873 self.hyperparams["latent_width"], 

874 self.hyperparams["latent_width"], 

875 ) 

876 causal_biadj = causal_biadj[None, None, :, :] 

877 causal_biadj = lp_pool2d( 

878 causal_biadj, norm_type, kernel_size 

879 ).squeeze() # penalize num edges 

880 self.parameters.causal_biadj = causal_biadj.detach().numpy().T 

881 return -loss + llambda * causal_biadj.norm(1) 

882 return -loss 

883 

884 

885class DevMedilInterv2(NeuroCausalFactorAnalysis): 

886 def __init__(self, **kwargs): 

887 super().__init__(**kwargs) 

888 self.parameters = Parameters("InterVAE2") 

889 self.hyperparams.update( 

890 { 

891 "batch_size": 128, 

892 "num_epochs": 200, 

893 "lr": 0.005, 

894 "beta": 1, 

895 "num_valid": 1000, 

896 "sparse_reg": 10, 

897 "dag_reg": 10, 

898 "meas_width": 1, 

899 "meas_depth": 0, 

900 "num_latent": 5, 

901 "latent_width": 1, 

902 "latent_depth": 0, 

903 } 

904 ) 

905 

906 def _data_loader(self, sample): 

907 sample_x = sample.astype(np.float32) 

908 sample_z = np.empty(shape=(sample_x.shape[0], 0)).astype(np.float32) 

909 dataset = TensorDataset(torch.tensor(sample_x), torch.tensor(sample_z)) 

910 data_loader = DataLoader( 

911 dataset, 

912 batch_sampler=self._sampler(dataset, self.hyperparams["batch_size"]), 

913 ) 

914 return data_loader 

915 

916 class _sampler(torch.utils.data.Sampler): 

917 def __init__(self, data: List[str], batch_size: int) -> None: 

918 self.data = data 

919 self.batch_size = batch_size 

920 sample_x = data[:][0] 

921 self.labels = sample_x[:, -1] 

922 self.contexts, self.inv, self.counts = torch.unique( 

923 self.labels, return_inverse=True, return_counts=True 

924 ) 

925 

926 def __len__(self) -> int: 

927 # self.chunk_sizes = [ 

928 # (count + self.batch_size - 1) // self.batch_size 

929 # for count in self.counts 

930 # ] 

931 # return sum(self.chunk_sizes) 

932 return (len(self.data) + self.batch_size - 1) // self.batch_size 

933 

934 def __iter__(self) -> Iterator[List[int]]: 

935 context_idcs = torch.multinomial( 

936 self.counts / self.counts.sum(), len(self), replacement=True 

937 ) 

938 # contexts = self.contexts[context_idcs] 

939 for context_idx in context_idcs: 

940 context_data_idcs = torch.where(self.inv == context_idx)[0] 

941 batch_idcs = context_data_idcs[torch.randperm(len(context_data_idcs))][ 

942 : self.batch_size 

943 ] 

944 yield batch_idcs 

945 

946 def _train_vae(self, train_loader, valid_loader): 

947 """Training VAE with the specified image dataset 

948 :param m: dimension of the latent variable 

949 :param n: dimension of the observed variable 

950 :param train_loader: training image dataset loader 

951 :param valid_loader: validation image dataset loader 

952 :param biadj_mat: the adjacency matrix of the directed graph 

953 :param seed: random seed for the experiments 

954 :return: trained model and training loss history 

955 """ 

956 

957 num_meas = self.dataset.shape[1] - 1 # one column for interv labels; 

958 self.num_meas = num_meas 

959 

960 # building VAE 

961 model = InterVAE2( 

962 num_meas, 

963 self.hyperparams["meas_width"], 

964 self.hyperparams["meas_depth"], 

965 self.hyperparams["num_latent"], 

966 self.hyperparams["latent_width"], 

967 self.hyperparams["latent_depth"], 

968 ) 

969 model = model.to(self.device) 

970 optimizer = torch.optim.AdamW( 

971 model.parameters(), lr=self.hyperparams["lr"], weight_decay=1e-5 

972 ) 

973 # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.90) 

974 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 

975 optimizer, patience=10, factor=0.5 

976 ) 

977 num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 

978 self.log(f"Number of parameters: {num_params}") 

979 

980 # training loop 

981 model.train() 

982 train_elbo, train_error = [], [] 

983 valid_elbo, valid_error = [], [] 

984 

985 pbar = tqdm( 

986 range(self.hyperparams["num_epochs"]), desc="Training NCFA", unit="epoch" 

987 ) 

988 for idx in pbar: 

989 self.log(f"Training on epoch {idx}...") 

990 train_lb, train_er, nbatch = 0.0, 0.0, 0 

991 

992 for batch, _ in train_loader: 

993 x_batch = batch[:, :-1] 

994 interv_idx_batch = batch[:, -1, None] 

995 interv_idx = torch.unique(interv_idx_batch).to(int) 

996 assert len(interv_idx) == 1 

997 interv_idx = int(interv_idx) 

998 batch_size = x_batch.shape[0] 

999 x_batch = x_batch.to(self.device) 

1000 recon_batch, logcov_batch, mu_batch, logvar_batch = model( 

1001 x_batch, interv_idx_batch 

1002 ) 

1003 causal_biadj_batch = model.decoder.mean_causal[interv_idx].weight 

1004 loss = self._elbo_gaussian( 

1005 x_batch, 

1006 recon_batch, 

1007 logcov_batch, 

1008 mu_batch, 

1009 logvar_batch, 

1010 causal_biadj_batch, 

1011 self.hyperparams["beta"], 

1012 ) 

1013 error = self._recon_error( 

1014 x_batch, recon_batch, logcov_batch, weighted=False 

1015 ) 

1016 optimizer.zero_grad() 

1017 loss.backward() 

1018 optimizer.step() 

1019 

1020 # update loss and nbatch 

1021 train_lb += loss.item() / batch_size 

1022 train_er += error.item() / batch_size 

1023 nbatch += 1 

1024 

1025 # finish training epoch 

1026 # scheduler.step() 

1027 train_lb = train_lb / nbatch 

1028 train_er = train_er / nbatch 

1029 train_elbo.append(train_lb) 

1030 train_error.append(train_er) 

1031 self.log(f"Finish training epoch {idx} with loss {train_lb}") 

1032 

1033 # append validation loss 

1034 valid_lb, valid_er = self._valid_vae(model, valid_loader) 

1035 valid_elbo.append(valid_lb) 

1036 valid_error.append(valid_er) 

1037 

1038 # decrease learning rate if validation plateaus 

1039 scheduler.step(valid_lb) 

1040 

1041 # update tqdm progress bar 

1042 pbar.set_postfix({"loss": train_lb}) # , "validation loss": valid_lb 

1043 

1044 train_elbo, train_error = np.array(train_elbo), np.array(train_error) 

1045 valid_elbo, valid_error = np.array(valid_elbo), np.array(valid_error) 

1046 elbo = [train_elbo, valid_elbo] 

1047 error = [train_error, valid_error] 

1048 

1049 temp = { 

1050 k: v.weight.detach().numpy().T for k, v in model.decoder.mean_causal.items() 

1051 } 

1052 norm_type = 2 

1053 kernel_size = ( 

1054 self.hyperparams["latent_width"], 

1055 self.hyperparams["latent_width"], 

1056 ) 

1057 temp = { 

1058 t: lp_pool2d( 

1059 v.weight.detach()[None, None, :, :], norm_type, kernel_size 

1060 ).squeeze() 

1061 for t, v in model.decoder.mean_causal.items() 

1062 } 

1063 self.parameters.causal_biadj_dict = {k: v.numpy().T for k, v in temp.items()} 

1064 

1065 return model, elbo, error 

1066 

1067 def _valid_vae(self, model, valid_loader): 

1068 """Training VAE with the specified image dataset 

1069 :param model: trained VAE model 

1070 :param valid_loader: validation image dataset loader 

1071 :return: validation loss 

1072 """ 

1073 # set to evaluation mode 

1074 model.eval() 

1075 valid_lb, valid_er, nbatch = 0.0, 0.0, 0 

1076 

1077 for batch, _ in valid_loader: 

1078 x_batch = batch[:, :-1] 

1079 interv_idx_batch = batch[:, -1, None] 

1080 

1081 with torch.no_grad(): 

1082 batch_size = x_batch.shape[0] 

1083 x_batch = x_batch.to(self.device) 

1084 recon_batch, logcov_batch, mu_batch, logvar_batch = model( 

1085 x_batch, interv_idx_batch 

1086 ) 

1087 loss = self._elbo_gaussian( 

1088 x_batch, 

1089 recon_batch, 

1090 logcov_batch, 

1091 mu_batch, 

1092 logvar_batch, 

1093 None, 

1094 self.hyperparams["beta"], 

1095 ) 

1096 error = self._recon_error( 

1097 x_batch, recon_batch, logcov_batch, weighted=False 

1098 ) 

1099 

1100 # update loss and nbatch 

1101 valid_lb += loss.item() / batch_size 

1102 valid_er += error.item() / batch_size 

1103 nbatch += 1 

1104 

1105 # report validation loss 

1106 valid_lb = valid_lb / nbatch 

1107 valid_er = valid_er / nbatch 

1108 self.log(f"Finish validation with loss {valid_lb}") 

1109 

1110 return valid_lb, valid_er 

1111 

1112 def _elbo_gaussian(self, x, x_recon, logcov, mu, logvar, causal_biadj, beta): 

1113 """Calculating loss for variational autoencoder 

1114 :param x: original image 

1115 :param x_recon: reconstruction in the output layer 

1116 :param logcov: log of covariance matrix of the data distribution 

1117 :param mu: mean in the fitted variational distribution 

1118 :param logvar: log of the variance in the variational distribution 

1119 :param beta: beta 

1120 :return: reconstruction loss + KL 

1121 """ 

1122 

1123 # KL-divergence 

1124 # https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py 

1125 # https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py 

1126 # https://arxiv.org/pdf/1312.6114.pdf 

1127 kl_div_loss = torch.mean( 

1128 -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp(), dim=1), dim=0 

1129 ) 

1130 

1131 recon_loss = torch.nn.functional.mse_loss(x_recon, x) 

1132 

1133 # elbo loss 

1134 loss = beta * kl_div_loss + recon_loss 

1135 if causal_biadj is not None: 

1136 sparse_reg = self.hyperparams["sparse_reg"] 

1137 dag_reg = self.hyperparams["dag_reg"] 

1138 norm_type = 2 

1139 kernel_size = ( 

1140 self.hyperparams["latent_width"], 

1141 self.hyperparams["latent_width"], 

1142 ) 

1143 pooled = causal_biadj[None, None, :, :] 

1144 pooled = lp_pool2d( 

1145 pooled, norm_type, kernel_size 

1146 ).squeeze() # penalize num edges 

1147 density = pooled.mean() # L1 norm / num_entries 

1148 

1149 # https://dagma.readthedocs.io/en/latest/#the-log-det-acyclicity-characterization 

1150 s = torch.tensor([5]) 

1151 d = len(pooled) 

1152 nondagness = -torch.logdet( 

1153 s * torch.eye(d) - torch.square(pooled - torch.diag(pooled)) 

1154 ) + d * torch.log(s) 

1155 

1156 return loss + sparse_reg * density + dag_reg * nondagness 

1157 return loss 

1158 

1159 

1160class BlockLinear(nn.Module): 

1161 def __init__( 

1162 self, 

1163 context_dims, 

1164 width, 

1165 bias=True, 

1166 device=None, 

1167 dtype=None, 

1168 ): 

1169 factory_kwargs = {"device": device, "dtype": dtype} 

1170 super().__init__() 

1171 self.block_mask = torch.eye(context_dims).kron(torch.ones(width, width)) 

1172 num_features = context_dims * width 

1173 self.weight = Parameter( 

1174 torch.empty((num_features, num_features), **factory_kwargs) 

1175 ) 

1176 if bias: 

1177 self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) 

1178 else: 

1179 self.register_parameter("bias", None) 

1180 self.reset_parameters() 

1181 

1182 def reset_parameters(self): 

1183 # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 

1184 nn.init.orthogonal_(self.weight) 

1185 # nn.init.sparse_(self.weight, 2 / 3) 

1186 if self.bias is not None: 

1187 fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 

1188 bound = 1 / torch.sqrt(torch.tensor(fan_in)) if fan_in > 0 else 0 

1189 nn.init.uniform_(self.bias, -bound, bound) 

1190 

1191 def forward(self, input): 

1192 # masked linear layer 

1193 return nn.functional.linear(input, self.weight * self.block_mask, self.bias) 

1194 

1195 def extra_repr(self): 

1196 return "in_features={}, out_features={}, bias={}".format( 

1197 self.in_features, self.out_features, self.bias is not None 

1198 ) 

1199 

1200 

1201class Intervenable(nn.Module): 

1202 def __init__( 

1203 self, 

1204 in_features, 

1205 out_features, 

1206 mask=None, 

1207 bias=True, 

1208 device=None, 

1209 dtype=None, 

1210 ): 

1211 self.factory_kwargs = {"device": device, "dtype": dtype} 

1212 super().__init__() 

1213 self.in_features = in_features 

1214 self.out_features = out_features 

1215 self.mask = mask 

1216 self.weight = Parameter( 

1217 torch.empty((out_features, in_features), **self.factory_kwargs) 

1218 ) 

1219 

1220 if bias: 

1221 self.bias = Parameter(torch.empty(out_features, **self.factory_kwargs)) 

1222 else: 

1223 self.register_parameter("bias", None) 

1224 self.reset_parameters() 

1225 

1226 def reset_parameters(self): 

1227 nn.init.orthogonal_(self.weight) 

1228 if self.bias is not None: 

1229 fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 

1230 bound = 1 / torch.sqrt(torch.tensor(fan_in)) if fan_in > 0 else 0 

1231 nn.init.uniform_(self.bias, -bound, bound) 

1232 

1233 def forward(self, input, obs_weight, interv_idx): 

1234 if interv_idx != -1: 

1235 min_weight = torch.minimum(self.weight, obs_weight) 

1236 interv_mask = torch.ones_like(min_weight, **self.factory_kwargs) 

1237 interv_mask[interv_idx] = 0 

1238 interv_mask[interv_idx, interv_idx] = 1 

1239 self.weight.data = min_weight * interv_mask 

1240 if self.mask is None: 

1241 return nn.functional.linear(input, self.weight, self.bias) 

1242 else: 

1243 return nn.functional.linear(input, self.weight * self.mask, self.bias) 

1244 

1245 def extra_repr(self): 

1246 return "in_features={}, out_features={}, bias={}".format( 

1247 self.in_features, self.out_features, self.bias is not None 

1248 ) 

1249 

1250 

1251class VAE(nn.Module): 

1252 def __init__(self, input_dims, context_dims, width, depth, hidden_dims): 

1253 super().__init__() 

1254 latent_dims = context_dims * width 

1255 self.input_dims = input_dims 

1256 self.width = width 

1257 

1258 # Encoder 

1259 self.encoder = nn.Sequential( 

1260 nn.Linear(input_dims, hidden_dims), 

1261 nn.BatchNorm1d(hidden_dims), 

1262 nn.GELU(), 

1263 nn.Linear(hidden_dims, hidden_dims), 

1264 nn.BatchNorm1d(hidden_dims), 

1265 nn.GELU(), 

1266 ) 

1267 self.fc_mu = nn.Linear(hidden_dims, latent_dims) 

1268 self.fc_var = nn.Linear(hidden_dims, latent_dims) 

1269 

1270 # Our module 

1271 unchained = BlockLinear(context_dims, width), nn.GELU() 

1272 deeply_expressive = chain(*(unchained for _ in range(depth))) 

1273 self.expressive_layer = nn.Sequential(*deeply_expressive, nn.AvgPool1d(width)) 

1274 self.causal_layer = nn.ModuleDict( 

1275 { 

1276 str(interv_idx): Intervenable( 

1277 in_features=context_dims, 

1278 out_features=input_dims, 

1279 ) 

1280 for interv_idx in range(-1, context_dims) 

1281 } 

1282 ) 

1283 

1284 # Decoder 

1285 self.decoder = nn.Sequential( 

1286 nn.Linear(input_dims, hidden_dims), 

1287 nn.BatchNorm1d(hidden_dims), 

1288 nn.GELU(), 

1289 nn.Linear(hidden_dims, input_dims), 

1290 ) 

1291 

1292 def encode(self, x): 

1293 h = self.encoder(x) 

1294 return self.fc_mu(h), self.fc_var(h) 

1295 

1296 def reparameterize(self, mu, log_var): 

1297 std = torch.exp(0.5 * log_var) 

1298 eps = torch.randn_like(std) 

1299 return mu + eps * std 

1300 

1301 def decode(self, z): 

1302 epsilon = self.expressive_layer(z) 

1303 obs_weight = self.causal_layer[str(-1)].weight 

1304 l = self.causal_layer[str(self.batch_label)]( 

1305 epsilon, obs_weight, self.batch_label 

1306 ) 

1307 return self.decoder(l) 

1308 

1309 def forward(self, x, label): 

1310 self.batch_label = label 

1311 mu, log_var = self.encode(x) 

1312 z = self.reparameterize(mu, log_var) 

1313 return self.decode(z), mu, log_var 

1314 

1315 

1316def _numpy_to_pytorch_dataset(np_dataset): 

1317 """ 

1318 Converts a NumPy dataset into a PyTorch dataset. 

1319 

1320 Parameters: 

1321 - np_dataset: NumPy array where rows are samples and columns are features. 

1322 

1323 Returns: 

1324 - A PyTorch TensorDataset containing samples and labels. 

1325 """ 

1326 

1327 # Ensure the input is a NumPy array 

1328 if not isinstance(np_dataset, np.ndarray): 

1329 raise ValueError("Input must be a NumPy array.") 

1330 

1331 # Split the dataset into features (X) and labels (y) 

1332 X = np_dataset[:, :-1] # Features are all columns except the last 

1333 y = np_dataset[:, -1] # Labels are the last column 

1334 

1335 # Convert NumPy arrays to PyTorch tensors 

1336 X_tensor = torch.from_numpy(X).float() 

1337 y_tensor = ( 

1338 torch.from_numpy(y).float().unsqueeze(-1) 

1339 ) # Optional: unsqueeze for consistency 

1340 

1341 # Create a PyTorch TensorDataset 

1342 dataset = TensorDataset(X_tensor, y_tensor) 

1343 

1344 return dataset 

1345 

1346 

1347class _sampler(torch.utils.data.Sampler): 

1348 def __init__(self, dataset, batch_size: int) -> None: 

1349 self.data = dataset 

1350 self.batch_size = batch_size 

1351 self.labels = dataset.tensors[1] 

1352 self.contexts, self.inv, self.counts = torch.unique( 

1353 self.labels, return_inverse=True, return_counts=True 

1354 ) 

1355 

1356 def __len__(self) -> int: 

1357 return (len(self.data) + self.batch_size - 1) // self.batch_size 

1358 

1359 def __iter__(self): 

1360 context_idcs = torch.multinomial( 

1361 self.counts / self.counts.sum(), len(self), replacement=True 

1362 ) 

1363 # contexts = self.contexts[context_idcs] 

1364 for context_idx in context_idcs: 

1365 context_data_idcs = torch.where(self.inv == context_idx)[0] 

1366 batch_idcs = context_data_idcs[torch.randperm(len(context_data_idcs))][ 

1367 : self.batch_size 

1368 ] 

1369 yield batch_idcs 

1370 

1371 

1372class IvnFA(object): 

1373 def __init__(self): 

1374 self.hyperparams = { 

1375 "batch_size": 128, 

1376 "num_epochs": 100, 

1377 "lr": 0.005, 

1378 "beta": 1, 

1379 "num_valid": 1000, 

1380 "sparse_reg": 10, 

1381 "width": 1, 

1382 "depth": 0, 

1383 "context_dims": 5, 

1384 "hidden_dims": 50, 

1385 } 

1386 self.checkpoint_save_path = None 

1387 self.final_save_path = None 

1388 self.loss_dict = { 

1389 "elbo_train": [], 

1390 "recon_train": [], 

1391 "elbo_valid": [], 

1392 "recon_valid": [], 

1393 } 

1394 

1395 def fit(self, dataset, split_idcs=None): 

1396 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

1397 

1398 input_dims = dataset.shape[1] - 1 

1399 

1400 # random train/val split if explicit indices not provided 

1401 if split_idcs is None: 

1402 train_split, valid_split = train_test_split(dataset, train_size=0.7) 

1403 else: 

1404 train_split = dataset[split_idcs[0]] 

1405 valid_split = dataset[split_idcs[1]] 

1406 

1407 # train 

1408 dataset = _numpy_to_pytorch_dataset(train_split) 

1409 sampler = _sampler(dataset, self.hyperparams["batch_size"]) 

1410 self.train_loader = DataLoader(dataset, batch_sampler=sampler) 

1411 

1412 # valid 

1413 dataset = _numpy_to_pytorch_dataset(valid_split) 

1414 sampler = _sampler(dataset, self.hyperparams["batch_size"]) 

1415 self.valid_loader = DataLoader(dataset, batch_sampler=sampler) 

1416 

1417 context_dims, width, depth, hidden_dims = ( 

1418 self.hyperparams["context_dims"], 

1419 self.hyperparams["width"], 

1420 self.hyperparams["depth"], 

1421 self.hyperparams["hidden_dims"], 

1422 ) 

1423 self.model = VAE(input_dims, context_dims, width, depth, hidden_dims).to( 

1424 self.device 

1425 ) 

1426 self.optimizer = torch.optim.Adam( 

1427 self.model.parameters(), lr=self.hyperparams["lr"] 

1428 ) 

1429 

1430 losses = [] 

1431 pbar = tqdm( 

1432 range(self.hyperparams["num_epochs"]), desc="Training...", unit="epoch" 

1433 ) 

1434 

1435 for epoch in pbar: 

1436 loss = self._train() 

1437 self._validate() 

1438 losses.append(loss) 

1439 pbar.set_postfix({"loss": f"{loss:.4f}"}) 

1440 

1441 # Save checkpoint after each epoch 

1442 if self.checkpoint_save_path is not None: 

1443 torch.save( 

1444 { 

1445 "epoch": epoch, 

1446 "model_state_dict": self.model.state_dict(), 

1447 "optimizer_state_dict": self.optimizer.state_dict(), 

1448 "loss": loss, 

1449 "losses": losses, 

1450 }, 

1451 f"{self.checkpoint_save_path}.pt", 

1452 ) 

1453 

1454 # Save model and training losses 

1455 if self.final_save_path is not None: 

1456 torch.save( 

1457 { 

1458 "model_state_dict": self.model.state_dict(), 

1459 "optimizer_state_dict": self.optimizer.state_dict(), 

1460 "losses": losses, 

1461 }, 

1462 f"{self.final_save_path}.pt", 

1463 ) 

1464 self.losses = losses 

1465 self.causal_weight_dict = { 

1466 k: v.weight.detach().numpy() for k, v in self.model.causal_layer.items() 

1467 } 

1468 

1469 def _loss_function(self, recon_x, x, mu, log_var, causal_weights): 

1470 MSE = nn.MSELoss(reduction="mean") 

1471 mse_loss = MSE(recon_x, x) 

1472 

1473 # see Appendix B from VAE paper: 

1474 # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 

1475 # https://arxiv.org/abs/1312.6114 

1476 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 

1477 KLD = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp()) 

1478 

1479 causal_weights = torch.abs(causal_weights) 

1480 sparse_reg = causal_weights.pow(2).mean() 

1481 

1482 return ( 

1483 mse_loss 

1484 + self.hyperparams["beta"] * KLD 

1485 + self.hyperparams["llambda"] * sparse_reg 

1486 ), (mse_loss, KLD) 

1487 

1488 def _train(self): 

1489 self.model.train() 

1490 batch_size = len(self.train_loader.dataset) 

1491 train_loss = 0 

1492 mse = 0 

1493 kl = 0 

1494 pbar = tqdm( 

1495 self.train_loader, desc="training epoch...", unit="batch", leave=False 

1496 ) 

1497 for batch_idx, (data, labels) in enumerate(pbar): 

1498 data = data.to(self.device) 

1499 label = torch.unique(labels).to(int).item() 

1500 self.optimizer.zero_grad() 

1501 recon_batch, mu, log_var = self.model(data, label) 

1502 causal_weights_batch = self.model.causal_layer[str(label)].weight 

1503 loss, (mse, kl) = self._loss_function( 

1504 recon_batch, data, mu, log_var, causal_weights_batch 

1505 ) 

1506 loss.backward() 

1507 train_loss += loss.item() 

1508 mse += mse.item() 

1509 kl += kl.item() 

1510 self.optimizer.step() 

1511 mse = mse / batch_size 

1512 self.loss_dict["recon_train"].append(mse.item()) 

1513 elbo = mse + (kl / batch_size) 

1514 self.loss_dict["elbo_train"].append(elbo.item()) 

1515 return train_loss / batch_size 

1516 

1517 def _validate(self): 

1518 self.model.eval() 

1519 batch_size = len(self.train_loader.dataset) 

1520 mse = 0 

1521 kl = 0 

1522 for batch_idx, (data, labels) in enumerate(self.valid_loader): 

1523 data = data.to(self.device) 

1524 label = torch.unique(labels).to(int).item() 

1525 recon_batch, mu, log_var = self.model(data, label) 

1526 causal_weights_batch = self.model.causal_layer[str(label)].weight 

1527 _, (mse, kl) = self._loss_function( 

1528 recon_batch, data, mu, log_var, causal_weights_batch 

1529 ) 

1530 mse += mse.item() 

1531 kl += kl.item() 

1532 mse = mse / batch_size 

1533 self.loss_dict["recon_valid"].append(mse.item()) 

1534 elbo = mse + (kl / batch_size) 

1535 self.loss_dict["elbo_valid"].append(elbo.item())