Skip to content

Commit

Permalink
change _df_filter to _df_get_masked_df + _df_all
Browse files Browse the repository at this point in the history
  • Loading branch information
adamamer20 committed Aug 12, 2024
1 parent 2d8b2b5 commit 8daedca
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 53 deletions.
17 changes: 4 additions & 13 deletions mesa_frames/abstract/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,9 @@ def _df_add(
def _df_all(
self,
df: DataFrame,
name: str,
name: str = "all",
axis: str = "columns",
index_cols: str | list[str] | None = None,
) -> DataFrame: ...
) -> Series: ...

@abstractmethod
def _df_column_names(self, df: DataFrame) -> list[str]: ...
Expand Down Expand Up @@ -259,19 +258,11 @@ def _df_drop_duplicates(
keep: Literal["first", "last", False] = "first",
) -> DataFrame: ...

@abstractmethod
def _df_filter(
self,
df: DataFrame,
condition: BoolSeries,
all: bool = True,
) -> DataFrame: ...

@abstractmethod
def _df_get_bool_mask(
self,
df: DataFrame,
index_cols: str | list[str],
index_cols: str | list[str] | None = None,
mask: Mask | None = None,
negate: bool = False,
) -> BoolSeries: ...
Expand All @@ -280,7 +271,7 @@ def _df_get_bool_mask(
def _df_get_masked_df(
self,
df: DataFrame,
index_cols: str,
index_cols: str | list[str] | None = None,
mask: Mask | None = None,
columns: str | list[str] | None = None,
negate: bool = False,
Expand Down
10 changes: 4 additions & 6 deletions mesa_frames/abstract/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,8 +1325,8 @@ def get_neighborhood(
radius_df,
on=self._center_col_names,
)
neighbors_df = self._df_filter(
neighbors_df, neighbors_df["radius"] <= neighbors_df["max_radius"]
neighbors_df = self._df_get_masked_df(
neighbors_df, mask=neighbors_df["radius"] <= neighbors_df["max_radius"]
)
neighbors_df = self._df_drop_columns(neighbors_df, "max_radius")

Expand All @@ -1341,13 +1341,12 @@ def get_neighborhood(
neighbors_df = self._df_drop_duplicates(neighbors_df, self._pos_col_names)

# Filter out-of-bound neighbors
neighbors_df = self._df_filter(
neighbors_df = self._df_get_masked_df(
neighbors_df,
(
mask=self._df_all(
(neighbors_df[self._pos_col_names] < self._dimensions)
& (neighbors_df >= 0)
),
all=True,
)

if include_center:
Expand Down Expand Up @@ -1412,7 +1411,6 @@ def out_of_bounds(self, pos: GridCoordinate | GridCoordinates) -> DataFrame:
out_of_bounds = self._df_all(
(pos_df < 0) | (pos_df >= self._dimensions),
name="out_of_bounds",
index_cols=self._pos_col_names,
)
return self._df_concat(objs=[pos_df, out_of_bounds], how="horizontal")

Expand Down
27 changes: 7 additions & 20 deletions mesa_frames/concrete/pandas/mixin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from collections.abc import Collection, Iterator, Sequence
from collections.abc import Collection, Hashable, Iterator, Sequence
from typing import Literal

from collections.abc import Hashable

import numpy as np
import pandas as pd
from typing_extensions import Any, overload
Expand All @@ -24,11 +22,10 @@ def _df_add(
def _df_all(
self,
df: pd.DataFrame,
name: str,
name: str = "all",
axis: str = "columns",
index_cols: str | list[str] | None = None,
) -> pd.DataFrame:
return df.all(axis).to_frame(name)
) -> pd.Series:
return df.all(axis).rename(name)

def _df_column_names(self, df: pd.DataFrame) -> list[str]:
return df.columns.tolist() + df.index.names
Expand Down Expand Up @@ -116,16 +113,6 @@ def _df_contains(
return pd.Series(values).isin(df.index)
return pd.Series(values).isin(df[column])

def _df_filter(
self,
df: pd.DataFrame,
condition: pd.DataFrame,
all: bool = True,
) -> pd.DataFrame:
if all and isinstance(condition, pd.DataFrame):
return df[condition.all(axis=1)]
return df[condition]

def _df_div(
self,
df: pd.DataFrame,
Expand Down Expand Up @@ -153,7 +140,7 @@ def _df_drop_duplicates(
def _df_get_bool_mask(
self,
df: pd.DataFrame,
index_cols: str | list[str],
index_cols: str | list[str] | None = None,
mask: PandasMask = None,
negate: bool = False,
) -> pd.Series:
Expand All @@ -162,7 +149,7 @@ def _df_get_bool_mask(
isinstance(index_cols, list) and df.index.names == index_cols
):
srs = df.index
else:
elif index_cols is not None:
srs = df.set_index(index_cols).index
if isinstance(mask, pd.Series) and mask.dtype == bool and len(mask) == len(df):
mask.index = df.index
Expand Down Expand Up @@ -190,7 +177,7 @@ def _df_get_bool_mask(
def _df_get_masked_df(
self,
df: pd.DataFrame,
index_cols: str,
index_cols: str | list[str] | None = None,
mask: PandasMask | None = None,
columns: str | list[str] | None = None,
negate: bool = False,
Expand Down
20 changes: 6 additions & 14 deletions mesa_frames/concrete/polars/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,12 @@ def _df_add(
def _df_all(
self,
df: pl.DataFrame,
name: str,
axis: str = "columns",
index_cols: str | None = None,
) -> pl.DataFrame:
if axis == "index":
return df.group_by(index_cols).agg(pl.all().all().alias(index_cols))
return df.select(pl.all().all())

def _df_with_columns(
self, original_df: pl.DataFrame, new_columns: list[str], data: Any
) -> pl.DataFrame:
return original_df.with_columns(
**{col: value for col, value in zip(new_columns, data)}
)
name: str = "all",
axis: Literal["index", "columns"] = "columns",
) -> pl.Series:
if axis == "columns":
return df.select(pl.col("*").all()).to_series()
return df.with_columns(all=pl.all_horizontal())["all"]

def _df_column_names(self, df: pl.DataFrame) -> list[str]:
return df.columns
Expand Down

0 comments on commit 8daedca

Please sign in to comment.