Skip to content

Commit

Permalink
Split MaskLike type alias in Mask (which can be applied to any DataFr…
Browse files Browse the repository at this point in the history
…ame) and AgentMask (which can be applied to the _agents DataFrame of AgentSetDF)
  • Loading branch information
adamamer20 committed Aug 1, 2024
1 parent 2b1f50d commit cdc7ca8
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 131 deletions.
129 changes: 70 additions & 59 deletions mesa_frames/abstract/agents.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions mesa_frames/abstract/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Literal
from collections.abc import Collection, Iterator, Sequence

from mesa_frames.types_ import BoolSeries, DataFrame, MaskLike, Series
from mesa_frames.types_ import BoolSeries, DataFrame, Mask, Series


class CopyMixin(ABC):
Expand Down Expand Up @@ -181,7 +181,7 @@ def _df_get_bool_mask(
self,
df: DataFrame,
index_col: str,
mask: MaskLike | None = None,
mask: Mask | None = None,
negate: bool = False,
) -> BoolSeries: ...

Expand All @@ -190,7 +190,7 @@ def _df_get_masked_df(
self,
df: DataFrame,
index_col: str,
mask: MaskLike | None = None,
mask: Mask | None = None,
columns: list[str] | None = None,
negate: bool = False,
) -> DataFrame: ...
Expand Down
46 changes: 23 additions & 23 deletions mesa_frames/concrete/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from mesa_frames.abstract.agents import AgentContainer, AgentSetDF
from mesa_frames.types_ import (
AgnosticMask,
AgentMask,
AgnosticAgentMask,
BoolSeries,
DataFrame,
IdsLike,
MaskLike,
Series,
)

