From cdc7ca8b6747580f90be7106db2c3df177d1e8e9 Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Thu, 1 Aug 2024 19:01:20 +0200 Subject: [PATCH] Split MaskLike type alias in Mask (which can be applied to any DataFrame) and AgentMask (which can be applied to the _agents DataFrame of AgentSetDF) --- mesa_frames/abstract/agents.py | 129 +++++++++++++----------- mesa_frames/abstract/mixin.py | 6 +- mesa_frames/concrete/agents.py | 46 ++++----- mesa_frames/concrete/pandas/agentset.py | 16 +-- mesa_frames/concrete/pandas/mixin.py | 6 +- mesa_frames/concrete/polars/agentset.py | 26 ++--- mesa_frames/concrete/polars/mixin.py | 6 +- mesa_frames/types_.py | 22 ++-- tests/test_agents.py | 6 +- tests/test_agentset_pandas.py | 6 +- tests/test_agentset_polars.py | 6 +- 11 files changed, 144 insertions(+), 131 deletions(-) diff --git a/mesa_frames/abstract/agents.py b/mesa_frames/abstract/agents.py index ce412c3..269a13d 100644 --- a/mesa_frames/abstract/agents.py +++ b/mesa_frames/abstract/agents.py @@ -9,7 +9,14 @@ from typing_extensions import Any, Self, overload from mesa_frames.abstract.mixin import CopyMixin -from mesa_frames.types_ import BoolSeries, DataFrame, IdsLike, Index, MaskLike, Series +from mesa_frames.types_ import ( + AgentMask, + BoolSeries, + DataFrame, + IdsLike, + Index, + Series, +) if TYPE_CHECKING: from mesa_frames.concrete.agents import AgentSetDF @@ -36,13 +43,13 @@ class AgentContainer(CopyMixin): Check if agents with the specified IDs are in the AgentContainer. do(method_name: str, *args, return_results: bool = False, inplace: bool = True, **kwargs) -> Self | Any | dict[str, Any] Invoke a method on the AgentContainer. - get(attr_names: str | Collection[str] | None = None, mask: MaskLike | None = None) -> Series | DataFrame | dict[str, Series] | dict[str, DataFrame] + get(attr_names: str | Collection[str] | None = None, mask: AgentMask | None = None) -> Series | DataFrame | dict[str, Series] | dict[str, DataFrame] Retrieve the value of a specified attribute for each agent in the AgentContainer. remove(ids: IdsLike, inplace: bool = True) -> Self Removes an agent from the AgentContainer. - select(mask: MaskLike | None = None, filter_func: Callable[[Self], MaskLike] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True) -> Self + select(mask: AgentMask | None = None, filter_func: Callable[[Self], AgentMask] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True) -> Self Select agents in the AgentContainer based on the given criteria. - set(attr_names: str | dict[str, Any] | Collection[str], values: Any | None = None, mask: MaskLike | None = None, inplace: bool = True) -> Self + set(attr_names: str | dict[str, Any] | Collection[str], values: Any | None = None, mask: AgentMask | None = None, inplace: bool = True) -> Self Sets the value of a specified attribute or attributes for each agent in the mask in AgentContainer. shuffle(inplace: bool = False) -> Self Shuffles the order of agents in the AgentContainer. @@ -136,7 +143,7 @@ def do( self, method_name: str, *args, - mask: MaskLike | None = None, + mask: AgentMask | None = None, return_results: Literal[False] = False, inplace: bool = True, **kwargs, @@ -148,7 +155,7 @@ def do( self, method_name: str, *args, - mask: MaskLike | None = None, + mask: AgentMask | None = None, return_results: Literal[True], inplace: bool = True, **kwargs, @@ -159,7 +166,7 @@ def do( self, method_name: str, *args, - mask: MaskLike | None = None, + mask: AgentMask | None = None, return_results: bool = False, inplace: bool = True, **kwargs, @@ -172,7 +179,7 @@ def do( The name of the method to invoke. *args : Any Positional arguments to pass to the method - mask : MaskLike, optional + mask : AgentMask, optional The subset of agents on which to apply the method return_results : bool, optional Whether to return the result of the method, by default False @@ -198,7 +205,7 @@ def get(self, attr_names: Collection[str]) -> DataFrame | dict[str, DataFrame]: def get( self, attr_names: str | Collection[str] | None = None, - mask: MaskLike | None = None, + mask: AgentMask | None = None, ) -> Series | DataFrame | dict[str, Series] | dict[str, DataFrame]: """Retrieves the value of a specified attribute for each agent in the AgentContainer. @@ -206,8 +213,8 @@ def get( ---------- attr_names : str | Collection[str] | None The attributes to retrieve. If None, all attributes are retrieved. Defaults to None. - MaskLike : MaskLike | None - The MaskLike of agents to retrieve the attribute for. If None, attributes of all agents are returned. Defaults to None. + AgentMask : AgentMask | None + The AgentMask of agents to retrieve the attribute for. If None, attributes of all agents are returned. Defaults to None. Returns ---------- @@ -237,8 +244,8 @@ def remove(self, agents, inplace: bool = True) -> Self: @abstractmethod def select( self, - mask: MaskLike | None = None, - filter_func: Callable[[Self], MaskLike] | None = None, + mask: AgentMask | None = None, + filter_func: Callable[[Self], AgentMask] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True, @@ -247,10 +254,10 @@ def select( Parameters ---------- - mask : MaskLike | None, optional - The MaskLike of agents to be selected, by default None - filter_func : Callable[[Self], MaskLike] | None, optional - A function which takes as input the AgentContainer and returns a MaskLike, by default None + mask : AgentMask | None, optional + The AgentMask of agents to be selected, by default None + filter_func : Callable[[Self], AgentMask] | None, optional + A function which takes as input the AgentContainer and returns a AgentMask, by default None n : int, optional The maximum number of agents to be selected, by default None negate : bool, optional @@ -271,7 +278,7 @@ def set( self, attr_names: dict[str, Any], values: None, - mask: MaskLike | None = None, + mask: AgentMask | None = None, inplace: bool = True, ) -> Self: ... @@ -281,7 +288,7 @@ def set( self, attr_names: str | Collection[str], values: Any, - mask: MaskLike | None = None, + mask: AgentMask | None = None, inplace: bool = True, ) -> Self: ... @@ -290,7 +297,7 @@ def set( self, attr_names: str | dict[str, Any] | Collection[str], values: Any | None = None, - mask: MaskLike | None = None, + mask: AgentMask | None = None, inplace: bool = True, ) -> Self: """Sets the value of a specified attribute or attributes for each agent in the mask in AgentContainer. @@ -304,8 +311,8 @@ def set( - A dictionary: keys should be attributes and values should be the values to set. Value should be None. value : Any | None The value to set the attribute to. If None, attr_names must be a dictionary. - mask : MaskLike | None - The MaskLike of agents to set the attribute for. + mask : AgentMask | None + The AgentMask of agents to set the attribute for. inplace : bool Whether to set the attribute in place. @@ -382,21 +389,21 @@ def __getitem__( key: ( str | Collection[str] - | MaskLike - | tuple[MaskLike, str] - | tuple[MaskLike, Collection[str]] + | AgentMask + | tuple[AgentMask, str] + | tuple[AgentMask, Collection[str]] ), ) -> Series | DataFrame | dict[str, Series] | dict[str, DataFrame]: """Implements the [] operator for the AgentContainer. The key can be: - An attribute or collection of attributes (eg. AgentContainer["str"], AgentContainer[["str1", "str2"]]): returns the specified column(s) of the agents in the AgentContainer. - - A MaskLike (eg. AgentContainer[MaskLike]): returns the agents in the AgentContainer that satisfy the MaskLike. - - A tuple (eg. AgentContainer[MaskLike, "str"]): returns the specified column of the agents in the AgentContainer that satisfy the MaskLike. + - A AgentMask (eg. AgentContainer[AgentMask]): returns the agents in the AgentContainer that satisfy the AgentMask. + - A tuple (eg. AgentContainer[AgentMask, "str"]): returns the specified column of the agents in the AgentContainer that satisfy the AgentMask. Parameters ---------- - key : Attributes | MaskLike | tuple[MaskLike, Attributes] + key : Attributes | AgentMask | tuple[AgentMask, Attributes] The key to retrieve. Returns @@ -433,7 +440,7 @@ def __isub__(self, other: AgentSetDF | IdsLike) -> Self: Parameters ---------- - other : MaskLike + other : AgentMask The agents to remove. Returns @@ -460,7 +467,10 @@ def __sub__(self, other: AgentSetDF | IdsLike) -> Self: def __setitem__( self, - key: str | Collection[str] | MaskLike | tuple[MaskLike, str | Collection[str]], + key: str + | Collection[str] + | AgentMask + | tuple[AgentMask, str | Collection[str]], values: Any, ) -> None: """Implement the [] operator for setting values in the AgentContainer. @@ -468,12 +478,12 @@ def __setitem__( The key can be: - A string (eg. AgentContainer["str"]): sets the specified column of the agents in the AgentContainer. - A list of strings(eg. AgentContainer[["str1", "str2"]]): sets the specified columns of the agents in the AgentContainer. - - A tuple (eg. AgentContainer[MaskLike, "str"]): sets the specified column of the agents in the AgentContainer that satisfy the MaskLike. - - A MaskLike (eg. AgentContainer[MaskLike]): sets the attributes of the agents in the AgentContainer that satisfy the MaskLike. + - A tuple (eg. AgentContainer[AgentMask, "str"]): sets the specified column of the agents in the AgentContainer that satisfy the AgentMask. + - A AgentMask (eg. AgentContainer[AgentMask]): sets the attributes of the agents in the AgentContainer that satisfy the AgentMask. Parameters ---------- - key : str | list[str] | MaskLike | tuple[MaskLike, str | list[str]] + key : str | list[str] | AgentMask | tuple[AgentMask, str | list[str]] The key to set. values : Any The values to set for the specified key. @@ -487,7 +497,7 @@ def __setitem__( ): try: self.set(attr_names=key, values=values) - except KeyError: # key=MaskLike + except KeyError: # key=AgentMask self.set(attr_names=None, mask=key, values=values) else: self.set(attr_names=None, mask=key, values=values) @@ -615,13 +625,13 @@ def active_agents(self) -> DataFrame | dict[str, DataFrame]: @abstractmethod def active_agents( self, - mask: MaskLike, + mask: AgentMask, ) -> None: """Set the active agents in the AgentContainer. Parameters ---------- - mask : MaskLike + mask : AgentMask The mask to apply. """ self.select(mask=mask, inplace=True) @@ -648,7 +658,7 @@ class AgentSetDF(AgentContainer): A list of attributes to copy with a reference only. _copy_with_method : dict[str, tuple[str, list[str]]] A dictionary of attributes to copy with a specified method and arguments. - _mask : MaskLike + _mask : AgentMask The underlying mask used for the active agents in the AgentSetDF. _model : ModelDF The model that the AgentSetDF belongs to. @@ -663,17 +673,17 @@ class AgentSetDF(AgentContainer): Check if agents with the specified IDs are in the AgentSetDF. copy(self, deep: bool = False, memo: dict | None = None) -> Self Create a copy of the AgentSetDF. - discard(self, ids: MaskLike, inplace: bool = True) -> Self + discard(self, ids: AgentMask, inplace: bool = True) -> Self Removes an agent from the AgentSetDF. Does not raise an error if the agent is not found. do(self, method_name: str, *args, return_results: bool = False, inplace: bool = True, **kwargs) -> Self | Any Invoke a method on the AgentSetDF. - get(self, attr_names: str | Collection[str] | None = None, mask: MaskLike | None = None) -> Series | DataFrame + get(self, attr_names: str | Collection[str] | None = None, mask: AgentMask | None = None) -> Series | DataFrame Retrieve the value of a specified attribute for each agent in the AgentSetDF. - remove(self, ids: MaskLike, inplace: bool = True) -> Self + remove(self, ids: AgentMask, inplace: bool = True) -> Self Removes an agent from the AgentSetDF. - select(self, mask: MaskLike | None = None, filter_func: Callable[[Self], MaskLike] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True) -> Self + select(self, mask: AgentMask | None = None, filter_func: Callable[[Self], AgentMask] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True) -> Self Select agents in the AgentSetDF based on the given criteria. - set(self, attr_names: str | dict[str, Any] | Collection[str], values: Any | None = None, mask: MaskLike | None = None, inplace: bool = True) -> Self + set(self, attr_names: str | dict[str, Any] | Collection[str], values: Any | None = None, mask: AgentMask | None = None, inplace: bool = True) -> Self Sets the value of a specified attribute or attributes for each agent in the mask in AgentSetDF. shuffle(self, inplace: bool = False) -> Self Shuffles the order of agents in the AgentSetDF. @@ -687,7 +697,7 @@ class AgentSetDF(AgentContainer): Add agents to the AgentSetDF through the += operator. __getattr__(self, name: str) -> Any Retrieve an attribute of the AgentSetDF. - __getitem__(self, key: str | Collection[str] | MaskLike | tuple[MaskLike, str] | tuple[MaskLike, Collection[str]]) -> Series | DataFrame + __getitem__(self, key: str | Collection[str] | AgentMask | tuple[AgentMask, str] | tuple[AgentMask, Collection[str]]) -> Series | DataFrame Retrieve an item from the AgentSetDF. __iter__(self) -> Iterator Get an iterator for the agents in the AgentSetDF. @@ -715,7 +725,7 @@ class AgentSetDF(AgentContainer): """ _agents: DataFrame - _mask: MaskLike + _mask: AgentMask _model: ModelDF @abstractmethod @@ -768,7 +778,7 @@ def do( self, method_name: str, *args, - mask: MaskLike | None = None, + mask: AgentMask | None = None, return_results: Literal[False] = False, inplace: bool = True, **kwargs, @@ -779,7 +789,7 @@ def do( self, method_name: str, *args, - mask: MaskLike | None = None, + mask: AgentMask | None = None, return_results: Literal[True], inplace: bool = True, **kwargs, @@ -789,7 +799,7 @@ def do( self, method_name: str, *args, - mask: MaskLike | None = None, + mask: AgentMask | None = None, return_results: bool = False, inplace: bool = True, **kwargs, @@ -826,7 +836,7 @@ def do( def get( self, attr_names: str, - mask: MaskLike | None = None, + mask: AgentMask | None = None, ) -> Series: ... @abstractmethod @@ -834,14 +844,14 @@ def get( def get( self, attr_names: Collection[str] | None = None, - mask: MaskLike | None = None, + mask: AgentMask | None = None, ) -> DataFrame: ... @abstractmethod def get( self, attr_names: str | Collection[str] | None = None, - mask: MaskLike | None = None, + mask: AgentMask | None = None, ) -> Series | DataFrame: ... @abstractmethod @@ -857,12 +867,12 @@ def _concatenate_agentsets( ) -> Self: ... @abstractmethod - def _get_bool_mask(self, mask: MaskLike) -> BoolSeries: + def _get_bool_mask(self, mask: AgentMask) -> BoolSeries: """Get the equivalent boolean mask based on the input mask Parameters ---------- - mask : MaskLike + mask : AgentMask Returns ------- @@ -871,12 +881,12 @@ def _get_bool_mask(self, mask: MaskLike) -> BoolSeries: ... @abstractmethod - def _get_masked_df(self, mask: MaskLike) -> DataFrame: + def _get_masked_df(self, mask: AgentMask) -> DataFrame: """Get the df filtered by the input mask Parameters ---------- - mask : MaskLike + mask : AgentMask Returns ------- @@ -954,11 +964,12 @@ def __getattr__(self, name: str) -> Any: ) @overload - def __getitem__(self, key: str | tuple[MaskLike, str]) -> Series | DataFrame: ... + def __getitem__(self, key: str | tuple[AgentMask, str]) -> Series | DataFrame: ... @overload def __getitem__( - self, key: MaskLike | Collection[str] | tuple[MaskLike, Collection[str]] + self, + key: AgentMask | Collection[str] | tuple[AgentMask, Collection[str]], ) -> DataFrame: ... def __getitem__( @@ -966,9 +977,9 @@ def __getitem__( key: ( str | Collection[str] - | MaskLike - | tuple[MaskLike, str] - | tuple[MaskLike, Collection[str]] + | AgentMask + | tuple[AgentMask, str] + | tuple[AgentMask, Collection[str]] ), ) -> Series | DataFrame: attr = super().__getitem__(key) diff --git a/mesa_frames/abstract/mixin.py b/mesa_frames/abstract/mixin.py index d58b24a..910cae3 100644 --- a/mesa_frames/abstract/mixin.py +++ b/mesa_frames/abstract/mixin.py @@ -5,7 +5,7 @@ from typing import Literal from collections.abc import Collection, Iterator, Sequence -from mesa_frames.types_ import BoolSeries, DataFrame, MaskLike, Series +from mesa_frames.types_ import BoolSeries, DataFrame, Mask, Series class CopyMixin(ABC): @@ -181,7 +181,7 @@ def _df_get_bool_mask( self, df: DataFrame, index_col: str, - mask: MaskLike | None = None, + mask: Mask | None = None, negate: bool = False, ) -> BoolSeries: ... @@ -190,7 +190,7 @@ def _df_get_masked_df( self, df: DataFrame, index_col: str, - mask: MaskLike | None = None, + mask: Mask | None = None, columns: list[str] | None = None, negate: bool = False, ) -> DataFrame: ... diff --git a/mesa_frames/concrete/agents.py b/mesa_frames/concrete/agents.py index 3f8530e..fd8362c 100644 --- a/mesa_frames/concrete/agents.py +++ b/mesa_frames/concrete/agents.py @@ -7,11 +7,11 @@ from mesa_frames.abstract.agents import AgentContainer, AgentSetDF from mesa_frames.types_ import ( - AgnosticMask, + AgentMask, + AgnosticAgentMask, BoolSeries, DataFrame, IdsLike, - MaskLike, Series, ) @@ -58,13 +58,13 @@ class AgentsDF(AgentContainer): Remove an agent from the AgentsDF. Does not raise an error if the agent is not found. do(self, method_name: str, *args, return_results: bool = False, inplace: bool = True, **kwargs) -> Self | Any Invoke a method on the AgentsDF. - get(self, attr_names: str | Collection[str] | None = None, mask: MaskLike = None) -> dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame] + get(self, attr_names: str | Collection[str] | None = None, mask: AgentMask = None) -> dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame] Retrieve the value of a specified attribute for each agent in the AgentsDF. remove(self, ids: IdsLike, inplace: bool = True) -> Self Remove agents from the AgentsDF. - select(self, mask: MaskLike = None, filter_func: Callable[[Self], MaskLike] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True) -> Self + select(self, mask: AgentMask = None, filter_func: Callable[[Self], AgentMask] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True) -> Self Select agents in the AgentsDF based on the given criteria. - set(self, attr_names: str | Collection[str] | dict[AgentSetDF, Any] | None = None, values: Any | None = None, mask: MaskLike | None = None, inplace: bool = True) -> Self + set(self, attr_names: str | Collection[str] | dict[AgentSetDF, Any] | None = None, values: Any | None = None, mask: AgentMask | None = None, inplace: bool = True) -> Self Set the value of a specified attribute or attributes for each agent in the mask in the AgentsDF. shuffle(self, inplace: bool = True) -> Self Shuffle the order of agents in the AgentsDF. @@ -157,7 +157,7 @@ def do( self, method_name: str, *args, - mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None, + mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, return_results: Literal[False] = False, inplace: bool = True, **kwargs, @@ -168,7 +168,7 @@ def do( self, method_name: str, *args, - mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None, + mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, return_results: Literal[True], inplace: bool = True, **kwargs, @@ -178,7 +178,7 @@ def do( self, method_name: str, *args, - mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None, + mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, return_results: bool = False, inplace: bool = True, **kwargs, @@ -214,7 +214,7 @@ def do( def get( self, attr_names: str | Collection[str] | None = None, - mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None, + mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, ) -> dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame]: agentsets_masks = self._get_bool_masks(mask) return { @@ -253,8 +253,8 @@ def remove( def select( self, - mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None, - filter_func: Callable[[AgentSetDF], MaskLike] | None = None, + mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, + filter_func: Callable[[AgentSetDF], AgentMask] | None = None, n: int | None = None, inplace: bool = True, negate: bool = False, @@ -275,7 +275,7 @@ def set( self, attr_names: str | dict[AgentSetDF, Any] | Collection[str], values: Any | None = None, - mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None, + mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) @@ -370,7 +370,7 @@ def _check_agentsets_presence(self, other: list[AgentSetDF]) -> pl.Series: def _get_bool_masks( self, - mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None, + mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, ) -> dict[AgentSetDF, BoolSeries]: return_dictionary = {} if not isinstance(mask, dict): @@ -418,16 +418,16 @@ def __getattr__(self, name: str) -> dict[AgentSetDF, Any]: @overload def __getitem__( - self, key: str | tuple[dict[AgentSetDF, MaskLike], str] + self, key: str | tuple[dict[AgentSetDF, AgentMask], str] ) -> dict[str, Series]: ... @overload def __getitem__( self, key: Collection[str] - | AgnosticMask + | AgnosticAgentMask | IdsLike - | tuple[dict[AgentSetDF, MaskLike], Collection[str]], + | tuple[dict[AgentSetDF, AgentMask], Collection[str]], ) -> dict[str, DataFrame]: ... def __getitem__( @@ -435,10 +435,10 @@ def __getitem__( key: ( str | Collection[str] - | AgnosticMask + | AgnosticAgentMask | IdsLike - | tuple[dict[AgentSetDF, MaskLike], str] - | tuple[dict[AgentSetDF, MaskLike], Collection[str]] + | tuple[dict[AgentSetDF, AgentMask], str] + | tuple[dict[AgentSetDF, AgentMask], Collection[str]] ), ) -> dict[str, Series] | dict[str, DataFrame]: return super().__getitem__(key) @@ -494,10 +494,10 @@ def __setitem__( key: ( str | Collection[str] - | AgnosticMask + | AgnosticAgentMask | IdsLike - | tuple[dict[AgentSetDF, MaskLike], str] - | tuple[dict[AgentSetDF, MaskLike], Collection[str]] + | tuple[dict[AgentSetDF, AgentMask], str] + | tuple[dict[AgentSetDF, AgentMask], Collection[str]] ), values: Any, ) -> None: @@ -542,7 +542,7 @@ def active_agents(self) -> dict[AgentSetDF, DataFrame]: @active_agents.setter def active_agents( - self, agents: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] + self, agents: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] ) -> None: self.select(agents, inplace=True) diff --git a/mesa_frames/concrete/pandas/agentset.py b/mesa_frames/concrete/pandas/agentset.py index 0378ae5..1b97104 100644 --- a/mesa_frames/concrete/pandas/agentset.py +++ b/mesa_frames/concrete/pandas/agentset.py @@ -8,7 +8,7 @@ from mesa_frames.abstract.agents import AgentSetDF from mesa_frames.concrete.pandas.mixin import PandasMixin from mesa_frames.concrete.polars.agentset import AgentSetPolars -from mesa_frames.types_ import PandasIdsLike, PandasMaskLike +from mesa_frames.types_ import AgentPandasMask, PandasIdsLike if TYPE_CHECKING: from mesa_frames.concrete.model import ModelDF @@ -172,7 +172,7 @@ def contains(self, agents: PandasIdsLike) -> bool | pd.Series: def get( self, attr_names: str | Collection[str] | None = None, - mask: PandasMaskLike = None, + mask: AgentPandasMask = None, ) -> pd.Index | pd.Series | pd.DataFrame: mask = self._get_bool_mask(mask) if attr_names is None: @@ -206,7 +206,7 @@ def set( self, attr_names: str | dict[str, Any] | Collection[str] | None = None, values: Any | None = None, - mask: PandasMaskLike = None, + mask: AgentPandasMask = None, inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) @@ -242,8 +242,8 @@ def set( def select( self, - mask: PandasMaskLike = None, - filter_func: Callable[[Self], PandasMaskLike] | None = None, + mask: AgentPandasMask = None, + filter_func: Callable[[Self], AgentPandasMask] | None = None, n: int | None = None, negate: bool = False, inplace: bool = True, @@ -315,7 +315,7 @@ def _concatenate_agentsets( def _get_bool_mask( self, - mask: PandasMaskLike = None, + mask: AgentPandasMask = None, ) -> pd.Series: if isinstance(mask, pd.Series) and mask.dtype == bool: return mask @@ -334,7 +334,7 @@ def _get_bool_mask( def _get_masked_df( self, - mask: PandasMaskLike = None, + mask: AgentPandasMask = None, ) -> pd.DataFrame: if isinstance(mask, pd.Series) and mask.dtype == bool: return self._agents.loc[mask] @@ -428,7 +428,7 @@ def active_agents(self) -> pd.DataFrame: return self._agents.loc[self._mask] @active_agents.setter - def active_agents(self, mask: PandasMaskLike) -> None: + def active_agents(self, mask: AgentPandasMask) -> None: self.select(mask=mask, inplace=True) @property diff --git a/mesa_frames/concrete/pandas/mixin.py b/mesa_frames/concrete/pandas/mixin.py index 9e594e8..be22393 100644 --- a/mesa_frames/concrete/pandas/mixin.py +++ b/mesa_frames/concrete/pandas/mixin.py @@ -6,7 +6,7 @@ from typing_extensions import Any from mesa_frames.abstract.mixin import DataFrameMixin -from mesa_frames.types_ import PandasMaskLike +from mesa_frames.types_ import PandasMask class PandasMixin(DataFrameMixin): @@ -47,7 +47,7 @@ def _df_get_bool_mask( self, df: pd.DataFrame, index_col: str, - mask: PandasMaskLike = None, + mask: PandasMask = None, negate: bool = False, ) -> pd.Series: if isinstance(mask, pd.Series) and mask.dtype == bool and len(mask) == len(df): @@ -77,7 +77,7 @@ def _df_get_masked_df( self, df: pd.DataFrame, index_col: str, - mask: PandasMaskLike | None = None, + mask: PandasMask | None = None, columns: list[str] | None = None, negate: bool = False, ) -> pd.DataFrame: diff --git a/mesa_frames/concrete/polars/agentset.py b/mesa_frames/concrete/polars/agentset.py index a9ad914..42ed528 100644 --- a/mesa_frames/concrete/polars/agentset.py +++ b/mesa_frames/concrete/polars/agentset.py @@ -7,7 +7,7 @@ from mesa_frames.concrete.agents import AgentSetDF from mesa_frames.concrete.polars.mixin import PolarsMixin -from mesa_frames.types_ import PolarsIdsLike, PolarsMaskLike +from mesa_frames.types_ import AgentPolarsMask, PolarsIdsLike if TYPE_CHECKING: from mesa_frames.concrete.model import ModelDF @@ -188,7 +188,7 @@ def contains( def get( self, attr_names: IntoExpr | Iterable[IntoExpr] | None, - mask: PolarsMaskLike = None, + mask: AgentPolarsMask = None, ) -> pl.Series | pl.DataFrame: masked_df = self._get_masked_df(mask) attr_names = self.agents.select(attr_names).columns.copy() @@ -219,7 +219,7 @@ def set( self, attr_names: str | Collection[str] | dict[str, Any] | None = None, values: Any | None = None, - mask: PolarsMaskLike = None, + mask: AgentPolarsMask = None, inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) @@ -270,7 +270,7 @@ def process_single_attr( def select( self, - mask: PolarsMaskLike = None, + mask: AgentPolarsMask = None, filter_func: Callable[[Self], pl.Series] | None = None, n: int | None = None, negate: bool = False, @@ -388,7 +388,7 @@ def _concatenate_agentsets( def _get_bool_mask( self, - mask: PolarsMaskLike = None, + mask: AgentPolarsMask = None, ) -> pl.Series | pl.Expr: def bool_mask_from_series(mask: pl.Series) -> pl.Series: if ( @@ -423,7 +423,7 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: def _get_masked_df( self, - mask: PolarsMaskLike = None, + mask: AgentPolarsMask = None, ) -> pl.DataFrame: if (isinstance(mask, pl.Series) and mask.dtype == pl.Boolean) or isinstance( mask, pl.Expr @@ -486,17 +486,17 @@ def __getattr__(self, key: str) -> pl.Series: @overload def __getitem__( self, - key: str | tuple[PolarsMaskLike, str], + key: str | tuple[AgentPolarsMask, str], ) -> pl.Series: ... @overload def __getitem__( self, key: ( - PolarsMaskLike + AgentPolarsMask | Collection[str] | tuple[ - PolarsMaskLike, + AgentPolarsMask, Collection[str], ] ), @@ -507,10 +507,10 @@ def __getitem__( key: ( str | Collection[str] - | PolarsMaskLike - | tuple[PolarsMaskLike, str] + | AgentPolarsMask + | tuple[AgentPolarsMask, str] | tuple[ - PolarsMaskLike, + AgentPolarsMask, Collection[str], ] ), @@ -543,7 +543,7 @@ def active_agents(self) -> pl.DataFrame: return self.agents.filter(self._mask) @active_agents.setter - def active_agents(self, mask: PolarsMaskLike) -> None: + def active_agents(self, mask: AgentPolarsMask) -> None: self.select(mask=mask, inplace=True) @property diff --git a/mesa_frames/concrete/polars/mixin.py b/mesa_frames/concrete/polars/mixin.py index 3645597..c3854ee 100644 --- a/mesa_frames/concrete/polars/mixin.py +++ b/mesa_frames/concrete/polars/mixin.py @@ -5,7 +5,7 @@ from typing_extensions import Any from mesa_frames.abstract.mixin import DataFrameMixin -from mesa_frames.types_ import PolarsMaskLike +from mesa_frames.types_ import PolarsMask class PolarsMixin(DataFrameMixin): @@ -65,7 +65,7 @@ def _df_get_bool_mask( self, df: pl.DataFrame, index_col: str, - mask: PolarsMaskLike = None, + mask: PolarsMask = None, negate: bool = False, ) -> pl.Series | pl.Expr: def bool_mask_from_series(mask: pl.Series) -> pl.Series: @@ -106,7 +106,7 @@ def _df_get_masked_df( self, df: pl.DataFrame, index_col: str, - mask: PolarsMaskLike | None = None, + mask: PolarsMask | None = None, columns: list[str] | None = None, negate: bool = False, ) -> pl.DataFrame: diff --git a/mesa_frames/types_.py b/mesa_frames/types_.py index b1b4ddf..5dcfd5b 100644 --- a/mesa_frames/types_.py +++ b/mesa_frames/types_.py @@ -1,29 +1,31 @@ -from collections.abc import Collection +from collections.abc import Collection, Sequence from typing import Literal -from collections.abc import Sequence - import geopandas as gpd import geopolars as gpl import pandas as pd import polars as pl from numpy import ndarray +from typing_extensions import Any ####----- Agnostic Types -----#### -AgnosticMask = Literal["all", "active"] | None +AgnosticMask = ( + Any | Sequence[Any] | None +) # Any is a placeholder for any type if it's a single value +AgnosticAgentMask = Sequence[int] | int | Literal["all", "active"] | None AgnosticIds = int | Collection[int] ###----- Pandas Types -----### -ArrayLike = pd.api.extensions.ExtensionArray | ndarray -AnyArrayLike = ArrayLike | pd.Index | pd.Series -PandasMaskLike = AgnosticMask | pd.Series | pd.DataFrame | AnyArrayLike +PandasMask = pd.Series | pd.DataFrame | AgnosticMask +AgentPandasMask = AgnosticAgentMask | pd.Series | pd.DataFrame PandasIdsLike = AgnosticIds | pd.Series | pd.Index PandasGridCapacity = ndarray ###----- Polars Types -----### -PolarsMaskLike = AgnosticMask | pl.Expr | pl.Series | pl.DataFrame | Collection[int] +PolarsMask = pl.Expr | pl.Series | pl.DataFrame | AgnosticMask +AgentPolarsMask = AgnosticAgentMask | pl.Expr | pl.Series | pl.DataFrame | Sequence[int] PolarsIdsLike = AgnosticIds | pl.Series PolarsGridCapacity = list[pl.Expr] @@ -31,10 +33,10 @@ GeoDataFrame = gpd.GeoDataFrame | gpl.GeoDataFrame DataFrame = pd.DataFrame | pl.DataFrame Series = pd.Series | pl.Series -Series = pd.Series | pl.Series Index = pd.Index | pl.Series BoolSeries = pd.Series | pl.Series -MaskLike = AgnosticMask | PandasMaskLike | PolarsMaskLike +Mask = PandasMask | PolarsMask +AgentMask = AgentPandasMask | AgentPolarsMask IdsLike = AgnosticIds | PandasIdsLike | PolarsIdsLike diff --git a/tests/test_agents.py b/tests/test_agents.py index f1886b1..404dc47 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -6,7 +6,7 @@ from mesa_frames import AgentsDF, ModelDF from mesa_frames.abstract.agents import AgentSetDF -from mesa_frames.types_ import MaskLike +from mesa_frames.types_ import AgentMask from tests.test_agentset_pandas import ( ExampleAgentSetPandas, fix1_AgentSetPandas, @@ -554,7 +554,7 @@ def test__get_bool_masks(self, fix_AgentsDF: AgentsDF): len(agents._agentsets[1]) - 1 ) - # Test with mask = dict[AgentSetDF, MaskLike] + # Test with mask = dict[AgentSetDF, AgentMask] result = agents._get_bool_masks(mask=mask_dictionary) assert result[agents._agentsets[0]].to_list() == mask0.to_list() assert result[agents._agentsets[1]].to_list() == mask1.to_list() @@ -714,7 +714,7 @@ def test___getitem__( fix2_AgentSetPolars._agents["wealth"] > fix2_AgentSetPolars._agents["wealth"][0] ) - mask_dictionary: dict[AgentSetDF, MaskLike] = { + mask_dictionary: dict[AgentSetDF, AgentMask] = { fix1_AgentSetPandas: mask0, fix2_AgentSetPolars: mask1, } diff --git a/tests/test_agentset_pandas.py b/tests/test_agentset_pandas.py index 4093b5e..c61271e 100644 --- a/tests/test_agentset_pandas.py +++ b/tests/test_agentset_pandas.py @@ -292,13 +292,13 @@ def test__getitem__(self, fix1_AgentSetPandas: ExampleAgentSetPandas): # Testing with a string assert agents["wealth"].tolist() == [1, 2, 3, 4] - # Test with a tuple[MaskLike, str] + # Test with a tuple[AgentMask, str] assert agents[0, "wealth"].values == 1 # Test with a list[str] assert agents[["wealth", "age"]].columns.tolist() == ["wealth", "age"] - # Testing with a tuple[MaskLike, list[str]] + # Testing with a tuple[AgentMask, list[str]] result = agents[0, ["wealth", "age"]] assert result["wealth"].values.tolist() == [1] assert result["age"].values.tolist() == [10] @@ -375,7 +375,7 @@ def test__setitem__(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents[0, "wealth"] = 5 assert agents.agents.wealth.tolist() == [5, 1, 1, 1] - # Test with key=MaskLike, value=Any + # Test with key=AgentMask, value=Any agents[0] = [9, 99] assert agents.agents.loc[0, "wealth"] == 9 assert agents.agents.loc[0, "age"] == 99 diff --git a/tests/test_agentset_polars.py b/tests/test_agentset_polars.py index 97a7983..8925742 100644 --- a/tests/test_agentset_polars.py +++ b/tests/test_agentset_polars.py @@ -291,13 +291,13 @@ def test__getitem__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): # Testing with a string assert agents["wealth"].to_list() == [1, 2, 3, 4] - # Test with a tuple[MaskLike, str] + # Test with a tuple[AgentMask, str] assert agents[0, "wealth"].item() == 1 # Test with a list[str] assert agents[["wealth", "age"]].columns == ["wealth", "age"] - # Testing with a tuple[MaskLike, list[str]] + # Testing with a tuple[AgentMask, list[str]] result = agents[0, ["wealth", "age"]] assert result["wealth"].to_list() == [1] assert result["age"].to_list() == [10] @@ -374,7 +374,7 @@ def test__setitem__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents[0, "wealth"] = 5 assert agents.agents["wealth"].to_list() == [5, 1, 1, 1] - # Test with key=MaskLike, value=Any + # Test with key=AgentMask, value=Any agents[0] = [9, 99] assert agents.agents.item(0, "wealth") == 9 assert agents.agents.item(0, "age") == 99