Skip to content

Estimate causal influences

Estimates the causal contribution (Shapley values) of each node on the rest of the network. Basically, this function performs MSA iteratively on each node and tracks the changes in the objective_function of the target node. For example we have a chain A -> B -> C, and we want to know how much A and B are contributing to C. We first need to define a metric for C (objective_function) which here let's say is the average activity of C. MSA then performs a multi-site lesioning analysis of A and B so for each we will end up with a number indicating their contributions to the average activity of C.

VERY IMPORTANT NOTES:

1. The resulting causal contribution matrix does not necessarily reflect the connectome. In the example above
there's no actual connection A -> C but there might be one in the causal contribution matrix since A is causally
influencing C via B.
2. Think twice (even three times) about your objective function. The same everything will result in different
causal contribution matrices depending on what are you tracking and how accurate it's capturing the effect of
lesions. Also don't forget the edge-cases. There will be weird behaviors in your system, for example, what it
does if every node is perturbed?
3. The metric you track is preferred to be non-negative and bounded (at least practically!)
4. Obviously this will take N times longer than a normal MSA with N is the number of nodes. So make sure your
process is as fast as it can be for example use Numba and stuff, but you don't need to implement any parallel
processes since it's already implemented here. Going below 1000 permutations might be an option depending on
your specific case but based on experience, it's not a good idea 
5. Shapley values sum up (or will be close) to the value of the intact coalition. So for example if the
mean activity of node C here is 50 then causal_contribution_matrix.sum(axis=0) = 50 or close to 50. If not it
means:
    1. the number of permutations are not enough
    2. there is randomness somewhere in the process
    3. your objective function is not suitable

Parameters:

Name Type Description Default
elements list

List of the players (elements). Can be strings (names), integers (indicies), and tuples.

required
objective_function Callable

The game (in-silico experiment). It should get the complement set and return one numeric value either int or float. This function is just calling it as: objective_function(complement, **objective_function_params)

An example using networkx with some tips:

def lesion_me_senpai(complements, network, index): # note "index", your function should be able to track the effects on the target and the keyword for that is "index"

if len(complements) == len(A)-1:  # -1 since the target node is active
    return 0

lesioned_network = deepcopy(network)
for target in complements:
    lesioned_network[target] = 0  # setting all connections of the targets to 0

activity = network.run(lesioned_network) # or really, whatever you want!
return float(activity[index].mean())

(you sometimes need to specify what should happen during edge-cases like an all-lesioned network)

required
objective_function_params Optional[Dict]

Kwargs for the objective_function. A dictionary pair of {'index': index} will be added to this during the process so your function can track the lesion effect.

None
target_elements Optional[list]

list of elements that you want to calculate the causal influence of.

None
multiprocessing_method str = 'joblib'

So far, two methods of parallelization is implemented, 'joblib' and 'ray' and the default method is joblib. If using ray tho, you need to decorate your objective function with @ray.remote decorator. Visit their documentations to see how to go for it.

'joblib'
n_cores int = -1

Number of parallel games. Default is -1, which means all cores so it can make the system freeze for a short period, if that happened then maybe go for -2, which means one msapy is left out. Or really just specify the number of threads you want to use!

-1
n_permutations int = 1000

Number of permutations per node. Didn't check it systematically yet but just based on random explorations I'd say something around 1000 is enough.

1000
permutation_seed Optional[int] = None

Sets the random seed of the sampling process. Default is None so if nothing is given every call results in a different orderings.

None
parallelize_over_games bool = False

Whether to parallelize over games or parallelize over elements. Parallelizing over the elements is generally faster. Defaults to False

False

Returns:

Type Description
DataFrame

causal_influences (pd.DataFrame)

