Skip to content

Commit

Permalink
Fix forecast first point in CatBoostPerSegmentModel (#1010)
Browse files Browse the repository at this point in the history
* fix forecast first point in CatBoostPerSegmentModel

* fix test for CatBoostPerSegmentModel

* update changelog

* fix moment with ts.to_pandas()

Co-authored-by: ext.ytarasyuk <[email protected]>
  • Loading branch information
DBcreator and ext.ytarasyuk authored Nov 25, 2022
1 parent a465a2a commit 554d4ea
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
- Fix release docs and docker images cron job ([#982](https://github.com/tinkoff-ai/etna/pull/982))
-
- Fix forecast first point with CatBoostPerSegmentModel ([#1010](https://github.com/tinkoff-ai/etna/pull/1010))
-
-
-
Expand Down
8 changes: 4 additions & 4 deletions etna/models/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,10 @@ def get_model(self) -> Dict[str, Any]:

@staticmethod
def _make_predictions_segment(
model: Any, segment: str, ts: TSDataset, prediction_method: Callable, **kwargs
model: Any, segment: str, df: pd.DataFrame, prediction_method: Callable, **kwargs
) -> pd.DataFrame:
"""Make predictions for one segment."""
segment_features = ts[:, segment, :]
segment_features = segment_features.droplevel("segment", axis=1)
segment_features = df[segment]
segment_features = segment_features.reset_index()
dates = segment_features["timestamp"]
dates.reset_index(drop=True, inplace=True)
Expand Down Expand Up @@ -321,9 +320,10 @@ def _make_predictions(self, ts: TSDataset, prediction_method: Callable, **kwargs
Dataset with predictions
"""
result_list = list()
df = ts.to_pandas()
for segment, model in self._get_model().items():
segment_predict = self._make_predictions_segment(
model=model, segment=segment, ts=ts, prediction_method=prediction_method, **kwargs
model=model, segment=segment, df=df, prediction_method=prediction_method, **kwargs
)

result_list.append(segment_predict)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_inference/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def _test_forecast_in_sample_full_no_target(ts, model, transforms):
"model, transforms",
[
(CatBoostModelMultiSegment(), [LagTransform(in_column="target", lags=[2, 3])]),
(CatBoostModelPerSegment(), [LagTransform(in_column="target", lags=[2, 3])]),
(ProphetModel(), []),
(SARIMAXModel(), []),
(AutoARIMAModel(), []),
Expand All @@ -86,7 +87,6 @@ def test_forecast_in_sample_full_no_target(self, model, transforms, example_tsds
@pytest.mark.parametrize(
"model, transforms",
[
(CatBoostModelPerSegment(), [LagTransform(in_column="target", lags=[2, 3])]),
(RNNModel(input_size=1, encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), []),
],
)
Expand Down

1 comment on commit 554d4ea

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.