Skip to content

Commit

Permalink
Fix ResampleWithDistributionTransform with holidays transform (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
egoriyaa authored Sep 13, 2023
1 parent 15018ed commit adfd18c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Update `CONTRIBUTING.md` with scenarios of documentation updates and release instruction ([#77](https://github.com/etna-team/etna/pull/77))

### Fixed
-
- Fix `ResampleWithDistributionTransform` working with categorical columns ([#82](https://github.com/etna-team/etna/pull/82))
-
- Fix links from tinkoff-ai/etna to etna-team/etna ([#47](https://github.com/etna-team/etna/pull/47))
-
Expand Down
1 change: 1 addition & 0 deletions etna/transforms/missing_values/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
:
result dataframe
"""
df = df.apply(pd.to_numeric)
df["fold"] = self._get_folds(df)
df = df.reset_index().merge(self.distribution, on="fold").set_index("timestamp").sort_index()
df[self.out_column] = df[self.in_column].ffill() * df["distribution"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from etna.transforms.missing_values import ResampleWithDistributionTransform
from etna.transforms import HolidayTransform
from etna.transforms import ResampleWithDistributionTransform
from tests.test_transforms.utils import assert_transformation_equals_loaded_original


Expand Down Expand Up @@ -132,3 +133,10 @@ def test_get_regressors_info_not_fitted():
def test_params_to_tune():
transform = ResampleWithDistributionTransform(in_column="regressor_exog", distribution_column="target")
assert len(transform.params_to_tune()) == 0


def test_working_with_categorical_columns(example_tsds):
holiday = HolidayTransform(out_column="holiday_regressor")
resample = ResampleWithDistributionTransform(distribution_column="target", in_column="holiday_regressor")
holiday.fit_transform(example_tsds)
resample.fit_transform(example_tsds)

0 comments on commit adfd18c

Please sign in to comment.