Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concrete GridPolars #60

Merged
merged 43 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
c98a08d
adding implementation of GridPolars
adamamer20 Aug 1, 2024
8d933e3
Update space.py
adamamer20 Aug 1, 2024
d8cab68
fix to _df_all and _df_concat
adamamer20 Aug 4, 2024
98bfe11
adding right overload to concat
adamamer20 Aug 4, 2024
015edee
Merge branch 'main' of https://github.com/adamamer20/mesa-frames into…
adamamer20 Aug 4, 2024
77ad379
change _df_filter to _df_get_masked_df + _df_all
adamamer20 Aug 4, 2024
9f757c8
adding custom name to _df_groupby_cumcount
adamamer20 Aug 4, 2024
7fe1933
Merge branch 'main' of https://github.com/adamamer20/mesa-frames into…
adamamer20 Aug 5, 2024
b4f757a
Merge branch 'main' of https://github.com/adamamer20/mesa-frames into…
adamamer20 Aug 12, 2024
3186a12
adding tests for PolarsMixin
adamamer20 Aug 12, 2024
93c7db3
fixes to PolarsMixins
adamamer20 Aug 12, 2024
f3543b9
Merge branch 'main' of https://github.com/adamamer20/mesa-frames into…
adamamer20 Aug 12, 2024
1ec89ae
Merge branch 'main' of https://github.com/adamamer20/mesa-frames into…
adamamer20 Aug 12, 2024
7280243
Merge branch 'main' into 53-tests-for-polarsmixin
adamamer20 Aug 12, 2024
34bd845
Merge branch '53-tests-for-polarsmixin' of https://github.com/adamame…
adamamer20 Aug 12, 2024
e3e6d91
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
703b561
adding index_cols to _df_join per abstract DataFrameMixin
adamamer20 Aug 12, 2024
7746cc8
Merge branch '53-tests-for-polarsmixin' of https://github.com/adamame…
adamamer20 Aug 12, 2024
bea064a
Merge branch '53-tests-for-polarsmixin' of https://github.com/adamame…
adamamer20 Aug 12, 2024
089900d
adding greater or equal to DataFrameMixin
adamamer20 Aug 12, 2024
e3423a9
adding boolean or to DataFrameMixin
adamamer20 Aug 12, 2024
58a149d
adding less than to DataFrameMixin
adamamer20 Aug 12, 2024
7c0f3da
adding _df_and and _df_or
adamamer20 Aug 12, 2024
ca42fe5
using mixin logical operations in space
adamamer20 Aug 12, 2024
e0914ee
fixing collections mixin
adamamer20 Aug 12, 2024
7e8671c
Merge branch 'comparisons-mixin' of https://github.com/adamamer20/mes…
adamamer20 Aug 12, 2024
6187c8d
adding modulus to DataFrameMixin
adamamer20 Aug 13, 2024
5ba1bbe
adding index to DataFrameMixin
adamamer20 Aug 13, 2024
9ea50e0
fixes to method logic
adamamer20 Aug 13, 2024
a096a5d
Merge branch '53-tests-for-polarsmixin' of https://github.com/adamame…
adamamer20 Aug 13, 2024
5081d18
Merge branch 'comparisons-mixin' of https://github.com/adamamer20/mes…
adamamer20 Aug 13, 2024
c89a7ef
changing % to _df_mod in abstract space
adamamer20 Aug 13, 2024
9ab6b9d
Merge branch 'comparisons-mixin' of https://github.com/adamamer20/mes…
adamamer20 Aug 13, 2024
f715ec7
Merge branch 'comparisons-mixin' of https://github.com/adamamer20/mes…
adamamer20 Aug 13, 2024
b8f2399
Merge branch '35-concrete-gridpolars' of https://github.com/adamamer2…
adamamer20 Aug 13, 2024
4d603d3
adding reindexing DataFrameMixin
adamamer20 Aug 13, 2024
36962cc
Merge branch 'comparisons-mixin' of https://github.com/adamamer20/mes…
adamamer20 Aug 13, 2024
89a781d
Merge branch 'main' of https://github.com/adamamer20/mesa-frames into…
adamamer20 Aug 13, 2024
d442fe1
fixing GridPolars logic (right now, the logic is duplicated from Grid…
adamamer20 Aug 13, 2024
e2643be
small fixes to GridDF to add GridPolars
adamamer20 Aug 13, 2024
c7c141a
adding tests for GridPolars
adamamer20 Aug 13, 2024
54e25d9
fixes to the logic of some methods
adamamer20 Aug 13, 2024
9b4c2df
Merge branch 'main' into 35-concrete-gridpolars
adamamer20 Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mesa_frames/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from mesa_frames.concrete.pandas.agentset import AgentSetPandas
from mesa_frames.concrete.pandas.space import GridPandas
from mesa_frames.concrete.polars.agentset import AgentSetPolars
from mesa_frames.concrete.polars.space import GridPolars

__all__ = [
"AgentsDF",
"AgentSetPandas",
"AgentSetPolars",
"ModelDF",
"GridPandas",
"GridPolars",
]
55 changes: 41 additions & 14 deletions mesa_frames/abstract/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,14 +435,20 @@
self, agents: IdsLike | AgentContainer | Collection[AgentContainer]
) -> Series:
if isinstance(agents, AgentSetDF):
return self._srs_constructor(agents.index, name="agent_id")
return self._srs_constructor(
self._df_index(agents, "unique_id"), name="agent_id"
)
elif isinstance(agents, AgentsDF):
return self._srs_constructor(agents._ids, name="agent_id")
elif isinstance(agents, Collection) and (isinstance(agents[0], AgentContainer)):
ids = []
for a in agents:
if isinstance(a, AgentSetDF):
ids.append(self._srs_constructor(a.index, name="agent_id"))
ids.append(
self._srs_constructor(
self._df_index(a, "unique_id"), name="agent_id"
)
)
elif isinstance(a, AgentsDF):
ids.append(self._srs_constructor(a._ids, name="agent_id"))
return self._df_concat(ids, ignore_index=True)
Expand Down Expand Up @@ -752,8 +758,10 @@
)

