Source code for medil.evaluate

import itertools

import numpy as np
import numpy.typing as npt


[docs] def sfd(predicted_biadj, true_biadj): """Perform analysis of the distances between true and reconstructed structures Parameters ---------- biadj_mat: input directed graph biadj_mat_recon: learned directed graph in the form of adjacency matrix Returns ------- sfd: squared Frobenius distance (bipartite graph) ushd: structural hamming distance (undirected graph) """ # ushd = shd_func(recover_ug(biadj_mat), recover_ug(biadj_mat_recon)) ug = recover_ug(true_biadj) ug_recon = recover_ug(predicted_biadj) ushd = np.triu(np.logical_xor(ug, ug_recon), 1).sum() true_biadj = true_biadj.astype(int) predicted_biadj = predicted_biadj.astype(int) wtd_ug = true_biadj.T @ true_biadj wtd_ug_recon = predicted_biadj.T @ predicted_biadj sfd = ((wtd_ug - wtd_ug_recon) ** 2).sum() return sfd, ushd
[docs] def recover_ug(biadj_mat): """Recover the undirected graph from the directed graph Parameters ---------- biadj_mat: learned directed graph Returns ------- ug: the recovered undirected graph """ # get the undirected graph from the directed graph ug = biadj_mat.T @ biadj_mat np.fill_diagonal(ug, False) return ug
[docs] def min_perm_squared_l2_dist(predicted_W: npt.NDArray, true_W: npt.NDArray): zeros = np.zeros_like(predicted_W) num_latents = len(true_W) zeros[:num_latents] = true_W true_W = zeros def perm_squared_l2_dist(perm): perm = np.array(perm) return np.sum((predicted_W[perm] - true_W) ** 2) def pair(perm): return perm, perm_squared_l2_dist(perm) perms = itertools.permutations(range(len(predicted_W))) pairs = map(pair, perms) opt_perm, min_dist = min(pairs, key=lambda pair: pair[1]) return np.array(opt_perm), min_dist
[docs] def min_perm_squared_l2_dist_abs(predicted_W: npt.NDArray, true_W: npt.NDArray): zeros = np.zeros_like(predicted_W) num_latents = len(true_W) zeros[:num_latents] = true_W true_W = zeros def perm_squared_l2_dist(perm): perm = np.array(perm) return np.sum((np.abs(predicted_W[perm]) - np.abs(true_W)) ** 2) def pair(perm): return perm, perm_squared_l2_dist(perm) perms = itertools.permutations(range(len(predicted_W))) pairs = map(pair, perms) opt_perm, min_dist = min(pairs, key=lambda pair: pair[1]) return np.array(opt_perm), min_dist