Coverage for /opt/conda/lib/python3.13/site-packages/medil/models.py: 89%

225 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-01 15:11 +0000

1"""MeDIL causal model base class and a preconfigured NCFA class.""" 

2 

3import copy 

4import os 

5import pickle 

6import random 

7import warnings 

8from datetime import datetime 

9from pathlib import Path 

10 

11import numpy as np 

12import numpy.typing as npt 

13import torch 

14import torch.nn.functional as F 

15from numpy.random import default_rng 

16from scipy.optimize import minimize 

17from sklearn.model_selection import train_test_split 

18from torch.utils.data import DataLoader, TensorDataset 

19from tqdm import tqdm 

20 

21from .ecc_algorithms import find_heuristic_1pc 

22from .independence_testing import estimate_UDG 

23from .vae import VariationalAutoencoder 

24 

25 

26class MedilCausalModel(object): 

27 """Base class using principle of polymorphism to establish common 

28 interface for derived parametric estimators. 

29 """ 

30 

31 def __init__( 

32 self, 

33 biadj: npt.NDArray = np.array([]), 

34 udg: npt.NDArray = np.array([]), 

35 one_pure_child: bool = True, 

36 rng=default_rng(0), 

37 ) -> None: 

38 self.biadj = biadj 

39 self.udg = udg 

40 self.one_pure_child = one_pure_child 

41 self.rng = rng 

42 

43 def fit(self, dataset: npt.NDArray) -> "MedilCausalModel": 

44 raise NotImplementedError 

45 

46 def sample(self, sample_size: int) -> npt.NDArray: 

47 raise NotImplementedError 

48 

49 

50class Parameters(object): 

51 "Different parameterizations of MeDIL causal Models." 

52 

53 def __init__(self, parameterization: str) -> None: 

54 self.parameterization = parameterization 

55 

56 if parameterization == "Gaussian": 

57 self.error_means = np.array([]) 

58 self.error_variances = np.array([]) 

59 self.biadj_weights = np.array([]) 

60 elif parameterization == "VAE": 

61 self.weights = np.array([]) 

62 self.vae = None 

63 

64 def __str__(self) -> str: 

65 return "\n".join( 

66 f"parameters.{attr}: {val}" for attr, val in vars(self).items() 

67 ) 

68 

69 

70class GaussianMCM(MedilCausalModel): 

71 """A linear MeDIL causal model with Gaussian random variables.""" 

72 

73 def __init__(self, **kwargs): 

74 super().__init__(**kwargs) 

75 self.parameters = Parameters("Gaussian") 

76 

77 def fit(self, dataset: npt.NDArray) -> "GaussianMCM": 

78 """Fit a Gaussian MCM to a dataset with constraint-based 

79 structure learning and least squares parameter estimation.""" 

80 self.dataset = dataset 

81 if self.biadj.size == 0: 

82 self._compute_biadj() 

83 

84 self.parameters.error_means = self.dataset.mean(0) 

85 

86 cov = np.cov(self.dataset, rowvar=False) 

87 

88 num_weights = self.biadj.sum() 

89 num_err_vars = self.biadj.shape[1] 

90 

91 def _objective(weights_and_err_vars): 

92 weights = weights_and_err_vars[:num_weights] 

93 err_vars = weights_and_err_vars[num_weights:] 

94 

95 biadj_weights = np.zeros_like(self.biadj, float) 

96 biadj_weights[self.biadj] = weights 

97 

98 return ( 

99 (cov - biadj_weights.T @ biadj_weights - np.diagflat(err_vars)) ** 2 

100 ).sum() 

101 

102 result = minimize(_objective, np.ones(num_weights + num_err_vars)) 

103 if not result.success: 

104 warnings.warn(f"Optimization failed: {result.message}") 

105 

106 self.parameters.error_variances = result.x[num_weights:] 

107 

108 self.parameters.biadj_weights = np.zeros_like(self.biadj, float) 

109 self.parameters.biadj_weights[self.biadj] = result.x[:num_weights] 

110 

111 return self 

112 

113 def _compute_biadj(self): 

114 """Constraint-based structure learning.""" 

115 if self.udg.size == 0: 

116 self._estimate_udg() 

117 self.biadj = find_heuristic_1pc(self.udg) 

118 

119 def _estimate_udg(self): 

120 """Constraint-based structure learning.""" 

121 samp_size = len(self.dataset) 

122 cov = np.cov(self.dataset, rowvar=False) 

123 corr = np.corrcoef(self.dataset, rowvar=False) 

124 inner_numerator = 1 - cov * corr # should never be <= 0? 

125 inner_numerator = inner_numerator.clip(min=0.00001) 

126 inner_numerator[np.tril_indices_from(inner_numerator)] = 1 

127 udg_triu = np.log(inner_numerator) < (-np.log(samp_size) / samp_size) 

128 udg = udg_triu + udg_triu.T 

129 self.udg = udg 

130 

131 def sample(self, sample_size: int, include_latent: bool = False) -> npt.NDArray: 

