Minimal examples of how to use MSA¶
Background¶
MSA stands for "Multiperturbation Shapley value Analysis" and as the name suggests, it's comprised of "multiple perturbations" and "Shapley value". Fundamentally, it uses a dataset of multi-element perturbation to estimate Shapley values of each element with respect to a global function. I'll refer you to these papers for technical and conceptual details:
Keinan, Alon, Claus C. Hilgetag, Isaac Meilijson, and Eytan Ruppin. 2004. “Causal Localization of Neural Function: The Shapley Value Method.” Neurocomputing 58-60 (June): 215–22.
Keinan, Alon, Ben Sandbank, Claus C. Hilgetag, Isaac Meilijson, and Eytan Ruppin. 2004. “Fair Attribution of Functional Contribution in Artificial and Biological Networks.” Neural Computation 16 (9): 1887–1915.
Kötter, Rolf. 2007. “Shapley Ratings in Brain Networks.” Frontiers in Neuroinformatics 1 (NOV): 1–9.
Toba, Monica N., Melissa Zavaglia, Caroline Malherbe, Tristan Moreau, Federica Rastelli, Anna Kaglik, Romain Valabrègue, Pascale Pradat-Diehl, Claus C. Hilgetag, and Antoni Valero-Cabré. 2020. “Game Theoretical Mapping of White Matter Contributions to Visuospatial Attention in Stroke Patients with Hemineglect.” Human Brain Mapping, no. February: 1–25.
Zavaglia, Melissa, and Claus C. Hilgetag. 2016. “Causal Functional Contributions and Interactions in the Attention Network of the Brain: An Objective Multi-Perturbation Analysis.” Brain Structure & Function 221 (5): 2553–68.
The toolbox is designed to handle a large variety of systems. All it needs is a list of elements, let's say node labels, node indices, tuples (connections between nodes for examples), but I didn't really test things that are not networks! For example, if you're interested in explainable machine learning there's already a very versatile toolbox called SHAP that calculates the Shapley values of model's input features. MSA is traditionally used as a brain-mapping tool (I mean, just look at those papers) and I'm also a neuroscientist so there's a bias towards networks and lesion-mapping conceptualization. These said, let's see how MSA works using some small networks.
Defining the network and the game¶
# Uncomment the line below if you don't have them.
# !pip install networkx matplotlib seaborn
# Imports n stuff
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import numpy as np
#---------
from msapy import msa, utils as ut
#---------
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['font.size'] = 10
CM = 1 / 2.54
SEED = 111
FIGPATH = "figures/minimal/"
/home/shrey/miniconda3/envs/msa/lib/python3.9/site-packages/scipy/__init__.py:155: UserWarning: A NumPy version >=1.18.5 and <1.25.0 is required for this version of SciPy (detected version 1.26.3 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" /home/shrey/miniconda3/envs/msa/lib/python3.9/site-packages/tqdm_joblib/__init__.py:4: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console) from tqdm.autonotebook import tqdm
As mentioned, all MSA needs is a list of elements, that's not quite true tho it needs a game. Players (elements) should play the game and well, the game can be anything and in a way, that's the beauty of it. Here, I will define a game called "ge" that stands for "average global efficiency" (yeah it should be "age" but that would be weird!), which is a graph metric. See this page for more detail.
So you need to define a game, a function that encapsulates the procedure you care about while the system is being lesioned. For example, if you're using MSA in the classic sense of brain mapping, the game function is where you perform the "cognitive experiment". Keep in mind a few things:
Take care of the edge-cases. MSA goes through all the possible ways you can lesion your system so if you're interested in the contribution (shapley value) of each node to information flow in a network, then in some of these combinations of lesions the network might end up with just a bunch of isolated nodes. Like define conditions, if len(x) < y return 0.
Copy your stuff before feeding it to the toolbox. Otherwise (depending on how you're lesioning the system) it might leak to the next step and messes with everything.
The result should be "one value", a global measure that describes the system with one value.
def ge(complements, graph):
if len(complements) < 0:
# grand coalition is in, nothing to lesion so work with the intact network.
return float(nx.global_efficiency(graph))
elif len(complements) == len(graph):
# everything is gone, why calling nx.global_efficiency?
return 0.0
else:
# lesion the system, calculate global efficiency
lesioned = graph.copy()
lesioned.remove_nodes_from(complements)
return float(nx.global_efficiency(lesioned))
Next, I'll make a graph with an intuitive topology, i.e., a balanced tree. Intuitively, the farther we go from the root, the smaller shapley values we will have. It's possible that the outer branch ends with negative shapley values since removing them shortenes the overal paths and improves global efficiency.
G = nx.balanced_tree(3,3)
fig,ax = plt.subplots()
fig.set_dpi(150)
fig.set_size_inches((12*CM,12*CM))
ax = nx.draw_networkx(G,with_labels=True,node_color='#FFE48D',font_size=8)
plt.savefig(f"{FIGPATH}balanced_tree.pdf",dpi=300,bbox_inches='tight')
MSA in action¶
Briefly, we'll estimate shapley values by first permuting the elements N times, then producing an instruction for which combinations to lesion.
node_perms = msa.make_permutation_space(elements=list(G.nodes), n_permutations=1_000,random_seed=SEED)
print(np.shape(node_perms))
print(f'Nodes: {list(G.nodes)}\n')
print(f'Permuted nodes: {node_perms[0]}')
(1000, 40) Nodes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39] Permuted nodes: (21, 16, 18, 22, 15, 24, 38, 7, 29, 33, 11, 25, 19, 14, 37, 6, 27, 10, 0, 17, 5, 9, 32, 13, 39, 28, 1, 3, 36, 34, 26, 31, 20, 8, 4, 23, 12, 30, 2, 35)
node_combs_template = msa.make_combination_space(permutation_space=node_perms)
node_compl_template = msa.make_complement_space(combination_space=node_combs_template, elements=list(G.nodes))
print(f'Number of lesion combinations: {len(node_combs_template)}')
Number of lesion combinations: 36088
Here we use the parallelized_take_contributions to actually play games and fill the values. We then use this multi-site perturbation dataset to calculate shapley values and sort them.
global_eff,_ = ut.parallelized_take_contributions(complement_space=node_compl_template,
combination_space=node_combs_template,
objective_function=ge,
objective_function_params={'graph': G})
global_eff_shapley = msa.get_shapley_table(contributions=global_eff, permutation_space=node_perms)
global_eff_shapley = ut.sorter(global_eff_shapley)
global_eff_shapley.head()
37 | 22 | 27 | 26 | 35 | 19 | 21 | 39 | 17 | 18 | ... | 7 | 12 | 4 | 9 | 6 | 5 | 2 | 0 | 1 | 3 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | -0.008062 | -0.005808 | 0.001170 | 0.007937 | -0.008140 | 0.006794 | 0.001562 | -0.019444 | -0.010490 | -0.013636 | ... | 0.000631 | 0.030357 | 0.026218 | 0.000000 | 0.000000 | 0.043039 | 0.112287 | 0.004040 | 0.022421 | 0.000000 |
1 | -0.000074 | -0.001991 | -0.001687 | 0.000252 | -0.005491 | -0.022073 | -0.166667 | 0.015152 | -0.001532 | -0.066667 | ... | -0.023810 | 0.000000 | 0.026345 | 0.000000 | 0.037694 | 0.011905 | 0.063440 | 0.159722 | 0.004579 | 1.000000 |
2 | -0.014828 | -0.004919 | 0.000709 | 0.000050 | -0.012646 | 0.001495 | 0.005898 | -0.022533 | 0.000900 | 0.006291 | ... | -0.047619 | 0.036062 | 0.036170 | -0.066667 | 0.027778 | 0.066667 | 0.105047 | 0.000000 | 1.000000 | 0.017593 |
3 | 0.000916 | 0.000000 | -0.001076 | -0.005931 | -0.000713 | 0.003239 | 0.004179 | 0.001333 | -0.005848 | 0.000000 | ... | 0.119048 | -0.019841 | -0.013889 | 0.005700 | 0.000000 | 0.020759 | 0.010975 | 0.143057 | 0.012573 | 0.030157 |
4 | -0.000721 | -0.000696 | -0.001950 | -0.010490 | -0.001991 | 0.001197 | 0.007597 | -0.000378 | 0.000000 | 0.000000 | ... | 0.015544 | 0.016739 | 0.021936 | 0.000000 | 0.027778 | 0.038660 | 0.077008 | 0.053453 | 0.019192 | 0.018672 |
5 rows × 40 columns
intact_global_eff = nx.global_efficiency(G)
d_global_eff = ut.distribution_of_processing(shapley_vector=global_eff_shapley.mean())
fig,ax = plt.subplots()
sns.barplot(data=global_eff_shapley, ax=ax, errorbar=('ci', 95), orient= "v", err_kws={'color': 'k'})
fig.set_dpi(150)
fig.set_size_inches((21*CM,5*CM))
plt.text(0.5, 0.04,f'Intact global efficiency: {intact_global_eff:.2f}')
plt.text(0.5, 0.03,f'Distribution of process: {d_global_eff:.4f}')
plt.xticks(fontsize=8)
plt.title("Contribution of each node to the Global efficiency")
plt.savefig(f"{FIGPATH}global_efficiency.pdf",dpi=300,bbox_inches='tight')
Voila! Minimal example. p.s: Sum of all the shapley values will add up to the value you get if you run the analysis on the intact network (grand coalition). It's a nice sanity check. Here:
print(intact_global_eff-global_eff_shapley.mean().sum())
-1.6653345369377348e-16
Of course this might not be the case if there are stochasticity in the results. But in that case too, it should not be very far. Anyways, here's the same graph but each node is now colored by its contribution to the global efficiency.
fig,ax = plt.subplots()
fig.set_dpi(150)
fig.set_size_inches((12*CM,12*CM))
ax = nx.draw_networkx(G,with_labels=True,font_size=8)
plt.savefig(f"{FIGPATH}balanced_tree_colorcoded.pdf",dpi=300,bbox_inches='tight')