Skip to content

Commit

Permalink
adding position accessor to AgentContainer
Browse files Browse the repository at this point in the history
  • Loading branch information
adamamer20 committed Aug 18, 2024
1 parent 896b1d7 commit abfbb8f
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 5 deletions.
24 changes: 22 additions & 2 deletions mesa_frames/abstract/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from numpy.random import Generator
from typing_extensions import Any, Self, overload

from mesa_frames.abstract.mixin import CopyMixin
from mesa_frames.abstract.mixin import CopyMixin, DataFrameMixin
from mesa_frames.types_ import (
AgentMask,
BoolSeries,
Expand Down Expand Up @@ -668,8 +668,19 @@ def index(self) -> Index | dict[AgentSetDF, Index]:
"""
...

@property
@abstractmethod
def pos(self) -> DataFrame | dict[str, DataFrame]:
"""The position of the agents in the AgentContainer.
class AgentSetDF(AgentContainer):
Returns
-------
DataFrame | dict[str, DataFrame]
"""
...


class AgentSetDF(AgentContainer, DataFrameMixin):
"""The AgentSetDF class is a container for agents of the same type.
Attributes
Expand Down Expand Up @@ -1050,3 +1061,12 @@ def inactive_agents(self) -> DataFrame: ...

@property
def index(self) -> Index: ...

@property
def pos(self) -> DataFrame:
pos = self._df_constructor(self.space.agents, index_cols="agent_id")
pos = self._df_get_masked_df(df=pos, index_cols="agent_id", mask=self.index)
pos = self._df_reindex(
pos, self.index, new_index_cols="unique_id", original_index_cols="agent_id"
)
return pos
4 changes: 4 additions & 0 deletions mesa_frames/concrete/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,3 +585,7 @@ def inactive_agents(self) -> dict[AgentSetDF, DataFrame]:
@property
def index(self) -> dict[AgentSetDF, Index]:
return {agentset: agentset.index for agentset in self._agentsets}

@property
def pos(self) -> dict[AgentSetDF, DataFrame]:
return {agentset: agentset.pos for agentset in self._agentsets}
4 changes: 4 additions & 0 deletions mesa_frames/concrete/pandas/agentset.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,7 @@ def inactive_agents(self) -> pd.DataFrame:
@property
def index(self) -> pd.Index:
return self._agents.index

@property
def pos(self) -> pd.DataFrame:
return super().pos
4 changes: 4 additions & 0 deletions mesa_frames/concrete/polars/agentset.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,7 @@ def inactive_agents(self) -> pl.DataFrame:
@property
def index(self) -> pl.Series:
return self._agents["unique_id"]

@property
def pos(self) -> pl.DataFrame:
return super().pos
18 changes: 16 additions & 2 deletions tests/pandas/test_agentset_pandas.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math
from copy import copy, deepcopy

import pandas as pd
import pytest
import typeguard as tg
from numpy.random import Generator

from mesa_frames import AgentSetPandas, ModelDF
from mesa_frames import AgentSetPandas, GridPolars, ModelDF


@tg.typechecked
Expand All @@ -28,7 +29,7 @@ def fix1_AgentSetPandas() -> ExampleAgentSetPandas:
agents.add({"unique_id": [0, 1, 2, 3]})
agents["wealth"] = agents.starting_wealth
agents["age"] = [10, 20, 30, 40]

model.agents.add(agents)
return agents


Expand Down Expand Up @@ -427,3 +428,16 @@ def test_inactive_agents(self, fix1_AgentSetPandas: ExampleAgentSetPandas):

agents.select(agents["wealth"] > 2, inplace=True)
assert agents.inactive_agents.index.to_list() == [0, 1]

def test_pos(self, fix1_AgentSetPandas: ExampleAgentSetPandas):
space = GridPolars(fix1_AgentSetPandas.model, dimensions=[3, 3], capacity=2)
fix1_AgentSetPandas.model.space = space
space.place_agents(agents=[0, 1], pos=[[0, 0], [1, 1]])
pos = fix1_AgentSetPandas.pos
assert isinstance(pos, pd.DataFrame)
assert pos.index.tolist() == [0, 1, 2, 3]
assert pos.columns.tolist() == ["dim_0", "dim_1"]
assert pos["dim_0"].tolist()[:2] == [0, 1]
assert all(math.isnan(val) for val in pos["dim_0"].tolist()[2:])
assert pos["dim_1"].tolist()[:2] == [0, 1]
assert all(math.isnan(val) for val in pos["dim_1"].tolist()[2:])
14 changes: 13 additions & 1 deletion tests/polars/test_agentset_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typeguard as tg
from numpy.random import Generator

from mesa_frames import AgentSetPolars, ModelDF
from mesa_frames import AgentSetPolars, GridPandas, ModelDF


@tg.typechecked
Expand All @@ -28,6 +28,7 @@ def fix1_AgentSetPolars() -> ExampleAgentSetPolars:
agents.add({"unique_id": [0, 1, 2, 3]})
agents["wealth"] = agents.starting_wealth
agents["age"] = [10, 20, 30, 40]
model.agents.add(agents)
return agents


Expand Down Expand Up @@ -426,3 +427,14 @@ def test_inactive_agents(self, fix1_AgentSetPolars: ExampleAgentSetPolars):

agents.select(agents.agents["wealth"] > 2, inplace=True)
assert agents.inactive_agents["unique_id"].to_list() == [0, 1]

def test_pos(self, fix1_AgentSetPolars: ExampleAgentSetPolars):
space = GridPandas(fix1_AgentSetPolars.model, dimensions=[3, 3], capacity=2)
fix1_AgentSetPolars.model.space = space
space.place_agents(agents=[0, 1], pos=[[0, 0], [1, 1]])
pos = fix1_AgentSetPolars.pos
assert isinstance(pos, pl.DataFrame)
assert pos["unique_id"].to_list() == [0, 1, 2, 3]
assert pos.columns == ["unique_id", "dim_0", "dim_1"]
assert pos["dim_0"].to_list() == [0, 1, None, None]
assert pos["dim_1"].to_list() == [0, 1, None, None]

0 comments on commit abfbb8f

Please sign in to comment.