Skip to content

Commit

Permalink
fix: Changing Sequence to ArrayLike because np.ndarray and Series are…
Browse files Browse the repository at this point in the history
… not Sequences (But keeping the separate hinting for clarity)
  • Loading branch information
adamamer20 committed Aug 18, 2024
1 parent b31acfb commit 46405c8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
23 changes: 12 additions & 11 deletions mesa_frames/abstract/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mesa_frames.abstract.agents import AgentContainer, AgentSetDF
from mesa_frames.abstract.mixin import CopyMixin, DataFrameMixin
from mesa_frames.types_ import (
ArrayLike,
BoolSeries,
DataFrame,
DiscreteCoordinate,
Expand Down Expand Up @@ -59,7 +60,7 @@ class SpaceDF(CopyMixin, DataFrameMixin):
) -> DataFrame
Returns the distances from pos0 to pos1 or agents0 and agents1.
get_neighbors(
radius: int | float | Sequence[int] | Sequence[float],
radius: int | float | Sequence[int] | Sequence[float] | ArrayLike,
pos: Space
) -> DataFrame
Get the neighboring agents from given positions or agents according to the specified radiuses.
Expand Down Expand Up @@ -312,7 +313,7 @@ def get_distances(
@abstractmethod
def get_neighbors(
self,
radius: int | float | Sequence[int] | Sequence[float],
radius: int | float | Sequence[int] | Sequence[float] | ArrayLike,
pos: SpaceCoordinate | SpaceCoordinates | None = None,
agents: IdsLike | AgentContainer | Collection[AgentContainer] | None = None,
include_center: bool = False,
Expand All @@ -322,7 +323,7 @@ def get_neighbors(
Parameters
----------
radius : int | float | Sequence[int] | Sequence[float]
radius : int | float | Sequence[int] | Sequence[float] | ArrayLike
The radius(es) of the neighborhood
pos : SpaceCoordinate | SpaceCoordinates | None, optional
The coordinates of the cell to get the neighborhood from, by default None
Expand Down Expand Up @@ -536,7 +537,7 @@ class DiscreteSpaceDF(SpaceDF):
Move agents to available cells in the space (cells where there is at least one spot available).
sample_cells(n: int, cell_type: Literal["any", "empty", "available", "full"] = "any", with_replacement: bool = True) -> DataFrame
Sample cells from the grid according to the specified cell_type.
get_neighborhood(radius: int | float | Sequence[int] | Sequence[float], pos: DiscreteCoordinate | DiscreteCoordinates | None = None, agents: IdsLike | AgentContainer | Collection[AgentContainer] = None, include_center: bool = False) -> DataFrame
get_neighborhood(radius: int | float | Sequence[int] | Sequence[float] | ArrayLike, pos: DiscreteCoordinate | DiscreteCoordinates | None = None, agents: IdsLike | AgentContainer | Collection[AgentContainer] = None, include_center: bool = False) -> DataFrame
Get the neighborhood cells from the given positions (pos) or agents according to the specified radiuses.
get_cells(coords: DiscreteCoordinate | DiscreteCoordinates | None = None) -> DataFrame
Retrieve a dataframe of specified cells with their properties and agents.
Expand Down Expand Up @@ -774,7 +775,7 @@ def set_cells(
@abstractmethod
def get_neighborhood(
self,
radius: int | float | Sequence[int] | Sequence[float],
radius: int | float | Sequence[int] | Sequence[float] | ArrayLike,
pos: DiscreteCoordinate | DiscreteCoordinates | None = None,
agents: IdsLike | AgentContainer | Collection[AgentContainer] = None,
include_center: bool = False,
Expand All @@ -784,7 +785,7 @@ def get_neighborhood(
Parameters
----------
radius : int | float | Sequence[int] | Sequence[float]
radius : int | float | Sequence[int] | Sequence[float] | ArrayLike
The radius(es) of the neighborhoods
pos : DiscreteCoordinate | DiscreteCoordinates | None, optional
The coordinates of the cell(s) to get the neighborhood from
Expand Down Expand Up @@ -1211,15 +1212,15 @@ def get_neighbors(

def get_neighborhood(
self,
radius: int | Sequence[int],
radius: int | Sequence[int] | ArrayLike,
pos: GridCoordinate | GridCoordinates | None = None,
agents: IdsLike | AgentContainer | Collection[AgentContainer] = None,
include_center: bool = False,
) -> DataFrame:
pos_df = self._get_df_coords(pos, agents)

if __debug__:
if isinstance(radius, Sequence):
if isinstance(radius, ArrayLike):
if len(radius) != len(pos_df):
raise ValueError(
"The length of the radius sequence must be equal to the number of positions/agents"
Expand All @@ -1228,7 +1229,7 @@ def get_neighborhood(
## Create all possible neighbors by multiplying offsets by the radius and adding original pos

# If radius is a sequence, get the maximum radius (we will drop unnecessary neighbors later, time-efficient but memory-inefficient)
if isinstance(radius, Sequence):
if isinstance(radius, ArrayLike):
radius_srs = self._srs_constructor(radius, name="radius")
radius_df = self._srs_to_df(radius_srs)
max_radius = radius_srs.max()
Expand Down Expand Up @@ -1332,7 +1333,7 @@ def get_neighborhood(
)

# If radius is a sequence, filter unnecessary neighbors
if isinstance(radius, Sequence):
if isinstance(radius, ArrayLike):
radius_df = self._df_rename_columns(
self._df_concat([pos_df, radius_df], how="horizontal"),
self._pos_col_names + ["radius"],
Expand Down Expand Up @@ -1651,7 +1652,7 @@ def _get_df_coords(
columns=self._pos_col_names,
dtypes={col: int for col in self._pos_col_names},
)
elif isinstance(pos, Sequence) and len(pos) == len(self._dimensions):
elif isinstance(pos, ArrayLike) and len(pos) == len(self._dimensions):
# This means that the sequence is already a sequence where each element is the
# sequence of coordinates for dimension i
for i, c in enumerate(pos):
Expand Down
2 changes: 1 addition & 1 deletion mesa_frames/types_.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
Mask = PandasMask | PolarsMask
AgentMask = AgentPandasMask | AgentPolarsMask
IdsLike = AgnosticIds | PandasIdsLike | PolarsIdsLike

ArrayLike = ndarray | Series | Sequence

###----- Time ------###
TimeT = float | int
Expand Down

0 comments on commit 46405c8

Please sign in to comment.