132 """Sample a dataset from a GaussianMCM, after structure and 

133 parameters have been specified or estimated.""" 

134 num_latent = len(self.biadj) 

135 latent_sample = self.rng.multivariate_normal( 

136 np.zeros(num_latent), np.eye(num_latent), sample_size 

137 ) 

138 error_sample = self.rng.multivariate_normal( 

139 self.parameters.error_means, 

140 np.diagflat(self.parameters.error_variances), 

141 sample_size, 

142 ) 

143 sample = latent_sample @ self.parameters.biadj_weights + error_sample 

144 

145 return (sample, latent_sample) if include_latent else sample 

146 

147 

148class NeuroCausalFactorAnalysis(MedilCausalModel): 

149 """A MeDIL causal model represented by a masked variational autoencoder.""" 

150 

151 def __init__( 

152 self, 

153 seed: int = 0, 

154 log_path: str = "", 

155 verbose: bool = False, 

156 **kwargs, 

157 ): 

158 super().__init__(**kwargs) 

159 

160 if log_path: 

161 Path(log_path).mkdir(exist_ok=True) 

162 

163 self.log_path = log_path 

164 self.verbose = verbose 

165 self.seed = seed 

166 

167 self.hyperparams = { 

168 "method": "xicor", 

169 "alpha": 0.05, 

170 "batch_size": 128, 

171 "num_epochs": 200, 

172 "lr": 1e-3, 

173 "beta": 1.0, 

174 "latent_width": 2, 

175 "meas_width": 2, 

176 "num_hidden_layers": 1, 

177 "encoder_hidden_dim": 64, 

178 "shuffle": True, 

179 "early_stopping": True, 

180 "patience": 20, 

181 "min_delta": 1e-4, 

182 } 

183 

184 self.parameters = Parameters("VAE") 

185 self.loss = None 

186 self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

187 

188 def log(self, entry: str) -> None: 

189 if not (self.log_path or self.verbose): 

190 return 

191 

192 time_stamped_entry = f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} {entry}" 

193 

194 if self.log_path: 

195 with open(os.path.join(self.log_path, "training.log"), "a") as log_file: 

196 log_file.write(time_stamped_entry + "\n") 

197 

198 if self.verbose: 

199 print(time_stamped_entry) 

200 

201 def _set_deterministic_seed(self): 

202 os.environ["PYTHONHASHSEED"] = str(self.seed) 

203 random.seed(self.seed) 

204 np.random.seed(self.seed) 

205 

206 torch.manual_seed(self.seed) 

207 if torch.cuda.is_available(): 

208 torch.cuda.manual_seed(self.seed) 

209 torch.cuda.manual_seed_all(self.seed) 

210 

211 torch.backends.cudnn.deterministic = True 

212 torch.backends.cudnn.benchmark = False 

213 torch.use_deterministic_algorithms(True) 

214 

215 def fit(self, dataset: npt.NDArray, split_idcs=None) -> "NeuroCausalFactorAnalysis": 

216 self._set_deterministic_seed() 

217 self.dataset = dataset 

218 

219 if self.biadj.size == 0: 

220 self._compute_biadj() 

221 

222 if split_idcs is None: 

223 train_split, valid_split = train_test_split( 

224 dataset, train_size=0.7, random_state=self.seed 

225 ) 

226 else: 

227 train_split = dataset[split_idcs[0]] 

228 valid_split = dataset[split_idcs[1]] 

229 

230 train_loader = self._data_loader(train_split) 

231 valid_loader = self._data_loader(valid_split) 

232 

233 model_recon, loss_recon, error_recon = self._train_vae( 

234 train_loader, valid_loader 

235 ) 

236 

237 if self.log_path: 

238 torch.save( 

239 model_recon.state_dict(), os.path.join(self.log_path, "model_recon.pt") 

240 ) 

241 with open(os.path.join(self.log_path, "loss_recon.pkl"), "wb") as handle: 

242 pickle.dump(loss_recon, handle, protocol=pickle.HIGHEST_PROTOCOL) 

243 with open(os.path.join(self.log_path, "error_recon.pkl"), "wb") as handle: 

244 pickle.dump(error_recon, handle, protocol=pickle.HIGHEST_PROTOCOL) 

245 

246 self.parameters.vae = model_recon 

247 self.loss = { 

248 "elbo_train": loss_recon[0], 

249 "elbo_valid": loss_recon[1], 

250 "recon_train": error_recon[0], 

251 "recon_valid": error_recon[1], 

252 } 

253 

254 return self 

255 

256 def _compute_biadj(self): 

257 if self.udg.size == 0: 

258 self._estimate_udg() 

259 self.biadj = find_heuristic_1pc(self.udg) 

260 

261 def _estimate_udg(self): 

262 self.udg, _ = estimate_UDG( 

263 self.dataset, 

264 method=self.hyperparams["method"], 

265 significance_level=self.hyperparams["alpha"], 

266 ) 

267 

268 def _data_loader(self, sample): 

269 sample_x = sample.astype(np.float32) 