if properties:
properties = obj._df_constructor(data=properties, index=cells_df.index)
cells_df = obj._df_concat([cells_df, properties], how="horizontal")
properties = obj._df_constructor(
data=properties, index=self._df_index(cells_df, obj._pos_col_names)
)
cells_df = obj._df_join(cells_df, properties, on=obj._pos_col_names)

if "capacity" in cells_col_names:
obj._cells_capacity = obj._update_capacity_cells(cells_df)
Expand Down Expand Up @@ -1146,11 +1154,12 @@
self._agents = self._df_constructor(
columns=["agent_id"] + self._pos_col_names,
index_cols="agent_id",
dtypes={col: int for col in self._pos_col_names},
dtypes={col: int for col in ["agent_id"] + self._pos_col_names},
)
self._cells = self._df_constructor(
columns=self._pos_col_names + ["capacity"],
index_cols=self._pos_col_names,
dtypes={col: int for col in self._pos_col_names + ["capacity"]},
)
self._offsets = self._compute_offsets(neighborhood_type)
self._cells_capacity = self._generate_empty_grid(dimensions, capacity)
Expand Down Expand Up @@ -1216,15 +1225,18 @@
# If radius is a sequence, get the maximum radius (we will drop unnecessary neighbors later, time-efficient but memory-inefficient)
if isinstance(radius, Sequence):
radius_srs = self._srs_constructor(radius, name="radius")
radius_df = self._srs_to_df(radius_srs)
max_radius = radius_srs.max()
else:
max_radius = radius

range_srs = self._srs_range(name="radius", start=1, end=max_radius + 1)
range_df = self._srs_to_df(
self._srs_range(name="radius", start=1, end=max_radius + 1)
)

neighbors_df = self._df_join(
self._offsets,
range_srs,
range_df,
how="cross",
)

Expand Down Expand Up @@ -1293,6 +1305,7 @@
neighbors_df = self._df_concat(
[neighbors_df, in_between_df], how="vertical"
)
radius_df = self._df_drop_columns(radius_df, "offset")