Source code in msapy/msa.py
@typechecked
def estimate_causal_influences(elements: list,
                               objective_function: Callable,
                               objective_function_params: Optional[dict] = None,
                               target_elements: Optional[list] = None,
                               multiprocessing_method: str = 'joblib',
                               n_cores: int = -1,
                               n_permutations: int = 1000,
                               permutation_seed: Optional[int] = None,
                               parallelize_over_games=False,
                               lazy=True
                               ) -> pd.DataFrame:
    """
    Estimates the causal contribution (Shapley values) of each node on the rest of the network. Basically, this function
    performs MSA iteratively on each node and tracks the changes in the objective_function of the target node.
    For example we have a chain A -> B -> C, and we want to know how much A and B are contributing to C. We first need to
    define a metric for C (objective_function) which here let's say is the average activity of C. MSA then performs a
    multi-site lesioning analysis of A and B so for each we will end up with a number indicating their contributions to
    the average activity of C.

    VERY IMPORTANT NOTES:

        1. The resulting causal contribution matrix does not necessarily reflect the connectome. In the example above
        there's no actual connection A -> C but there might be one in the causal contribution matrix since A is causally
        influencing C via B.
        2. Think twice (even three times) about your objective function. The same everything will result in different
        causal contribution matrices depending on what are you tracking and how accurate it's capturing the effect of
        lesions. Also don't forget the edge-cases. There will be weird behaviors in your system, for example, what it
        does if every node is perturbed?
        3. The metric you track is preferred to be non-negative and bounded (at least practically!)
        4. Obviously this will take N times longer than a normal MSA with N is the number of nodes. So make sure your
        process is as fast as it can be for example use Numba and stuff, but you don't need to implement any parallel
        processes since it's already implemented here. Going below 1000 permutations might be an option depending on
        your specific case but based on experience, it's not a good idea 
        5. Shapley values sum up (or will be close) to the value of the intact coalition. So for example if the
        mean activity of node C here is 50 then causal_contribution_matrix.sum(axis=0) = 50 or close to 50. If not it
        means:
            1. the number of permutations are not enough
            2. there is randomness somewhere in the process
            3. your objective function is not suitable


    Args:
        elements (list):
            List of the players (elements). Can be strings (names), integers (indicies), and tuples.

        objective_function (Callable):
            The game (in-silico experiment). It should get the complement set and return one numeric value
            either int or float.
            This function is just calling it as: objective_function(complement, **objective_function_params)

            An example using networkx with some tips:

            def lesion_me_senpai(complements, network, index):
                # note "index", your function should be able to track the effects on the target and the keyword for
                  that is "index"

                if len(complements) == len(A)-1:  # -1 since the target node is active
                    return 0

                lesioned_network = deepcopy(network)
                for target in complements:
                    lesioned_network[target] = 0  # setting all connections of the targets to 0

                activity = network.run(lesioned_network) # or really, whatever you want!
                return float(activity[index].mean())

            (you sometimes need to specify what should happen during edge-cases like an all-lesioned network)


        objective_function_params (Optional[Dict]):
            Kwargs for the objective_function. A dictionary pair of {'index': index} will be added to this during
            the process so your function can track the lesion effect.

        target_elements (Optional[list]): list of elements that you want to calculate the causal influence of.

        multiprocessing_method (str = 'joblib'):
            So far, two methods of parallelization is implemented, 'joblib' and 'ray' and the default method is joblib.
            If using ray tho, you need to decorate your objective function with @ray.remote decorator. Visit their
            documentations to see how to go for it.

        n_cores (int = -1):
            Number of parallel games. Default is -1, which means all cores so it can make the system
            freeze for a short period, if that happened then maybe go for -2, which means one msapy is
            left out. Or really just specify the number of threads you want to use!

        n_permutations (int = 1000):
            Number of permutations per node.
            Didn't check it systematically yet but just based on random explorations
            I'd say something around 1000 is enough.

        permutation_seed (Optional[int] = None):
            Sets the random seed of the sampling process. Default is None so if nothing is given every call results in
            a different orderings.

        parallelize_over_games (bool = False): Whether to parallelize over games or parallelize over elements. Parallelizing
            over the elements is generally faster. Defaults to False

    Returns:
        causal_influences (pd.DataFrame)

    """
    target_elements = target_elements if target_elements else elements
    objective_function_params = objective_function_params if objective_function_params else {}

    if parallelize_over_games:
        # run causal_influence_single_element for all target elements.
        mbar = master_bar(enumerate(target_elements),
                          total=len(target_elements))
        results = [causal_influence_single_element(elements, objective_function,
                                                   objective_function_params, n_permutations,
                                                   n_cores, multiprocessing_method,
                                                   permutation_seed, index, element, lazy, mbar) for index, element in mbar]

    elif multiprocessing_method == 'ray':
        if importlib.util.find_spec("ray") is None:
            raise ImportError(
                "The ray package is required to run this algorithm. Install and use at your own risk.")

        import ray
        if n_cores <= 0:
            warnings.warn("A zero or a negative n_cores was passed and ray doesn't like so "
                          "to fix that ray.init() will get no arguments, "
                          "which means use all cores as n_cores = -1 does for joblib.", stacklevel=2)
            ray.init()
        else:
            ray.init(num_cpus=n_cores)

        result_ids = [ray.remote(causal_influence_single_element).remote(elements, objective_function,
                                                                         objective_function_params, n_permutations,
                                                                         1, 'joblib',
                                                                         permutation_seed, index, element, lazy, None) for index, element in enumerate(target_elements)]

        for _ in tqdm(ut.ray_iterator(result_ids), total=len(result_ids)):
            pass

        results = ray.get(result_ids)
        ray.shutdown()

    else:
        with tqdm_joblib(desc="Doing Nodes: ", total=len(target_elements)) as pb:
            results = (Parallel(n_jobs=n_cores)(delayed(causal_influence_single_element)(elements, objective_function,
                                                                                         objective_function_params, n_permutations,
                                                                                         1, 'joblib',
                                                                                         permutation_seed, index, element, lazy) for index, element in enumerate(target_elements)))

    _, contribution_type = results[0]
    shapley_values = [r[0] for r in results]

    causal_influences = pd.DataFrame(
        shapley_values, columns=elements) if contribution_type == "scaler" else pd.concat(shapley_values, keys=elements)

    if contribution_type == "scaler":
        return causal_influences
    return causal_influences[causal_influences.index.levels[0]]