Skip to content

Commit

Permalink
Fix assign optimization when overwriting columns
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl committed Dec 16, 2024
1 parent 77d0f89 commit 7cea0c9
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
14 changes: 11 additions & 3 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1893,21 +1893,29 @@ def _tree_repr_argument_construction(self, i, op, header):
def _node_label_args(self):
return self.operands

def _remove_common_columns(self, other):
if set(self.keys) & set(other.keys):
keys = set(self.keys)
operands = [[k, v] for k, v in zip(other.keys, other.vals) if k not in keys]
return [other.frame] + list(flatten(operands)) + self.operands[1:]
else:
return other.operands + self.operands[1:]

def _simplify_down(self):
if isinstance(self.frame, Assign):
if self._check_for_previously_created_column(self.frame):
# don't squash if we are using a column that was previously created
return
return Assign(*self.frame.operands, *self.operands[1:])
return Assign(*self._remove_common_columns(self.frame))
elif isinstance(self.frame, Projection) and isinstance(
self.frame.frame, Assign
):
if self._check_for_previously_created_column(self.frame.frame):
return
new_columns = self.frame.operands[1].copy()
new_columns.extend(self.keys)
new_columns.extend([k for k in self.keys if k not in new_columns])
return Projection(
Assign(*self.frame.frame.operands, *self.operands[1:]), new_columns
Assign(*self._remove_common_columns(self.frame.frame)), new_columns
)

def _check_for_previously_created_column(self, child):
Expand Down
27 changes: 27 additions & 0 deletions dask_expr/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2689,6 +2689,33 @@ def apply_func(x):
assert result.expr._depth() == 13.0 # this grew exponentially previously


def test_assign_overwriting_column(df, pdf):
pdf = pd.DataFrame(
{"partner": ["A", np.nan, "C", "A", np.nan, "C", "A", np.nan, "C"]},
dtype="string[pyarrow]",
)

df = from_pandas(pdf, npartitions=2)

df["partner1"] = ""
df["partner1"] = "google"

df["partner2"] = ""
df["partner2"] = np.nan
df["partner2"] = df["partner2"].mask(cond=(df["partner"] == "A"), other="Blackhawk")

pdf["partner1"] = ""
pdf["partner1"] = "google"

pdf["partner2"] = ""
pdf["partner2"] = np.nan
pdf["partner2"] = pdf["partner2"].mask(
cond=(pdf["partner"] == "A"), other="Blackhawk"
)
df.compute()
assert_eq(df, pdf, check_dtype=False)


def test_dropna_merge(df, pdf):
dropped_na = df.dropna(subset=["x"])
result = dropped_na.merge(dropped_na, on="x")
Expand Down

0 comments on commit 7cea0c9

Please sign in to comment.