Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 28, 2024
1 parent aa63d50 commit ec963de
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
6 changes: 4 additions & 2 deletions sodym/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,10 @@ def union_with(self, other: "DimensionSet") -> "DimensionSet":
added_dims = [dim for dim in other.dim_list if dim.letter not in self.letters]
return self.expand_by(added_dims)

def difference_with(self, other: 'DimensionSet') -> 'DimensionSet':
difference_letters = [dim.letter for dim in self.dim_list if dim.letter not in other.letters]
def difference_with(self, other: "DimensionSet") -> "DimensionSet":
difference_letters = [
dim.letter for dim in self.dim_list if dim.letter not in other.letters
]
return self.get_subset(difference_letters)

@property
Expand Down
4 changes: 3 additions & 1 deletion sodym/named_dim_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ def split(self, dim_letter: str) -> dict:

def get_shares_over(self, dim_letters: tuple) -> "NamedDimArray":
"""Get shares of the NamedDimArray along a tuple of dimensions, indicated by letter."""
assert all([d in self.dims.letters for d in dim_letters]), 'Dimensions to get share of must be in the object'
assert all(
[d in self.dims.letters for d in dim_letters]
), "Dimensions to get share of must be in the object"

if all([d in dim_letters for d in self.dims.letters]):
return self / self.sum_values()
Expand Down
12 changes: 6 additions & 6 deletions tests/test_named_dim_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,23 @@ def test_sum_nda_to():

def test_get_shares_over():
# example of getting shares over one dimension
shares = space_animals.get_shares_over(dim_letters=('p'))
shares = space_animals.get_shares_over(dim_letters=("p"))
assert shares.dims == space_animals.dims
wanted_values = np.einsum('pta,ta->pta', animal_values, 1/np.sum(animal_values, axis=0))
wanted_values = np.einsum("pta,ta->pta", animal_values, 1 / np.sum(animal_values, axis=0))
assert_array_almost_equal(shares.values, wanted_values)

# example of getting shares over two dimensions
shares = space_animals.get_shares_over(dim_letters=('p', 'a'))
wanted_values = np.einsum('pta,t->pta', animal_values, 1/np.sum(animal_values, axis=(0, 2)))
shares = space_animals.get_shares_over(dim_letters=("p", "a"))
wanted_values = np.einsum("pta,t->pta", animal_values, 1 / np.sum(animal_values, axis=(0, 2)))
assert_array_almost_equal(shares.values, wanted_values)

# example of getting shares over all dimensions
shares = space_animals.get_shares_over(dim_letters=('p', 't', 'a'))
shares = space_animals.get_shares_over(dim_letters=("p", "t", "a"))
assert_array_almost_equal(shares.values, animal_values / np.sum(animal_values))

# example of getting shares over a dimension that doesn't exist
with pytest.raises(AssertionError):
space_animals.get_shares_over(dim_letters=('s',))
space_animals.get_shares_over(dim_letters=("s",))


def test_maths():
Expand Down

0 comments on commit ec963de

Please sign in to comment.