Skip to content

Commit

Permalink
Merge pull request #24 from kuffmode/Release-1.6
Browse files Browse the repository at this point in the history
Release 1.6
  • Loading branch information
kuffmode authored Feb 5, 2024
2 parents f2bd920 + 4014bb3 commit 015976d
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 31 deletions.
42 changes: 42 additions & 0 deletions .github/workflows/publish_pypi.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries

# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.

name: Upload Python Package

on:
push:
branches:
main

permissions:
contents: read

jobs:
deploy:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.9'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build setuptools wheel twine
- name: Build package
run: |
rm dist/*
python setup.py sdist bdist_wheel
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
/dist/
/msapy.egg-info/
docs/examples/.ipynb_checkpoints/
docs/examples/data
docs/examples/data
celebA
*.pyc
2 changes: 1 addition & 1 deletion msapy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from msapy import msa, utils, plottings,checks
__version__ = "1.5"
__version__ = "1.6"
43 changes: 24 additions & 19 deletions msapy/msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def make_combination_space(*, permutation_space: list, pair: Optional[Tuple] = N
_check_valid_permutation_space(permutation_space)

# if we have an element that needs to be lesioned in every combination, then we store it in a set so that taking a difference becomes easier and efficient
lesioned = set(lesioned) if lesioned else set()
lesioned = {lesioned} if lesioned else set()

combination_space = OrderedSet()

Expand Down Expand Up @@ -393,12 +393,12 @@ def local_efficiency(complements, graph):
_check_get_shapley_table_args(contributions, objective_function, lazy)
_check_valid_permutation_space(permutation_space)

contributions = {tuple(): objective_function(tuple(), **objective_function_params)} if lazy else contributions
lesioned = {lesioned} if lesioned else set()
contributions = {tuple(lesioned): objective_function(tuple(lesioned), **objective_function_params)} if lazy else contributions

contribution_type, arbitrary_contrib = _get_contribution_type(contributions)
contrib_shape = arbitrary_contrib.shape if contribution_type == "nd" else []
contribution_type, intact_contributions_in_case_lazy = _get_contribution_type(contributions)
contrib_shape = intact_contributions_in_case_lazy.shape if contribution_type == "nd" else []

lesioned = set(lesioned) if lesioned else set()
sorted_elements = sorted(permutation_space[0])
permutation_space = set(permutation_space)

Expand All @@ -414,21 +414,23 @@ def local_efficiency(complements, graph):
shapley_table = 0 if (contribution_type == 'nd' and not save_permutations) else np.zeros((len(permutation_space), len(sorted_elements), *contrib_shape), dtype=float)

for i, permutation in parent_bar:
isolated_contributions = np.zeros((len(permutation), *arbitrary_contrib.shape), dtype=float) if contribution_type=="nd" else ([None] * len(permutation)) # got to be a better way!
isolated_contributions = np.zeros((len(permutation), *intact_contributions_in_case_lazy.shape), dtype=float) if contribution_type=="nd" else ([None] * len(permutation)) # got to be a better way!
child_bar = enumerate(permutation) if not (dual_progress_bars and lazy) else progress_bar(
enumerate(permutation), total=len(permutation), leave=False, parent=parent_bar)
# iterate over all elements in the permutation to calculate their isolated contributions

contributions_including = intact_contributions_in_case_lazy
for index, element in child_bar:
including = frozenset(permutation[:index + 1]) - lesioned
excluding = frozenset(permutation[:index]) - lesioned
including = frozenset(permutation[:index + 1])
excluding = frozenset(permutation[:index])

# the isolated contribution of an element is the difference of contribution with that element and without that element
if lazy:
contributions_including = objective_function(tuple(excluding), **objective_function_params)
contributions_excluding = objective_function(tuple(including), **objective_function_params)
contributions_excluding = objective_function(tuple(including.union(lesioned)), **objective_function_params)
isolated_contributions[sorted_elements.index(element)] = contributions_including - contributions_excluding
contributions_including = contributions_excluding
else:
isolated_contributions[sorted_elements.index(element)] = contributions[including] - contributions[excluding]
isolated_contributions[sorted_elements.index(element)] = contributions[including - lesioned] - contributions[excluding - lesioned]

if contribution_type == 'nd' and not save_permutations:
shapley_table += (isolated_contributions - shapley_table) / (i + 1)
Expand All @@ -441,7 +443,7 @@ def local_efficiency(complements, graph):
shapley_table = shapley_table.reshape(shapley_table.shape[0], -1).T
shapley_table = pd.DataFrame(
shapley_table, columns=sorted_elements)
return ShapleyModeND(shapley_table, arbitrary_contrib.shape)
return ShapleyModeND(shapley_table, intact_contributions_in_case_lazy.shape)

if contribution_type == "scaler":
return ShapleyTable(pd.DataFrame(shapley_table, columns=sorted_elements))
Expand Down Expand Up @@ -631,6 +633,7 @@ def interaction_2d(*,
rng: Optional[np.random.Generator] = None,
random_seed: Optional[int] = None,
n_parallel_games: int = -1,
lazy: bool = False,
) -> Tuple:
"""Performs Two dimensional MSA as explain in section 2.3 of [1].
We calculate the Shapley value of element i in the subgame of all elements without element j.
Expand Down Expand Up @@ -719,22 +722,22 @@ def local_efficiency(complements, graph):
"random_seed": random_seed,
"n_parallel_games": n_parallel_games,
"save_permutations": False,
"lazy": False}
"lazy": lazy}

# calculate the shapley values with element j lesioned
shapley_i = interface(**interface_args, lesioned=pair[1])
# get the shapley value of element i with element j leasioned
gamma_i = _get_gamma(shapley_i, pair[0]).sum()
gamma_i = _get_gamma(shapley_i, [pair[0]]).sum()

# calculate the shapley values with element i lesioned
shapley_j = interface(**interface_args, lesioned=pair[0])
# get the shapley value of element j with element i leasioned
gamma_j = _get_gamma(shapley_j, pair[1]).sum()
gamma_j = _get_gamma(shapley_j, [pair[1]]).sum()

# calculate the shapley values with element i and j together in every combination
shapley_ij = interface(**interface_args, pair=pair)
# get the sum of the shapley value of element i and j
gamma_ij = _get_gamma(shapley_ij, pair).sum()
gamma_ij = _get_gamma(shapley_ij, list(pair)).sum()

return gamma_ij, gamma_i, gamma_j

Expand All @@ -750,6 +753,7 @@ def network_interaction_2d(*,
rng: Optional[np.random.Generator] = None,
random_seed: Optional[int] = None,
n_parallel_games: int = -1,
lazy: bool = False
) -> np.ndarray:
"""Performs Two dimensional MSA as explain in section 2.3 of [1]
for every possible pair of elements and returns a symmetric matrix of
Expand Down Expand Up @@ -838,7 +842,8 @@ def local_efficiency(complements, graph):
"multiprocessing_method": multiprocessing_method,
"rng": rng,
"random_seed": random_seed,
"n_parallel_games": n_parallel_games}
"n_parallel_games": n_parallel_games,
"lazy": lazy}

interactions = np.zeros((len(elements), len(elements)))

Expand All @@ -865,9 +870,9 @@ def _get_gamma(shapley, idx):
shapley value of elements in idx
"""
if isinstance(shapley, ShapleyTable):
gamma = shapley.shapley_values[list(idx)]
gamma = shapley.shapley_values[idx]
elif isinstance(shapley, ShapleyTableND):
gamma = shapley[list(idx)]
gamma = shapley[idx]
return gamma


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
test_packages = ["pytest~=6.2.5"]

setup(name="msapy",
version="1.5",
version="1.6",
description=DESCRIPTION,
long_description=LONG_DESCRIPTION,
long_description_content_type='text/markdown',
Expand Down
19 changes: 10 additions & 9 deletions tests/test_ground_truth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@
# ------------------------------#
# A function that assigns 1 to the cause and 0 to others
def simple(complements, causes):
if len(causes) != 0 and set(causes).issubset(complements):
if len(causes) != 0 and {causes}.issubset(complements):
return 0
else:
return 1


def simple_with_interaction(complements):
if ("A" not in complements) and ("B" not in complements):
if ("A115" not in complements) and ("B655" not in complements):
return sum(contrib_dict.values()) - sum(contrib_dict[k] for k in complements) + 87

return sum(contrib_dict.values()) - sum(contrib_dict[k] for k in complements)


# ------------------------------#
elements = ['a', 'b', 'c']
cause = 'a'
elements = ['A115', 'b', 'c']
cause = 'A115'
shapley_table = msa.interface(
elements=elements,
n_permutations=300,
Expand All @@ -30,7 +30,7 @@ def simple_with_interaction(complements):
objective_function_params={'causes': cause},
random_seed=111)

contrib_dict = {"A": 10, "B": 9, "C": 57, "D": -8, "E": 42}
contrib_dict = {"A115": 10, "B655": 9, "C": 57, "D": -8, "E": 42}

# ------------------------------#

Expand All @@ -44,7 +44,7 @@ def test_min():


def test_cause():
assert shapley_table['a'].mean() == 1
assert shapley_table['A115'].mean() == 1


def test_others():
Expand All @@ -57,15 +57,16 @@ def test_d_index():
shapley_vector=shapley_table.mean()) == 0


@pytest.mark.parametrize("n_parallel_games, multiprocessing_method", [(1, 'joblib'), (-1, 'joblib')])
def test_interaction_2d(n_parallel_games, multiprocessing_method):
@pytest.mark.parametrize("n_parallel_games, multiprocessing_method, lazy", [(1, 'joblib', True), (-1, 'joblib', True), (1, 'joblib', False), (-1, 'joblib', False)])
def test_interaction_2d(n_parallel_games, multiprocessing_method, lazy):
interactions = msa.network_interaction_2d(
elements=list(contrib_dict.keys()),
n_permutations=1000,
objective_function=simple_with_interaction,
n_parallel_games=n_parallel_games,
multiprocessing_method=multiprocessing_method,
random_seed=111)
random_seed=111,
lazy=lazy)

expected_interactions = np.array([[0, 87, 0, 0, 0], [87, 0, 0, 0, 0],
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
Expand Down

0 comments on commit 015976d

Please sign in to comment.