Coverage for /opt/conda/lib/python3.12/site-packages/medil/evaluate.py: 100%
46 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-25 05:42 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-25 05:42 +0000
1import numpy as np
2import numpy.typing as npt
5def sfd(
6 true_biadj: npt.NDArray,
7 predicted_biadj: npt.NDArray,
8 to_return: str = "raw",
9) -> int | float | tuple[int, float]:
10 """Structural Frobenius distance sums difference of latent parents.
12 For a binary biadjacency matrix B, consider U = B'B, where U_ij
13 counts the number of parents nodes i and j have in common (so U_ii
14 is the assignment number, and diag(U) is a sufficient statistic
15 for the graph under the 1-pure-child assumption). sfd(B_1, B_2) is
16 the sum of the differences between U_ij for B_1 and B_2 (without
17 double-counting for U_ji).
19 Parameters
20 ----------
21 predicted_biadj: learned bipartite directed graph
22 true_biadj: true bipartite directed graph
24 Returns
25 -------
26 nsfd: normalized structural Frobenius distance
27 """
28 true_biadj = true_biadj.astype(int)
29 true_wtd_ug = true_biadj.T @ true_biadj
31 predicted_biadj = predicted_biadj.astype(int)
32 predicted_wtd_ug = predicted_biadj.T @ predicted_biadj
34 sfd = np.abs(np.triu(true_wtd_ug - predicted_wtd_ug)).sum()
36 if to_return == "raw":
37 return sfd
39 true_zeros = np.where(true_wtd_ug == 0)
40 true_wtd_ug[true_zeros] = -1
42 predicted_zeros = np.where(predicted_wtd_ug == 0)
43 predicted_wtd_ug[predicted_zeros] = -1
45 similarity = np.sum(true_wtd_ug * predicted_wtd_ug)
47 cosin_normalizer = np.sqrt((true_wtd_ug**2).sum()) * np.sqrt(
48 (predicted_wtd_ug**2).sum()
49 )
51 nsfd = np.arccos(similarity / cosin_normalizer) / np.pi
53 match to_return:
54 case "normalized":
55 return nsfd
56 case "both":
57 return sfd, nsfd
58 case _:
59 raise ValueError("`to_return` should be 'raw', 'normalized', or 'both'")
62def shd(
63 true_biadj: npt.NDArray,
64 *,
65 predicted_biadj: npt.NDArray = np.array([]),
66 predicted_adj: npt.NDArray = np.array([]),
67 to_return: str = "raw",
68) -> int | float | tuple[int, float]:
69 """Structural Hamming distance counts number of incorrect arrowheads/tails.
71 Parameters
72 ----------
73 true_biadj: true bipartite directed graph
74 predicted_biadj: learned bipartite directed graph
75 predicted_adj: learned mixed graph
77 Returns
78 -------
79 nshd: normalized structural Hamming distance
80 """
81 if bool(len(predicted_biadj)) == bool(len(predicted_adj)):
82 raise ValueError(
83 "Must provide `predicted_biadj` or `predicted_adj` but not both."
84 )
85 elif bool(len(predicted_biadj)):
86 predicted_adj = recover_ug(predicted_biadj)
88 ug = recover_ug(true_biadj)
90 shd = np.logical_xor(ug, predicted_adj).sum()
92 if to_return == "raw":
93 return shd
95 n = len(ug)
96 nshd = shd / (n**2 - n)
98 match to_return:
99 case "normalized":
100 return nshd
101 case "both":
102 return shd, nshd
103 case _:
104 raise ValueError("`to_return` should be 'raw', 'normalized', or 'both'")
107def recover_ug(biadj_mat: npt.NDArray) -> npt.NDArray:
108 """Recover the undirected graph from the directed bipartite graph
109 Parameters
110 ----------
111 biadj_mat: learned directed graph
113 Returns
114 -------
115 ug: the recovered undirected graph
116 """
117 ug = biadj_mat.T @ biadj_mat
118 np.fill_diagonal(ug, False)
119 return ug