neighbors_df = self._df_join(
neighbors_df, pos_df, how="cross", suffix="_center"
Expand All @@ -1316,10 +1329,11 @@
# If radius is a sequence, filter unnecessary neighbors
if isinstance(radius, Sequence):
radius_df = self._df_rename_columns(
self._df_concat([pos_df, radius_srs], how="horizontal"),
self._df_concat([pos_df, radius_df], how="horizontal"),
self._pos_col_names + ["radius"],
self._center_col_names + ["max_radius"],
)

neighbors_df = self._df_join(
neighbors_df,
radius_df,
Expand Down Expand Up @@ -1358,15 +1372,14 @@
pos_df = self._df_with_columns(
pos_df,
data=0,
new_columns=["radius"],
new_columns="radius",
)
pos_df = self._df_concat([pos_df, center_df], how="horizontal")

neighbors_df = self._df_concat(
[pos_df, neighbors_df], how="vertical", ignore_index=True
)

neighbors_df = self._df_reset_index(neighbors_df, drop=True)
return neighbors_df

def get_cells(
Expand Down Expand Up @@ -1422,7 +1435,9 @@
),
name="out_of_bounds",
)
return self._df_concat(objs=[pos_df, out_of_bounds], how="horizontal")
return self._df_concat(
objs=[pos_df, self._srs_to_df(out_of_bounds)], how="horizontal"
)

def remove_agents(
self,
Expand Down Expand Up @@ -1626,7 +1641,11 @@
and (len(pos[0]) == len(self._dimensions))
): # We only test the first coordinate for performance
# This means that we have a collection of coordinates
return self._df_constructor(data=pos, columns=self._pos_col_names)
return self._df_constructor(
data=pos,
columns=self._pos_col_names,
dtypes={col: int for col in self._pos_col_names},
)
elif isinstance(pos, Sequence) and len(pos) == len(self._dimensions):
# This means that the sequence is already a sequence where each element is the
# sequence of coordinates for dimension i
Expand All @@ -1636,9 +1655,17 @@
step = c.step if c.step is not None else 1
stop = c.stop if c.stop is not None else self._dimensions[i]
pos[i] = self._srs_range(start=start, stop=stop, step=step)
return self._df_constructor(data=[pos], columns=self._pos_col_names)
return self._df_constructor(
data=[pos],
columns=self._pos_col_names,
dtypes={col: int for col in self._pos_col_names},
)
elif isinstance(pos, int) and len(self._dimensions) == 1:
return self._df_constructor(data=[pos], columns=self._pos_col_names)
return self._df_constructor(

Check warning on line 1664 in mesa_frames/abstract/space.py

View check run for this annotation

Codecov / codecov/patch

mesa_frames/abstract/space.py#L1664

Added line #L1664 was not covered by tests
data=[pos],
columns=self._pos_col_names,
dtypes={col: int for col in self._pos_col_names},
)
else:
raise ValueError("Invalid coordinates")

Expand Down
24 changes: 15 additions & 9 deletions mesa_frames/concrete/polars/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,18 @@ def _df_combine_first(
new_df: pl.DataFrame,
index_cols: str | list[str],
) -> pl.DataFrame:
original_df = original_df.with_columns(_index=pl.int_range(0, len(original_df)))
common_cols = set(original_df.columns) & set(new_df.columns)
merged_df = original_df.join(new_df, on=index_cols, how="full", suffix="_right")
merged_df = merged_df.with_columns(
pl.coalesce(pl.col(col), pl.col(f"{col}_right")).alias(col)
for col in common_cols
).select(pl.exclude("^.*_right$"))
merged_df = (
merged_df.with_columns(
pl.coalesce(pl.col(col), pl.col(f"{col}_right")).alias(col)
for col in common_cols
)
.select(pl.exclude("^.*_right$"))
.sort("_index")
.drop("_index")
)
return merged_df

@overload
Expand Down Expand Up @@ -219,7 +225,7 @@ def _df_ge(
def _df_get_bool_mask(
self,
df: pl.DataFrame,
index_cols: str | list[str],
index_cols: str | list[str] | None = None,
mask: PolarsMask = None,
negate: bool = False,
) -> pl.Series | pl.Expr:
Expand All @@ -234,10 +240,10 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series:
return df[index_cols].is_in(mask)

def bool_mask_from_df(mask: pl.DataFrame) -> pl.Series:
assert index_cols, list[str]
mask = mask[index_cols].unique()
mask = mask.with_columns(in_it=True)
return df.join(mask[index_cols + ["in_it"]], on=index_cols, how="left")[
"in_it"
].fill_null(False)
return df.join(mask, on=index_cols, how="left")["in_it"].fill_null(False)

if isinstance(mask, pl.Expr):
result = mask
Expand Down Expand Up @@ -269,7 +275,7 @@ def bool_mask_from_df(mask: pl.DataFrame) -> pl.Series:
def _df_get_masked_df(
self,
df: pl.DataFrame,
index_cols: str,
index_cols: str | list[str] | None = None,
mask: PolarsMask | None = None,
columns: list[str] | None = None,
negate: bool = False,
Expand Down
170 changes: 170 additions & 0 deletions mesa_frames/concrete/polars/space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from collections.abc import Callable, Sequence

import numpy as np
import polars as pl
from typing import Literal

from mesa_frames.abstract.space import GridDF
from mesa_frames.concrete.polars.mixin import PolarsMixin


class GridPolars(GridDF, PolarsMixin):
_agents: pl.DataFrame
_copy_with_method: dict[str, tuple[str, list[str]]] = {
"_agents": ("clone", []),
"_cells": ("clone", []),
"_cells_capacity": ("copy", []),
"_offsets": ("clone", []),
}
_cells: pl.DataFrame
_cells_capacity: np.ndarray
_offsets: pl.DataFrame

def _empty_cell_condition(self, cap: np.ndarray) -> np.ndarray:
# Create a boolean mask of the same shape as cap
empty_mask = np.ones_like(cap, dtype=bool)

if not self._agents.is_empty():
# Get the coordinates of all agents
agent_coords = self._agents[self._pos_col_names].to_numpy()

# Mark cells containing agents as not empty
empty_mask[tuple(agent_coords.T)] = False

return empty_mask

def _generate_empty_grid(
self, dimensions: Sequence[int], capacity: int
) -> np.ndarray:
if not capacity:
capacity = np.inf
return np.full(dimensions, capacity)

def _sample_cells(
self,
n: int | None,
with_replacement: bool,
condition: Callable[[np.ndarray], np.ndarray],
respect_capacity: bool = True,
) -> pl.DataFrame:
# Get the coordinates of cells that meet the condition
coords = np.array(np.where(condition(self._cells_capacity))).T

if respect_capacity and condition != self._full_cell_condition:
capacities = self._cells_capacity[tuple(coords.T)]
else:
# If not respecting capacity or for full cells, set capacities to 1
capacities = np.ones(len(coords), dtype=int)

if n is not None:
if with_replacement:
if respect_capacity and condition != self._full_cell_condition:
assert (
n <= capacities.sum()
), "Requested sample size exceeds the total available capacity."

sampled_coords = np.empty((0, coords.shape[1]), dtype=coords.dtype)
while len(sampled_coords) < n:
remaining_samples = n - len(sampled_coords)
sampled_indices = self.random.choice(
len(coords),
size=remaining_samples,
replace=True,
)
unique_indices, counts = np.unique(
sampled_indices, return_counts=True
)

if respect_capacity and condition != self._full_cell_condition:
# Calculate valid counts for each unique index
valid_counts = np.minimum(counts, capacities[unique_indices])
# Update capacities
capacities[unique_indices] -= valid_counts
else:
valid_counts = counts

Check warning on line 84 in mesa_frames/concrete/polars/space.py

View check run for this annotation

Codecov / codecov/patch

mesa_frames/concrete/polars/space.py#L84

Added line #L84 was not covered by tests

# Create array of repeated coordinates
new_coords = np.repeat(coords[unique_indices], valid_counts, axis=0)
# Extend sampled_coords
sampled_coords = np.vstack((sampled_coords, new_coords))

if respect_capacity and condition != self._full_cell_condition:
# Update coords and capacities
mask = capacities > 0
coords = coords[mask]
capacities = capacities[mask]

sampled_coords = sampled_coords[:n]
self.random.shuffle(sampled_coords)
else:
assert n <= len(
coords
), "Requested sample size exceeds the number of available cells."
sampled_indices = self.random.choice(len(coords), size=n, replace=False)
sampled_coords = coords[sampled_indices]
else:
sampled_coords = coords

# Convert the coordinates to a DataFrame
sampled_cells = pl.DataFrame(
sampled_coords, schema=self._pos_col_names, orient="row"
)
return sampled_cells

def _update_capacity_agents(
self,
agents: pl.DataFrame,
operation: Literal["movement", "removal"],
) -> np.ndarray:
# Update capacity for agents that were already on the grid
masked_df = self._df_get_masked_df(
self._agents, index_cols="agent_id", mask=agents
)

if operation == "movement":
# Increase capacity at old positions
old_positions = tuple(masked_df[self._pos_col_names].to_numpy().T)
np.add.at(self._cells_capacity, old_positions, 1)

# Decrease capacity at new positions
new_positions = tuple(agents[self._pos_col_names].to_numpy().T)
np.add.at(self._cells_capacity, new_positions, -1)
elif operation == "removal":
# Increase capacity at the positions of removed agents
positions = tuple(masked_df[self._pos_col_names].to_numpy().T)
np.add.at(self._cells_capacity, positions, 1)
return self._cells_capacity

def _update_capacity_cells(self, cells: pl.DataFrame) -> np.ndarray:
# Get the coordinates of the cells to update
coords = cells[self._pos_col_names]

# Get the current capacity of updatable cells
current_capacity = (
coords.join(self._cells, on=self._pos_col_names, how="left")
.fill_null(self._capacity)["capacity"]
.to_numpy()
)

# Calculate the number of agents currently in each cell
agents_in_cells = (
current_capacity - self._cells_capacity[tuple(zip(*coords.to_numpy()))]
)

# Update the capacity in self._cells_capacity
new_capacity = cells["capacity"].to_numpy() - agents_in_cells

# Assert that no new capacity is negative
assert np.all(
new_capacity >= 0
), "New capacity of a cell cannot be less than the number of agents in it."

self._cells_capacity[tuple(zip(*coords.to_numpy()))] = new_capacity

return self._cells_capacity

@property
def remaining_capacity(self) -> int:
if not self._capacity:
return np.inf
return self._cells_capacity.sum()
Loading