Tutorial

For a basic overview, see the software demonstration [MCGW20] presented at the 10th International Conference on Probabilistic Graphical Models (PGM) and the associated script, which you can see below or download here:

 1# for making sample data
 2import numpy as np
 3from medil.examples import triangle
 4from medil.functional_MCM import gaussian_mixture_sampler
 5from medil.functional_MCM import MeDILCausalModel  # also used in step 3
 6
 7# for step 1
 8from medil.independence_testing import hypothesis_test
 9
10# for step 2
11from medil.ecc_algorithms import find_clique_min_cover as find_cm
12
13# for step 3
14from pytorch_lightning import Trainer
15from medil.functional_MCM import uniform_sampler, GAN
16
17# for visualization
18import medil.visualize as vis
19from medil.independence_testing import distance_correlation
20
21
22# make sample data
23num_latent, num_observed = triangle.MCM.shape
24
25decoder = MeDILCausalModel(biadj_mat=triangle.MCM)
26sampler = gaussian_mixture_sampler(num_latent)
27
28input_sample, output_sample = decoder.sample(sampler, num_samples=1000)
29np.save("measurement_data", output_sample)
30
31# step 1: estimate UDG
32p_vals, null_corr, dep_graph = hypothesis_test(output_sample.T, num_resamples=100)
33# dep_graph is adjacency matrix of the estimated UDG
34
35
36# step 2: learn graphical MCM
37learned_biadj_mat = find_cm(dep_graph)
38
39
40# step 3: learn functional MCM
41num_latent, num_observed = learned_biadj_mat.shape
42
43decoder = MeDILCausalModel(biadj_mat=learned_biadj_mat)
44sampler = uniform_sampler(num_latent)
45
46minMCM = GAN("measurement_data.npy", decoder, latent_sampler=sampler, batch_size=100)
47trainer = Trainer(max_epochs=100)
48trainer.fit(minMCM)
49
50
51# confirm given and learned causal structures match
52vis.show_dag(triangle.MCM)
53vis.show_dag(learned_biadj_mat)
54
55# compare plots of disttrance correlation values for given and learned MCMs
56generated_sample = decoder.sample(sampler, 1000)[1].detach().numpy()
57generated_dcor_mat = distance_correlation(generated_sample.T)
58
59vis.show_obs_dcor_mat(null_corr, print_val=True)
60vis.show_obs_dcor_mat(generated_dcor_mat, print_val=True)
61
62# get params for learned functional MCM; replace '0' with 'i' to get params for any M_i
63print(decoder.observed["0"].causal_function)