270 dataset = TensorDataset(torch.tensor(sample_x)) 

271 return DataLoader( 

272 dataset, 

273 batch_size=self.hyperparams["batch_size"], 

274 shuffle=self.hyperparams["shuffle"], 

275 num_workers=0, 

276 ) 

277 

278 def _train_vae(self, train_loader, valid_loader): 

279 num_meas = self.dataset.shape[1] 

280 biadj = torch.tensor(self.biadj.T, dtype=torch.float32) 

281 

282 model = VariationalAutoencoder( 

283 num_latent=biadj.shape[1], 

284 num_meas=num_meas, 

285 num_hidden_layers=self.hyperparams["num_hidden_layers"], 

286 latent_width=self.hyperparams["latent_width"], 

287 meas_width=self.hyperparams["meas_width"], 

288 biadj=biadj, 

289 encoder_hidden_dim=self.hyperparams["encoder_hidden_dim"], 

290 ).to(self.device) 

291 

292 optimizer = torch.optim.AdamW(model.parameters(), lr=self.hyperparams["lr"]) 

293 

294 num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 

295 self.log(f"Number of parameters: {num_params}") 

296 

297 train_elbo, train_error = [], [] 

298 valid_elbo, valid_error = [], [] 

299 

300 best_valid = float("inf") 

301 best_state = copy.deepcopy(model.state_dict()) 

302 epochs_without_improvement = 0 

303 

304 pbar = tqdm( 

305 range(self.hyperparams["num_epochs"]), desc="Training NCFA", unit="epoch" 

306 ) 

307 

308 for epoch in pbar: 

309 model.train() 

310 train_lb, train_er, nbatch = 0.0, 0.0, 0 

311 

312 for (x_batch,) in train_loader: 

313 x_batch = x_batch.to(self.device) 

314 

315 x_recon, mu, logvar = model(x_batch) 

316 loss = self._vae_loss( 

317 x_batch, x_recon, mu, logvar, beta=self.hyperparams["beta"] 

318 ) 

319 error = self._recon_error(x_batch, x_recon) 

320 

321 optimizer.zero_grad() 

322 loss.backward() 

323 optimizer.step() 

324 

325 train_lb += loss.item() / x_batch.shape[0] 

326 train_er += error.item() / x_batch.shape[0] 

327 nbatch += 1 

328 

329 train_lb, train_er = self._eval_loss(model, train_loader) 

330 train_elbo.append(train_lb) 

331 train_error.append(train_er) 

332 

333 valid_lb, valid_er = self._eval_loss(model, valid_loader) 

334 valid_elbo.append(valid_lb) 

335 valid_error.append(valid_er) 

336 

337 pbar.set_postfix({"train": train_lb, "valid": valid_lb}) 

338 

339 improved = valid_lb < (best_valid - self.hyperparams["min_delta"]) 

340 if improved: 

341 best_valid = valid_lb 

342 best_state = copy.deepcopy(model.state_dict()) 

343 epochs_without_improvement = 0 

344 else: 

345 epochs_without_improvement += 1 

346 

347 if ( 

348 self.hyperparams["early_stopping"] 

349 and epochs_without_improvement >= self.hyperparams["patience"] 

350 ): 

351 self.log(f"Early stopping at epoch {epoch}") 

352 break 

353 

354 model.load_state_dict(best_state) 

355 

356 return ( 

357 model, 

358 [np.array(train_elbo), np.array(valid_elbo)], 

359 [np.array(train_error), np.array(valid_error)], 

360 ) 

361 

362 def _eval_loss(self, model, loader): 

363 model.eval() 

364 total_loss = 0.0 

365 total_recon = 0.0 

366 n = 0 

367 

368 with torch.no_grad(): 

369 for (x_batch,) in loader: 

370 x_batch = x_batch.to(self.device) 

371 x_recon, mu, logvar = model(x_batch) 

372 loss = self._vae_loss( 

373 x_batch, x_recon, mu, logvar, beta=self.hyperparams["beta"] 

374 ) 

375 recon = self._recon_error(x_batch, x_recon) 

376 

377 bs = x_batch.shape[0] 

378 total_loss += loss.item() 

379 total_recon += recon.item() 

380 n += bs 

381 

382 return total_loss / n, total_recon / n 

383 

384 @staticmethod 

385 def _vae_loss(x, x_recon, mu, logvar, beta=1.0): 

386 recon_loss = F.mse_loss(x_recon, x, reduction="sum") 

387 kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 

388 return recon_loss + beta * kl_div 

389 

390 @staticmethod 

391 def _recon_error(x, x_recon): 

392 return torch.linalg.norm(x - x_recon, ord=2) 

393 

394 def set_full_decoder_mask(self, num_meas=None): 

395 if num_meas is None: 

396 if not hasattr(self, "dataset"): 

397 raise ValueError("Provide num_meas or set dataset first.") 

398 num_meas = self.dataset.shape[1] 

399 

400 num_latent = num_meas 

401 self.biadj = np.ones((num_latent, num_meas), dtype=float)