From 1ac5b97318663926686a4d8584f174dafe1f3449 Mon Sep 17 00:00:00 2001 From: Maxim Zherelo <60392282+brsnw250@users.noreply.github.com> Date: Wed, 19 Jun 2024 17:22:41 +0300 Subject: [PATCH] removed `TSDataset` (#397) --- etna/pipeline/base.py | 12 ++++++------ tests/test_pipeline/test_pipeline.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/etna/pipeline/base.py b/etna/pipeline/base.py index 8066b8d05..888434cdf 100644 --- a/etna/pipeline/base.py +++ b/etna/pipeline/base.py @@ -463,7 +463,7 @@ def _forecast_prediction_interval( return predictions @staticmethod - def _validate_residuals_for_interval_estimation(backtest_forecasts: TSDataset, residuals: pd.DataFrame): + def _validate_residuals_for_interval_estimation(backtest_forecasts: pd.DataFrame, residuals: pd.DataFrame): len_backtest, num_segments = residuals.shape min_timestamp = backtest_forecasts.index.min() max_timestamp = backtest_forecasts.index.max() @@ -485,11 +485,11 @@ def _add_forecast_borders( self, ts: TSDataset, backtest_forecasts: pd.DataFrame, quantiles: Sequence[float], predictions: TSDataset ) -> None: """Estimate prediction intervals and add to the forecasts.""" - backtest_forecasts = TSDataset(df=backtest_forecasts, freq=ts.freq) - residuals = ( - backtest_forecasts.loc[:, pd.IndexSlice[:, "target"]] - - ts[backtest_forecasts.index.min() : backtest_forecasts.index.max(), :, "target"] - ) + target = ts[backtest_forecasts.index.min() : backtest_forecasts.index.max(), :, "target"] + if not backtest_forecasts.index.equals(target.index): + raise ValueError("Historical backtest timestamps must match with the original dataset timestamps!") + + residuals = backtest_forecasts.loc[:, pd.IndexSlice[:, "target"]] - target self._validate_residuals_for_interval_estimation(backtest_forecasts=backtest_forecasts, residuals=residuals) sigma = np.nanstd(residuals.values, axis=0) diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index 200268c49..b8dc0e917 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -336,6 +336,23 @@ def test_forecast_prediction_interval_not_builtin_with_nans_error(example_tsds, _ = pipeline.forecast(prediction_interval=True, quantiles=[0.025, 0.975]) +@pytest.mark.filterwarnings("ignore: There are NaNs in target on time span from .* to .*") +@pytest.mark.parametrize("model", (MovingAverageModel(),)) +@pytest.mark.parametrize("stride", (1, 4, 6)) +def test_add_forecast_borders_overlapping_timestamps(example_tsds, model, stride): + example_tsds.df.loc[example_tsds.index[-20:-1], pd.IndexSlice["segment_1", "target"]] = None + + pipeline = Pipeline(model=model, transforms=[DateFlagsTransform()], horizon=5) + pipeline.fit(example_tsds) + + forecasts = pipeline.get_historical_forecasts(ts=example_tsds, stride=stride) + + with pytest.raises(ValueError, match="Historical backtest timestamps must match"): + pipeline._add_forecast_borders( + ts=example_tsds, backtest_forecasts=forecasts, quantiles=[0.025, 0.975], predictions=None + ) + + def test_forecast_prediction_interval_correct_values(splited_piecewise_constant_ts): """Test that the prediction interval for piecewise-constant dataset is correct.""" train, test = splited_piecewise_constant_ts