Coverage for /opt/conda/lib/python3.12/site-packages/medil/evaluate.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-25 05:42 +0000

1import numpy as np 

2import numpy.typing as npt 

3 

4 

5def sfd( 

6 true_biadj: npt.NDArray, 

7 predicted_biadj: npt.NDArray, 

8 to_return: str = "raw", 

9) -> int | float | tuple[int, float]: 

10 """Structural Frobenius distance sums difference of latent parents. 

11 

12 For a binary biadjacency matrix B, consider U = B'B, where U_ij 

13 counts the number of parents nodes i and j have in common (so U_ii 

14 is the assignment number, and diag(U) is a sufficient statistic 

15 for the graph under the 1-pure-child assumption). sfd(B_1, B_2) is 

16 the sum of the differences between U_ij for B_1 and B_2 (without 

17 double-counting for U_ji). 

18 

19 Parameters 

20 ---------- 

21 predicted_biadj: learned bipartite directed graph 

22 true_biadj: true bipartite directed graph 

23 

24 Returns 

25 ------- 

26 nsfd: normalized structural Frobenius distance 

27 """ 

28 true_biadj = true_biadj.astype(int) 

29 true_wtd_ug = true_biadj.T @ true_biadj 

30 

31 predicted_biadj = predicted_biadj.astype(int) 

32 predicted_wtd_ug = predicted_biadj.T @ predicted_biadj 

33 

34 sfd = np.abs(np.triu(true_wtd_ug - predicted_wtd_ug)).sum() 

35 

36 if to_return == "raw": 

37 return sfd 

38 

39 true_zeros = np.where(true_wtd_ug == 0) 

40 true_wtd_ug[true_zeros] = -1 

41 

42 predicted_zeros = np.where(predicted_wtd_ug == 0) 

43 predicted_wtd_ug[predicted_zeros] = -1 

44 

45 similarity = np.sum(true_wtd_ug * predicted_wtd_ug) 

46 

47 cosin_normalizer = np.sqrt((true_wtd_ug**2).sum()) * np.sqrt( 

48 (predicted_wtd_ug**2).sum() 

49 ) 

50 

51 nsfd = np.arccos(similarity / cosin_normalizer) / np.pi 

52 

53 match to_return: 

54 case "normalized": 

55 return nsfd 

56 case "both": 

57 return sfd, nsfd 

58 case _: 

59 raise ValueError("`to_return` should be 'raw', 'normalized', or 'both'") 

60 

61 

62def shd( 

63 true_biadj: npt.NDArray, 

64 *, 

65 predicted_biadj: npt.NDArray = np.array([]), 

66 predicted_adj: npt.NDArray = np.array([]), 

67 to_return: str = "raw", 

68) -> int | float | tuple[int, float]: 

69 """Structural Hamming distance counts number of incorrect arrowheads/tails. 

70 

71 Parameters 

72 ---------- 

73 true_biadj: true bipartite directed graph 

74 predicted_biadj: learned bipartite directed graph 

75 predicted_adj: learned mixed graph 

76 

77 Returns 

78 ------- 

79 nshd: normalized structural Hamming distance 

80 """ 

81 if bool(len(predicted_biadj)) == bool(len(predicted_adj)): 

82 raise ValueError( 

83 "Must provide `predicted_biadj` or `predicted_adj` but not both." 

84 ) 

85 elif bool(len(predicted_biadj)): 

86 predicted_adj = recover_ug(predicted_biadj) 

87 

88 ug = recover_ug(true_biadj) 

89 

90 shd = np.logical_xor(ug, predicted_adj).sum() 

91 

92 if to_return == "raw": 

93 return shd 

94 

95 n = len(ug) 

96 nshd = shd / (n**2 - n) 

97 

98 match to_return: 

99 case "normalized": 

100 return nshd 

101 case "both": 

102 return shd, nshd 

103 case _: 

104 raise ValueError("`to_return` should be 'raw', 'normalized', or 'both'") 

105 

106 

107def recover_ug(biadj_mat: npt.NDArray) -> npt.NDArray: 

108 """Recover the undirected graph from the directed bipartite graph 

109 Parameters 

110 ---------- 

111 biadj_mat: learned directed graph 

112 

113 Returns 

114 ------- 

115 ug: the recovered undirected graph 

116 """ 

117 ug = biadj_mat.T @ biadj_mat 

118 np.fill_diagonal(ug, False) 

119 return ug