"""MeDIL causal model base class and a preconfigured NCFA class."""
from datetime import datetime
import os
from pathlib import Path
import pickle
import warnings
import numpy as np
from numpy.random import default_rng
import numpy.typing as npt
from scipy.linalg import norm
from scipy.optimize import minimize
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler as sc
import torch
from torch.utils.data import DataLoader, TensorDataset
from .ecc_algorithms import find_heuristic_1pc
from .independence_testing import estimate_UDG
from .vae import VariationalAutoencoder
[docs]
class MedilCausalModel(object):
"""Base class using principle of polymorphism to establish common
interface for derived parametric estimators.
"""
[docs]
def __init__(
self,
biadj: npt.NDArray = np.array([]),
udg: npt.NDArray = np.array([]),
one_pure_child: bool = True,
rng=default_rng(0),
) -> None:
self.biadj = biadj
self.udg = udg
self.one_pure_child = one_pure_child
self.rng = rng
[docs]
def fit(self, dataset: npt.NDArray) -> "MedilCausalModel":
raise NotImplementedError
[docs]
def sample(self, sample_size: int) -> npt.NDArray:
raise NotImplementedError
[docs]
class Parameters(object):
"Different parameterizations of MeDIL causal Models."
[docs]
def __init__(self, parameterization: str) -> None:
self.parameterization = parameterization
if parameterization == "Gaussian":
self.error_means = np.array([])
self.error_variances = np.array([])
self.biadj_weights = np.array([])
elif parameterization == "VAE":
raise NotImplementedError
def __str__(self) -> str:
return "\n".join(
f"parameters.{attr}: {val}" for attr, val in vars(self).items()
)
[docs]
class GaussianMCM(MedilCausalModel):
"""A linear MeDIL causal model with Gaussian random variables."""
[docs]
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.parameters = Parameters("Gaussian")
[docs]
def fit(self, dataset: npt.NDArray) -> "GaussianMCM":
"""Fit a Gaussian MCM to a dataset with constraint-based
structure learning and least squares parameter estimation."""
self.dataset = dataset
if self.biadj.size == 0:
self._compute_biadj()
self.parameters.error_means = self.dataset.mean(0)
cov = np.cov(self.dataset, rowvar=False)
num_weights = self.biadj.sum()
num_err_vars = self.biadj.shape[1]
def _objective(weights_and_err_vars):
weights = weights_and_err_vars[:num_weights]
err_vars = weights_and_err_vars[num_weights:]
biadj_weights = np.zeros_like(self.biadj, float)
biadj_weights[self.biadj] = weights
return (
(cov - biadj_weights.T @ biadj_weights - np.diagflat(err_vars)) ** 2
).sum()
result = minimize(_objective, np.ones(num_weights + num_err_vars))
if not result.success:
warnings.warn(f"Optimization failed: {result.message}")
self.parameters.error_variances = result.x[num_weights:]
self.parameters.biadj_weights = np.zeros_like(self.biadj, float)
self.parameters.biadj_weights[self.biadj] = result.x[:num_weights]
return self
def _compute_biadj(self):
"""Constraint-based structure learning."""
if self.udg.size == 0:
self._estimate_udg()
self.biadj = find_heuristic_1pc(self.udg)
def _estimate_udg(self):
"""Constraint-based structure learning."""
samp_size = len(self.dataset)
cov = np.cov(self.dataset, rowvar=False)
corr = np.corrcoef(self.dataset, rowvar=False)
inner_numerator = 1 - cov * corr # should never be <= 0?
inner_numerator = inner_numerator.clip(min=0.00001)
inner_numerator[np.tril_indices_from(inner_numerator)] = 1
udg_triu = np.log(inner_numerator) < (-np.log(samp_size) / samp_size)
udg = udg_triu + udg_triu.T
self.udg = udg
[docs]
def sample(self, sample_size: int) -> npt.NDArray:
"""Sample a dataset from a GaussianMCM, after structure and
parameters have been specified or estimated."""
num_latent = len(self.biadj)
latent_sample = self.rng.multivariate_normal(
np.zeros(num_latent), np.eye(num_latent), sample_size
)
error_sample = self.rng.multivariate_normal(
self.parameters.error_means,
np.diagflat(self.parameters.error_variances),
sample_size,
)
sample = latent_sample @ self.parameters.biadj_weights + error_sample
return sample
[docs]
class NeuroCausalFactorAnalysis(MedilCausalModel):
"""A MeDIL causal model represented by a deep generative model."""
[docs]
def __init__(
self,
seed: int = 0,
dof: int = 0,
path: str = "trained_ncfa/",
verbose: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Path(path).mkdir(exist_ok=True)
self.path = path
self.verbose = verbose
self.seed = seed
self.hyperparams = {
"heuristic": True,
"method": "xicor",
"alpha": 0.05,
"dof": dof,
"batch_size": 128,
"num_epochs": 200,
"lr": 0.005,
"beta": 1,
"num_valid": 1000,
}
self.parameters = Parameters("vae")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
[docs]
def log(self, entry: str) -> None:
time_stamped_entry = f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} {entry}"
with open(f"{self.path}training.log", "a") as log_file:
log_file.write(time_stamped_entry + "\n")
if self.verbose:
print(time_stamped_entry)
[docs]
def fit(self, dataset: npt.NDArray) -> "NeuroCausalFactorAnalysis":
self.dataset = dataset
self.doffed = self.assign_dof()
standardized = sc().fit_transform(dataset)
train_split, valid_split = train_test_split(
standardized, train_size=0.7, random_state=self.seed
)
train_loader = self._data_loader(train_split)
valid_loader = self._data_loader(valid_split)
np.random.seed(self.seed)
model_recon, loss_recon, error_recon = self._train_vae(
train_loader, valid_loader
)
torch.save(model_recon, os.path.join(self.path, "model_recon.pt"))
with open(os.path.join(self.path, "loss_recon.pkl"), "wb") as handle:
pickle.dump(loss_recon, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(os.path.join(self.path, "error_recon.pkl"), "wb") as handle:
pickle.dump(error_recon, handle, protocol=pickle.HIGHEST_PROTOCOL)
return self
[docs]
def assign_dof(self) -> npt.NDArray:
"""Assign degrees of freedom (latent variables) of VAE to
latent factors from causal structure learning.
"""
if self.biadj.size == 0:
self._compute_biadj()
num_cliques, num_meas = self.biadj.shape
if self.hyperparams["dof"] == 0:
# then default to 3x num_meas overcomplete
self.dof = num_meas * 3
elif self.dof < num_cliques:
warnings.warn(
f"Input `deg_of_freedom={self.dof}` is less than the {num_cliques} required for the estimated causal structure. `deg_of_freedom` increased to {num_cliques} to compensate."
)
self.dof = num_cliques
latents_per_clique = np.ones(num_cliques, int) * (self.dof // num_cliques)
for _ in range(2):
remainder = self.dof - latents_per_clique.sum()
latents_per_clique[np.argsort(latents_per_clique)[0:remainder]] += 1
redundant_biadj_mat = np.repeat(self.biadj, latents_per_clique, axis=0)
return redundant_biadj_mat
def _compute_biadj(self):
if self.udg.size == 0:
self._estimate_udg()
self.biadj = find_heuristic_1pc(self.udg)
def _estimate_udg(self):
self.udg, pvals = estimate_UDG(
self.dataset,
method=self.hyperparams["method"],
significance_level=self.hyperparams["alpha"],
)
def _data_loader(self, sample):
sample_x = sample.astype(np.float32)
sample_z = np.empty(shape=(sample_x.shape[0], 0)).astype(np.float32)
dataset = TensorDataset(torch.tensor(sample_x), torch.tensor(sample_z))
data_loader = DataLoader(
dataset, batch_size=self.hyperparams["batch_size"], shuffle=False
)
return data_loader
def _train_vae(self, train_loader, valid_loader):
"""Training VAE with the specified image dataset
:param m: dimension of the latent variable
:param n: dimension of the observed variable
:param train_loader: training image dataset loader
:param valid_loader: validation image dataset loader
:param biadj_mat: the adjacency matrix of the directed graph
:param seed: random seed for the experiments
:return: trained model and training loss history
"""
m, n = self.doffed.shape
# building VAE
mask = self.doffed.T.astype("float32")
mask = torch.tensor(mask).to(self.device)
model = VariationalAutoencoder(m, n, mask)
model = model.to(self.device)
optimizer = torch.optim.AdamW(
model.parameters(), lr=self.hyperparams["lr"], weight_decay=1e-5
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.90)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
self.log(f"Number of parameters: {num_params}")
# training loop
model.train()
train_elbo, train_error = [], []
valid_elbo, valid_error = [], []
for idx in range(self.hyperparams["num_epochs"]):
self.log(f"Training on epoch {idx}...")
train_lb, train_er, nbatch = 0.0, 0.0, 0
for x_batch, _ in train_loader:
batch_size = x_batch.shape[0]
x_batch = x_batch.to(self.device)
recon_batch, logcov_batch, mu_batch, logvar_batch = model(x_batch)
loss = self._elbo_gaussian(
x_batch,
recon_batch,
logcov_batch,
mu_batch,
logvar_batch,
self.hyperparams["beta"],
)
error = self._recon_error(
x_batch, recon_batch, logcov_batch, weighted=False
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# update loss and nbatch
train_lb += loss.item() / batch_size
train_er += error.item() / batch_size
nbatch += 1
# finish training epoch
scheduler.step()
train_lb = train_lb / nbatch
train_er = train_er / nbatch
train_elbo.append(train_lb)
train_error.append(train_er)
self.log(f"Finish training epoch {idx} with loss {train_lb}")
# append validation loss
valid_lb, valid_er = self._valid_vae(model, valid_loader)
valid_elbo.append(valid_lb)
valid_error.append(valid_er)
train_elbo, train_error = np.array(train_elbo), np.array(train_error)
valid_elbo, valid_error = np.array(valid_elbo), np.array(valid_error)
elbo = [train_elbo, valid_elbo]
error = [train_error, valid_error]
return model, elbo, error
def _valid_vae(self, model, valid_loader):
"""Training VAE with the specified image dataset
:param model: trained VAE model
:param valid_loader: validation image dataset loader
:return: validation loss
"""
# set to evaluation mode
model.eval()
valid_lb, valid_er, nbatch = 0.0, 0.0, 0
for x_batch, _ in valid_loader:
with torch.no_grad():
batch_size = x_batch.shape[0]
x_batch = x_batch.to(self.device)
recon_batch, logcov_batch, mu_batch, logvar_batch = model(x_batch)
loss = self._elbo_gaussian(
x_batch,
recon_batch,
logcov_batch,
mu_batch,
logvar_batch,
self.hyperparams["beta"],
)
error = self._recon_error(
x_batch, recon_batch, logcov_batch, weighted=False
)
# update loss and nbatch
valid_lb += loss.item() / batch_size
valid_er += error.item() / batch_size
nbatch += 1
# report validation loss
valid_lb = valid_lb / nbatch
valid_er = valid_er / nbatch
self.log(f"Finish validation with loss {valid_lb}")
return valid_lb, valid_er
def _elbo_gaussian(self, x, x_recon, logcov, mu, logvar, beta):
"""Calculating loss for variational autoencoder
:param x: original image
:param x_recon: reconstruction in the output layer
:param logcov: log of covariance matrix of the data distribution
:param mu: mean in the fitted variational distribution
:param logvar: log of the variance in the variational distribution
:param beta: beta
:return: reconstruction loss + KL
"""
# KL-divergence
# https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
# https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py
# https://arxiv.org/pdf/1312.6114.pdf
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# reconstruction loss
cov = torch.exp(logcov)
cov = self._apply_along_axis(torch.diag, cov, axis=0)
cov = cov.mean(axis=0)
diff = x - x_recon
recon_loss = torch.sum(
torch.det(cov)
+ torch.diagonal(
torch.mm(
torch.mm(diff, torch.inverse(cov)), torch.transpose(diff, 0, 1)
)
)
).mul(-1 / 2)
# elbo
loss = -beta * kl_div + recon_loss
return -loss
def _recon_error(self, x, x_recon, logcov, weighted):
"""Reconstruction error given x and x_recon
:param x: original image
:param x_recon: reconstruction in the output layer
:param logcov: covariance matrix of the data distribution
:param weighted: whether to use weighted reconstruction
Returns
-------
error: reconstruction error
"""
# reconstruction loss
cov = torch.exp(logcov)
cov = self._apply_along_axis(torch.diag, cov, axis=0)
cov = cov.mean(axis=0)
diff = x - x_recon
if weighted:
error = torch.sum(
torch.det(cov)
+ torch.diagonal(
torch.mm(
torch.mm(diff, torch.inverse(cov)), torch.transpose(diff, 0, 1)
)
)
).mul(-1 / 2)
else:
error = torch.linalg.norm(diff, ord=2)
return error
@staticmethod
def _apply_along_axis(function, x, axis=0):
"""Helper function to return along a particular axis
Parameters
----------
function: function to be applied
x: data
axis: axis to apply the function
Returns
-------
The output applied to the axis
"""
return torch.stack(
[function(x_i) for x_i in torch.unbind(x, dim=axis)], dim=axis
)
# implement penalized mle and penalized lse with a new class
[docs]
class DevMedil(MedilCausalModel):
[docs]
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.init_W = None
# penalized MLE
[docs]
def fit_penalized_mle(
self,
dataset: npt.NDArray,
lambda_reg: float = 0.1,
mu_reg: float = 0.1,
) -> "DevMedil":
num_meas = dataset.shape[1]
if self.one_pure_child:
num_latent = num_meas
else:
num_latent = (num_meas**2) // 4
Sigma_hat = np.cov(dataset.T)
def penalized_mle_loss(W_and_D):
W = W_and_D[: num_latent * num_meas].reshape(num_latent, num_meas)
D = np.diag(W_and_D[num_latent * num_meas :])
Sigma = self.compute_sigma(W, D)
Sigma_inv = np.linalg.inv(Sigma)
sign, logdet = np.linalg.slogdet(Sigma_inv)
if sign <= 0:
return np.inf
loss = np.trace(np.dot(Sigma_hat, Sigma_inv)) - sign * logdet
loss += lambda_reg * self.rho(W) + mu_reg * self.sigma(W)
return loss
initial_W = (
self.rng.standard_normal((num_latent, num_meas))
if self.init_W is None
else self.init_W
)
initial_D = self.rng.random(num_meas)
initial_W_and_D = np.hstack([initial_W.flatten(), initial_D])
result = minimize(penalized_mle_loss, initial_W_and_D, method="BFGS")
self.result = result
self.W_hat = result.x[: num_latent * num_meas].reshape(num_latent, num_meas)
self.D_hat = np.diag(result.x[num_latent * num_meas :])
self.convergence_success_mle = result.success
self.convergence_message_mle = result.message
return self
[docs]
def validation_mle(self, lambda_reg, mu_reg, data):
W = self.W_hat
D = self.D_hat
Sigma_hat = np.cov(data, rowvar=False)
Sigma = self.compute_sigma(W, D)
Sigma_inv = np.linalg.inv(Sigma)
sign, logdet = np.linalg.slogdet(Sigma_inv)
if sign <= 0:
return np.inf
loss = np.trace(np.dot(Sigma_hat, Sigma_inv)) - sign * logdet
loss += lambda_reg * self.rho(W) + mu_reg * self.sigma(W)
return loss
# penalized LSE
[docs]
def fit_penalized_lse(
self,
dataset: npt.NDArray,
lambda_reg: float = 0.1,
mu_reg: float = 0.1,
) -> "DevMedil":
num_meas = dataset.shape[1]
if self.one_pure_child:
num_latent = num_meas
else:
num_latent = (num_meas**2) // 4
Sigma_hat = np.cov(dataset.T)
def penalized_lse_loss(W_and_D):
W = W_and_D[: num_latent * num_meas].reshape(num_latent, num_meas)
D = np.diag(W_and_D[num_latent * num_meas :])
loss = norm(Sigma_hat - W.T @ W - D, "fro") ** 2
# nuclear norm for the first penalty term
loss += lambda_reg * self.rho(W)
# L1 norm for the second penalty function
loss += mu_reg * self.sigma(W)
return loss
initial_W = self.rng.standard_normal((num_latent, num_meas))
initial_D = np.abs(self.rng.standard_normal(num_meas))
initial_params = np.concatenate([initial_W.flatten(), initial_D])
result = minimize(penalized_lse_loss, initial_params, method="BFGS")
self.result = result
self.W_hat = result.x[: num_latent * num_meas].reshape(num_latent, num_meas)
self.D_hat = np.diag(np.abs(result.x[num_latent * num_meas :]))
self.convergence_success_lse = result.success
self.convergence_message_lse = result.message
return self
[docs]
def validation_lse(self, lambda_reg, mu_reg, data):
W = self.W_hat
D = self.D_hat
Sigma_hat = np.cov(data, rowvar=False)
loss = norm(Sigma_hat - W.T @ W - D, "fro") ** 2
loss += lambda_reg * self.rho(W)
loss += mu_reg * self.sigma(W)
return loss
# compute sigma
[docs]
def compute_sigma(self, W: npt.NDArray, D: npt.NDArray) -> npt.NDArray:
Sigma = np.dot(W.T, W) + D
return Sigma
# ρ(W)
[docs]
def rho(self, W: npt.NDArray) -> float:
return norm(W, "nuc")
# σ(W), the sum of absolute values of elements (L1 norm)
[docs]
def sigma(self, W: npt.NDArray) -> float:
return np.sum(np.abs(W))
[docs]
def sample(self, sample_size: int, method: str = "mle") -> npt.NDArray:
if method not in ["mle", "lse"]:
raise ValueError("Method must be either 'mle' or 'lse'")
if method == "mle":
if not hasattr(self, "W_hat_mle") or not hasattr(self, "D_hat_mle"):
raise ValueError("MLE model must be fitted before sampling")
W_hat, D_hat = self.W_hat_mle, self.D_hat_mle
else:
if not hasattr(self, "W_hat_lse") or not hasattr(self, "D_hat_lse"):
raise ValueError("LSE model must be fitted before sampling")
W_hat, D_hat = self.W_hat_lse, self.D_hat_lse
k, n = W_hat.shape
L = self.rng.standard_normal((sample_size, k))
epsilon = self.rng.multivariate_normal(np.zeros(n), D_hat, sample_size)
return np.dot(L, W_hat) + epsilon
[docs]
def fit(
self, dataset: npt.NDArray, method: str = "mle", lambda_reg=0.1, mu_reg=0.1
) -> "DevMedil":
if method == "mle":
return self.fit_penalized_mle(dataset, lambda_reg, mu_reg)
elif method == "lse":
return self.fit_penalized_lse(dataset, lambda_reg, mu_reg)
else:
raise ValueError("Method must be either 'mle' or 'lse'")