Working with groundtruth models¶
In this notebook, I'll quickly generate a simple function that represents a cognitive task. Then I check if MSA can find the important units, and how well it performs with more distributed processes and larger number of elements.
# Uncomment the line below if you don't have them.
# !pip install networkx matplotlib seaborn
# Imports
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
# ---------
from msapy import msa, utils as ut
# ---------
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['font.size'] = 8
CM = 1 / 2.54
SEED = 111
RNG = np.random.default_rng(SEED)
FIGPATH = "figures/gt/"
/home/shrey/miniconda3/envs/msa/lib/python3.9/site-packages/scipy/__init__.py:155: UserWarning: A NumPy version >=1.18.5 and <1.25.0 is required for this version of SciPy (detected version 1.26.3 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" /home/shrey/miniconda3/envs/msa/lib/python3.9/site-packages/tqdm_joblib/__init__.py:4: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console) from tqdm.autonotebook import tqdm
A very localized function¶
The task is, well, basically just a random generator! The idea is to draw a number from a normal distribution around some value, let's say 100 with standard deviation of 10. Then condition on lesioning some arbitrary element, there will be a deficit. For simulating a very localized "cognitive function" we can pass one element as the only cause. Passing more elements means this cognitive function is distributed across these elements so naturally we expect those to have the highest contributions while the others settle around zero.
ground_truth_elements = ['a','b','c','d','e','f','g','h'] # our brain regions, for example!
print(f'total number of possible lesions: {2**len(ground_truth_elements)}')
total number of possible lesions: 256
As you can see, the number of lesions are not "that large" so we can exhauste the combination space.
ground_truth_cause = ['a'] # the cognitive function is the product of just one element 'a'
def gt(complements, causes):
# default score, on average, is 100.
score = RNG.normal(loc=100, scale=10)
# checking if the regions of interest are being lesioned.
if len(causes) != 0 and set(causes).issubset(complements):
# lesioning ends with a reduction of 50 points.
return score - 50
else:
return score
shapley_table = msa.interface(
elements=ground_truth_elements,
n_permutations=1000,
objective_function=gt,
n_parallel_games=-1, #parallelized over all CPU cores
objective_function_params={'causes': ground_truth_cause},
rng=RNG)
To quantify how much the function is distributed in our system, we can use this nifty function called "distribution_of_processing" in the utils module. You can read more about it here:
- Aharonov, R., Segev, L., Meilijson, I., & Ruppin, E. 2003. Localization of function via lesion analysis. Neural Computation.
- Saggie-Wexler, Keren, Alon Keinan, and Eytan Ruppin. 2006. Neural Processing of Counting in Evolved Spiking and McCulloch-Pitts Agents. Artificial Life.
But the basic idea is to see how many units are involved in producing the function, one way or another, so negative Shapley values (interpreted as hinderace) still counts. If the value is around zero, it means a localized function, which we expect here, and if it's around one, it suggests a distributed process. So far I couldn't get anything really around zero and I assume it's because of the noise in the datasets. I think it needs some rescaling but generally it's a nice metric and still, you can compare two values to see which process is more distributed and how much.
d = ut.distribution_of_processing(shapley_vector=shapley_table.mean())
print(f'D index is: {d}')
D index is: 0.19889855206048468
shapley_table = ut.sorter(shapley_table) # sorting based on the average contribution (Shapley values)
fig,ax = plt.subplots()
sns.barplot(data=shapley_table, ax=ax, errorbar=('ci', 95), orient= "h", err_kws={'color': 'k'})
fig.set_dpi(154)
fig.set_size_inches((4*CM,6*CM))
plt.xlabel('Shapley values')
plt.ylabel('Elements')
plt.title('Shapley values of a ground-truth dataset\nwith only 1 critical element')
plt.savefig(f"{FIGPATH}1critical.pdf",dpi=300,bbox_inches='tight')
Here's an interesting thing, pay attention to the Shapley value, it's around 50, which is the same number we mentioned for the performance deficit.
A more distributed function¶
Here, we say the cognitive function relies on let's say 3 regions, if "all of these regions" are perturbed then the performance will drop by 50 points. The function we made stays the same, only the causes need an adjustment.
ground_truth_cause = ['a','b','c']
# The same call here:
shapley_table = msa.interface(
elements=ground_truth_elements,
n_permutations=1_000,
objective_function=gt,
n_parallel_games=-1, #parallelized over all CPU cores
objective_function_params={'causes': ground_truth_cause},
rng=RNG)
# The same pipeline here:
shapley_table = ut.sorter(shapley_table) # sorting based on the average contribution (Shapley values)
fig,ax = plt.subplots()
sns.barplot(data=shapley_table, ax=ax, errorbar=('ci', 95), orient= "h", err_kws={'color': 'k'})
fig.set_dpi(154)
fig.set_size_inches((4*CM,6*CM))
plt.xlabel('Shapley values')
plt.ylabel('Elements')
plt.title('Shapley values of a ground-truth dataset\nwith three critical elements')
plt.savefig(f"{FIGPATH}3critical.pdf",dpi=300,bbox_inches='tight')
# And the D index
d = ut.distribution_of_processing(shapley_vector=shapley_table.mean())
print(f'D intex is: {d}')
D intex is: 0.5432645447465022
Again, the interesting point is Shapley values roughly correspond to 50/3, which is pretty neat.
A totally distributed function¶
Yeah why not! let's say all except one unit is involved. Things will get tricky here and you'll see why!
ground_truth_cause = ['a','b','c','d','e','f','g'] #only 'h' is out.
# The same call here:
shapley_table = msa.interface(
elements=ground_truth_elements,
n_permutations=1_000,
objective_function=gt,
n_parallel_games=-1, #parallelized over all CPU cores
objective_function_params={'causes': ground_truth_cause},
rng=RNG)
# The same pipeline here:
shapley_table = ut.sorter(shapley_table) # sorting based on the average contribution (Shapley values)
fig,ax = plt.subplots()
sns.barplot(data=shapley_table, ax=ax, errorbar=('ci', 95), orient= "h", err_kws={'color': 'k'})
fig.set_dpi(154)
fig.set_size_inches((4*CM,6*CM))
plt.xlabel('Shapley values')
plt.ylabel('Elements')
plt.title('Shapley values of a ground-truth dataset\nwith only one non-critical element')
plt.savefig(f"{FIGPATH}1noncritical.pdf",dpi=300,bbox_inches='tight')
# And the D index
d = ut.distribution_of_processing(shapley_vector=shapley_table.mean())
print(f'D intex is: {d}')
D intex is: 0.8630240119335526
You see, still it captured the critical elements, but the importance is not equally distributed. I believe it's due to two things:
- Noise in the performance.
- Noise in the estimation.
Let's get rid of the first and see if it makes a difference.
def gt_noisless(complements, causes):
# default score, will be 100 sharp!
score = 100
# checking if the regions of interest are being lesioned.
if len(causes) != 0 and set(causes).issubset(complements):
# lesioning ends with a reduction of 50 points.
return score - 50
else:
return score
ground_truth_cause = ['a','b','c','d','e','f','g'] #only 'h' is out.
# The same call here:
shapley_table = msa.interface(
elements=ground_truth_elements,
n_permutations=1_000,
objective_function=gt_noisless,
n_parallel_games=-1, #parallelized over all CPU cores
objective_function_params={'causes': ground_truth_cause},
rng=RNG)
# The same pipeline here:
shapley_table = ut.sorter(shapley_table) # sorting based on the average contribution (Shapley values)
fig,ax = plt.subplots()
sns.barplot(data=shapley_table, ax=ax, errorbar=('ci', 95), orient= "h", err_kws={'color': 'k'})
fig.set_dpi(154)
fig.set_size_inches((4*CM,6*CM))
plt.xlabel('Shapley values')
plt.ylabel('Elements')
plt.title('Shapley values of a ground-truth dataset with only one non-critical element')
plt.savefig(f"{FIGPATH}1noncritical_noiseless.pdf",dpi=300,bbox_inches='tight')
# And the D index
d = ut.distribution_of_processing(shapley_vector=shapley_table.mean())
print(f'D intex is: {d}')
D intex is: 0.8551934592118597
Much better! So here's the moral, in case you missed it: The noisier the data, the less accurate Shapley values we will be. What is nice tho is the fact that even with noisy data, the ranking still makes sense. Now let's see if it scales to a large system. We'll use the same logic but on a system with 500 elements. Let's see if we can localize "the" element first.
Scaling up to 500 elements¶
WARNING: This step needs about 40GB of RAM! I need to fix this later probably with dask arrays or something else. let me know if you have a solution. Also, this step shows that in principle, a large number of elements doesn't make the calculation unstable (but look below for a caveat), it makes the process computationally prohibitive. I mean, my task takes a few miliseconds but imagine a task that costs a second or two!
ground_truth_elements = list(range(500))
len(ground_truth_elements)
500
ground_truth_cause = [100] # element number 100 is the sole producer of the function
# Everything else is the same, again.
shapley_table = msa.interface(
elements=ground_truth_elements,
n_permutations=1_000,
objective_function=gt,
n_parallel_games=1, # somehow all over one core is faster!
objective_function_params={'causes': ground_truth_cause},
rng=RNG)
This time we don't sort by importance but we still need to sort in ascending order to plot, in the dataframe everything is in its place but the positions are shuffled in the process or calculating Shapley values so I might later add the shapley_table.sort_index(axis=1)
to the package itself but for now, we'll be doing it here.
We'll use scatter and line plots since barplots of this size will be very ugly.
shapley_table = shapley_table.sort_index(axis=1)
# these are my custom color palette btw but feel free to use it. It's colorblind friendly.
color = sns.blend_palette(['#006685', '#3FA5C4', '#FFE48D', '#E84653', '#BF003F'],
len(ground_truth_elements),
as_cmap= True)
plt.figure(figsize=(21*CM,5*CM),dpi = 154)
plt.plot(np.arange(len(ground_truth_elements)),
shapley_table.mean(),
c='k',linewidth=1,alpha=0.1)
plt.scatter(x = np.arange(len(ground_truth_elements)),
y = shapley_table.mean(),
c = shapley_table.mean(),
s = shapley_table.mean().abs()*10,
cmap = color)
plt.axhline(linewidth=1, color='#BF003F')
plt.xlabel('Elements (Indices)')
plt.ylabel('Shapley values')
plt.title('Shapley values of a ground-truth dataset with 1 critical elements')
plt.savefig(f"{FIGPATH}scaled1critical.pdf",dpi=300,bbox_inches='tight')
# And the D index
d = ut.distribution_of_processing(shapley_vector=shapley_table.mean())
print(f'D intex is: {d}')
D intex is: 0.7791750779196088
Works well, what doesn't work well is the D index! For sure it needs some adjustments to be more robust to the noise. But let's see if we can capture a function that is distributed over 100 nodes.
ground_truth_cause = list(range(100,200))
shapley_table = msa.interface(
elements=ground_truth_elements,
n_permutations=1_000,
objective_function=gt,
n_parallel_games=1,
objective_function_params={'causes': ground_truth_cause},
rng=RNG)
shapley_table = shapley_table.sort_index(axis=1)
plt.figure(figsize=(21*CM,5*CM),dpi = 154)
plt.plot(np.arange(len(ground_truth_elements)),
shapley_table.mean(),c='k',linewidth=1,alpha=0.1)
plt.scatter(x = np.arange(len(ground_truth_elements)),
y = shapley_table.mean(),
c = shapley_table.mean(),
s = shapley_table.mean().abs()*20,
cmap = color)
plt.axhline(linewidth=1, color='#BF003F')
plt.xlabel('Elements (Indices)')
plt.ylabel('Shapley values')
plt.title('Shapley values of a ground-truth dataset with 100 critical elements')
plt.savefig(f"{FIGPATH}scaled100critical.pdf",dpi=300,bbox_inches='tight')
# And the D index
d = ut.distribution_of_processing(shapley_vector=shapley_table.mean())
print(f'D intex is: {d}')
D intex is: 0.9456323166894177
As you can see, it didn't workout well and it actually makes sense, again remember the noise and remember that you're trying to distribute 50 points across 100 elements so although there's a noticable bump, the result is much more noisier. Let's try eliminating the performance noise.
shapley_table = msa.interface(
elements=ground_truth_elements,
n_permutations=1_000,
objective_function=gt_noisless,
n_parallel_games=1, # somehow all over one core is faster!
objective_function_params={'causes': ground_truth_cause})
shapley_table = shapley_table.sort_index(axis=1)
plt.figure(figsize=(21*CM,5*CM),dpi = 154)
color = sns.blend_palette(['#FFE48D', '#E84653', '#BF003F'], # no negative Shapley values, no cold colors
len(ground_truth_elements),
as_cmap= True)
plt.plot(np.arange(len(ground_truth_elements)),
shapley_table.mean(),c='k',linewidth=1,alpha=0.1)
plt.scatter(x = np.arange(len(ground_truth_elements)),
y = shapley_table.mean(),
c = shapley_table.mean(),
s = shapley_table.mean().abs()*20,
cmap = color)
plt.axhline(linewidth=1, color='#BF003F')
plt.xlabel('Elements (Indices)')
plt.ylabel('Shapley values')
plt.title('Shapley values of a ground-truth dataset with 100 critical elements')
plt.savefig(f"{FIGPATH}scaled100critical_noiseless.pdf",dpi=300,bbox_inches='tight')
# And the D index
d = ut.distribution_of_processing(shapley_vector=shapley_table.mean())
print(f'D intex is: {d}')
D intex is: 0.9057781579939023
Interestingly, this number changes but follows a normal distribution. Look: Here I'll just use 20 elements and keep reproducing permutations and lesion combinations. Then I'll plot the histogram.
elements = list(range(20))
lenths = []
for i in range(1000):
permutations = msa.make_permutation_space(elements=elements,n_permutations=1000)
combinations = msa.make_combination_space(permutation_space=permutations)
lenths.append(len(combinations))
plt.figure(figsize=(8*CM,8*CM),dpi=154)
sns.histplot(lenths,bins=20,color = "#FFE48D")
plt.xticks(rotation=90)
plt.xlabel('Number of coalitions '
'\n(N permutations = 1000)')
plt.title(f'Generated coalitions\n(from {2**20} possible coalitions)')
plt.savefig(f"{FIGPATH}distribution_of_coalitions.pdf",dpi=300,bbox_inches='tight')