diff --git a/mesa_frames/abstract/space.py b/mesa_frames/abstract/space.py index 98ce722..ac4c1cd 100644 --- a/mesa_frames/abstract/space.py +++ b/mesa_frames/abstract/space.py @@ -736,6 +736,7 @@ def set_cells( """ obj = self._get_obj(inplace) cells_col_names = obj._df_column_names(obj._cells) + if __debug__: if isinstance(cells, DataFrame) and any( k not in cells_col_names for k in obj._pos_col_names @@ -743,20 +744,22 @@ def set_cells( raise ValueError( f"The cells DataFrame must have the columns {obj._pos_col_names}" ) - if properties: - pos_df = obj._get_df_coords(cells) - properties = obj._df_constructor(data=properties, index=pos_df.index) - cells = obj._df_concat( - [pos_df, properties], how="horizontal", index_cols=obj._pos_col_names - ) + if isinstance(cells, DataFrame): + cells_df = obj._df_set_index(cells, index_name=obj._pos_col_names) else: - cells = obj._df_constructor(data=cells, index_cols=obj._pos_col_names) + cells_df = obj._df_set_index( + obj._get_df_coords(cells), index_name=obj._pos_col_names + ) + + if properties: + properties = obj._df_constructor(data=properties, index=cells_df.index) + cells_df = obj._df_concat([cells_df, properties], how="horizontal") if "capacity" in cells_col_names: - obj._cells_capacity = obj._update_capacity_cells(cells) + obj._cells_capacity = obj._update_capacity_cells(cells_df) obj._cells = obj._df_combine_first( - cells, obj._cells, index_cols=obj._pos_col_names + cells_df, obj._cells, index_cols=obj._pos_col_names ) return obj