Expand Down Expand Up @@ -58,13 +58,13 @@ class AgentsDF(AgentContainer):
Remove an agent from the AgentsDF. Does not raise an error if the agent is not found.
do(self, method_name: str, *args, return_results: bool = False, inplace: bool = True, **kwargs) -> Self | Any
Invoke a method on the AgentsDF.
get(self, attr_names: str | Collection[str] | None = None, mask: MaskLike = None) -> dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame]
get(self, attr_names: str | Collection[str] | None = None, mask: AgentMask = None) -> dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame]
Retrieve the value of a specified attribute for each agent in the AgentsDF.
remove(self, ids: IdsLike, inplace: bool = True) -> Self
Remove agents from the AgentsDF.
select(self, mask: MaskLike = None, filter_func: Callable[[Self], MaskLike] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True) -> Self
select(self, mask: AgentMask = None, filter_func: Callable[[Self], AgentMask] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True) -> Self
Select agents in the AgentsDF based on the given criteria.
set(self, attr_names: str | Collection[str] | dict[AgentSetDF, Any] | None = None, values: Any | None = None, mask: MaskLike | None = None, inplace: bool = True) -> Self
set(self, attr_names: str | Collection[str] | dict[AgentSetDF, Any] | None = None, values: Any | None = None, mask: AgentMask | None = None, inplace: bool = True) -> Self
Set the value of a specified attribute or attributes for each agent in the mask in the AgentsDF.
shuffle(self, inplace: bool = True) -> Self
Shuffle the order of agents in the AgentsDF.
Expand Down Expand Up @@ -157,7 +157,7 @@ def do(
self,
method_name: str,
*args,
mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None,
mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None,
return_results: Literal[False] = False,
inplace: bool = True,
**kwargs,
Expand All @@ -168,7 +168,7 @@ def do(
self,
method_name: str,
*args,
mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None,
mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None,
return_results: Literal[True],
inplace: bool = True,
**kwargs,
Expand All @@ -178,7 +178,7 @@ def do(
self,
method_name: str,
*args,
mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None,
mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None,
return_results: bool = False,
inplace: bool = True,
**kwargs,
Expand Down Expand Up @@ -214,7 +214,7 @@ def do(
def get(
self,
attr_names: str | Collection[str] | None = None,
mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None,
mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None,
) -> dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame]:
agentsets_masks = self._get_bool_masks(mask)
return {
Expand Down Expand Up @@ -253,8 +253,8 @@ def remove(

def select(
self,
mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None,
filter_func: Callable[[AgentSetDF], MaskLike] | None = None,
mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None,
filter_func: Callable[[AgentSetDF], AgentMask] | None = None,
n: int | None = None,
inplace: bool = True,
negate: bool = False,
Expand All @@ -275,7 +275,7 @@ def set(
self,
attr_names: str | dict[AgentSetDF, Any] | Collection[str],
values: Any | None = None,
mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None,
mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None,
inplace: bool = True,
) -> Self:
obj = self._get_obj(inplace)
Expand Down Expand Up @@ -370,7 +370,7 @@ def _check_agentsets_presence(self, other: list[AgentSetDF]) -> pl.Series:

def _get_bool_masks(
self,
mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None,
mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None,
) -> dict[AgentSetDF, BoolSeries]:
return_dictionary = {}
if not isinstance(mask, dict):
Expand Down Expand Up @@ -418,27 +418,27 @@ def __getattr__(self, name: str) -> dict[AgentSetDF, Any]:

@overload
def __getitem__(
self, key: str | tuple[dict[AgentSetDF, MaskLike], str]
self, key: str | tuple[dict[AgentSetDF, AgentMask], str]
) -> dict[str, Series]: ...

@overload
def __getitem__(
self,
key: Collection[str]
| AgnosticMask
| AgnosticAgentMask
| IdsLike
| tuple[dict[AgentSetDF, MaskLike], Collection[str]],
| tuple[dict[AgentSetDF, AgentMask], Collection[str]],
) -> dict[str, DataFrame]: ...

def __getitem__(
self,
key: (
str
| Collection[str]
| AgnosticMask
| AgnosticAgentMask
| IdsLike
| tuple[dict[AgentSetDF, MaskLike], str]
| tuple[dict[AgentSetDF, MaskLike], Collection[str]]
| tuple[dict[AgentSetDF, AgentMask], str]
| tuple[dict[AgentSetDF, AgentMask], Collection[str]]
),
) -> dict[str, Series] | dict[str, DataFrame]:
return super().__getitem__(key)
Expand Down Expand Up @@ -494,10 +494,10 @@ def __setitem__(
key: (
str
| Collection[str]
| AgnosticMask
| AgnosticAgentMask
| IdsLike
| tuple[dict[AgentSetDF, MaskLike], str]
| tuple[dict[AgentSetDF, MaskLike], Collection[str]]
| tuple[dict[AgentSetDF, AgentMask], str]
| tuple[dict[AgentSetDF, AgentMask], Collection[str]]
),
values: Any,
) -> None:
Expand Down Expand Up @@ -542,7 +542,7 @@ def active_agents(self) -> dict[AgentSetDF, DataFrame]:

@active_agents.setter
def active_agents(
self, agents: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike]
self, agents: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask]
) -> None:
self.select(agents, inplace=True)

Expand Down
16 changes: 8 additions & 8 deletions mesa_frames/concrete/pandas/agentset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mesa_frames.abstract.agents import AgentSetDF
from mesa_frames.concrete.pandas.mixin import PandasMixin
from mesa_frames.concrete.polars.agentset import AgentSetPolars
from mesa_frames.types_ import PandasIdsLike, PandasMaskLike
from mesa_frames.types_ import AgentPandasMask, PandasIdsLike

if TYPE_CHECKING:
from mesa_frames.concrete.model import ModelDF
Expand Down Expand Up @@ -172,7 +172,7 @@ def contains(self, agents: PandasIdsLike) -> bool | pd.Series:
def get(
self,
attr_names: str | Collection[str] | None = None,
mask: PandasMaskLike = None,
mask: AgentPandasMask = None,
) -> pd.Index | pd.Series | pd.DataFrame:
mask = self._get_bool_mask(mask)
if attr_names is None:
Expand Down Expand Up @@ -206,7 +206,7 @@ def set(
self,
attr_names: str | dict[str, Any] | Collection[str] | None = None,
values: Any | None = None,
mask: PandasMaskLike = None,
mask: AgentPandasMask = None,
inplace: bool = True,
) -> Self:
obj = self._get_obj(inplace)
Expand Down Expand Up @@ -242,8 +242,8 @@ def set(

def select(
self,
mask: PandasMaskLike = None,
filter_func: Callable[[Self], PandasMaskLike] | None = None,
mask: AgentPandasMask = None,
filter_func: Callable[[Self], AgentPandasMask] | None = None,
n: int | None = None,
negate: bool = False,
inplace: bool = True,
Expand Down Expand Up @@ -315,7 +315,7 @@ def _concatenate_agentsets(

def _get_bool_mask(
self,
mask: PandasMaskLike = None,
mask: AgentPandasMask = None,
) -> pd.Series:
if isinstance(mask, pd.Series) and mask.dtype == bool:
return mask
Expand All @@ -334,7 +334,7 @@ def _get_bool_mask(

def _get_masked_df(
self,
mask: PandasMaskLike = None,
mask: AgentPandasMask = None,
) -> pd.DataFrame:
if isinstance(mask, pd.Series) and mask.dtype == bool:
return self._agents.loc[mask]
Expand Down Expand Up @@ -428,7 +428,7 @@ def active_agents(self) -> pd.DataFrame:
return self._agents.loc[self._mask]

@active_agents.setter
def active_agents(self, mask: PandasMaskLike) -> None:
def active_agents(self, mask: AgentPandasMask) -> None:
self.select(mask=mask, inplace=True)

@property
Expand Down
6 changes: 3 additions & 3 deletions mesa_frames/concrete/pandas/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Any

from mesa_frames.abstract.mixin import DataFrameMixin
from mesa_frames.types_ import PandasMaskLike
from mesa_frames.types_ import PandasMask


class PandasMixin(DataFrameMixin):
Expand Down Expand Up @@ -47,7 +47,7 @@ def _df_get_bool_mask(
self,
df: pd.DataFrame,
index_col: str,
mask: PandasMaskLike = None,
mask: PandasMask = None,
negate: bool = False,
) -> pd.Series:
if isinstance(mask, pd.Series) and mask.dtype == bool and len(mask) == len(df):
Expand Down Expand Up @@ -77,7 +77,7 @@ def _df_get_masked_df(
self,
df: pd.DataFrame,
index_col: str,
mask: PandasMaskLike | None = None,
mask: PandasMask | None = None,
columns: list[str] | None = None,
negate: bool = False,
) -> pd.DataFrame:
Expand Down
26 changes: 13 additions & 13 deletions mesa_frames/concrete/polars/agentset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from mesa_frames.concrete.agents import AgentSetDF
from mesa_frames.concrete.polars.mixin import PolarsMixin
from mesa_frames.types_ import PolarsIdsLike, PolarsMaskLike
from mesa_frames.types_ import AgentPolarsMask, PolarsIdsLike

if TYPE_CHECKING:
from mesa_frames.concrete.model import ModelDF
Expand Down Expand Up @@ -188,7 +188,7 @@ def contains(
def get(
self,
attr_names: IntoExpr | Iterable[IntoExpr] | None,
mask: PolarsMaskLike = None,
mask: AgentPolarsMask = None,
) -> pl.Series | pl.DataFrame:
masked_df = self._get_masked_df(mask)
attr_names = self.agents.select(attr_names).columns.copy()
Expand Down Expand Up @@ -219,7 +219,7 @@ def set(
self,
attr_names: str | Collection[str] | dict[str, Any] | None = None,
values: Any | None = None,
mask: PolarsMaskLike = None,
mask: AgentPolarsMask = None,
inplace: bool = True,
) -> Self:
obj = self._get_obj(inplace)
Expand Down Expand Up @@ -270,7 +270,7 @@ def process_single_attr(

def select(
self,
mask: PolarsMaskLike = None,
mask: AgentPolarsMask = None,
filter_func: Callable[[Self], pl.Series] | None = None,
n: int | None = None,
negate: bool = False,
Expand Down Expand Up @@ -388,7 +388,7 @@ def _concatenate_agentsets(

def _get_bool_mask(
self,
mask: PolarsMaskLike = None,
mask: AgentPolarsMask = None,
) -> pl.Series | pl.Expr:
def bool_mask_from_series(mask: pl.Series) -> pl.Series:
if (
Expand Down Expand Up @@ -423,7 +423,7 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series:

def _get_masked_df(
self,
mask: PolarsMaskLike = None,
mask: AgentPolarsMask = None,
) -> pl.DataFrame:
if (isinstance(mask, pl.Series) and mask.dtype == pl.Boolean) or isinstance(
mask, pl.Expr
Expand Down Expand Up @@ -486,17 +486,17 @@ def __getattr__(self, key: str) -> pl.Series:
@overload
def __getitem__(
self,
key: str | tuple[PolarsMaskLike, str],
key: str | tuple[AgentPolarsMask, str],
) -> pl.Series: ...

@overload
def __getitem__(
self,
key: (
PolarsMaskLike
AgentPolarsMask
| Collection[str]
| tuple[
PolarsMaskLike,
AgentPolarsMask,
Collection[str],
]
),
Expand All @@ -507,10 +507,10 @@ def __getitem__(
key: (
str
| Collection[str]
| PolarsMaskLike
| tuple[PolarsMaskLike, str]
| AgentPolarsMask
| tuple[AgentPolarsMask, str]
| tuple[
PolarsMaskLike,
AgentPolarsMask,
Collection[str],
]
),
Expand Down Expand Up @@ -543,7 +543,7 @@ def active_agents(self) -> pl.DataFrame:
return self.agents.filter(self._mask)

@active_agents.setter
def active_agents(self, mask: PolarsMaskLike) -> None:
def active_agents(self, mask: AgentPolarsMask) -> None:
self.select(mask=mask, inplace=True)

@property
Expand Down
6 changes: 3 additions & 3 deletions mesa_frames/concrete/polars/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import Any

from mesa_frames.abstract.mixin import DataFrameMixin
from mesa_frames.types_ import PolarsMaskLike
from mesa_frames.types_ import PolarsMask


class PolarsMixin(DataFrameMixin):
Expand Down Expand Up @@ -65,7 +65,7 @@ def _df_get_bool_mask(
self,
df: pl.DataFrame,
index_col: str,
mask: PolarsMaskLike = None,
mask: PolarsMask = None,
negate: bool = False,
) -> pl.Series | pl.Expr:
def bool_mask_from_series(mask: pl.Series) -> pl.Series:
Expand Down Expand Up @@ -106,7 +106,7 @@ def _df_get_masked_df(
self,
df: pl.DataFrame,
index_col: str,
mask: PolarsMaskLike | None = None,
mask: PolarsMask | None = None,
columns: list[str] | None = None,
negate: bool = False,
) -> pl.DataFrame:
Expand Down
Loading

0 comments on commit cdc7ca8

Please sign in to comment.