Skip to content

Commit

Permalink
removed TSDataset (#397)
Browse files Browse the repository at this point in the history
  • Loading branch information
brsnw250 authored Jun 19, 2024
1 parent 7429c39 commit 1ac5b97
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
12 changes: 6 additions & 6 deletions etna/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def _forecast_prediction_interval(
return predictions

@staticmethod
def _validate_residuals_for_interval_estimation(backtest_forecasts: TSDataset, residuals: pd.DataFrame):
def _validate_residuals_for_interval_estimation(backtest_forecasts: pd.DataFrame, residuals: pd.DataFrame):
len_backtest, num_segments = residuals.shape
min_timestamp = backtest_forecasts.index.min()
max_timestamp = backtest_forecasts.index.max()
Expand All @@ -485,11 +485,11 @@ def _add_forecast_borders(
self, ts: TSDataset, backtest_forecasts: pd.DataFrame, quantiles: Sequence[float], predictions: TSDataset
) -> None:
"""Estimate prediction intervals and add to the forecasts."""
backtest_forecasts = TSDataset(df=backtest_forecasts, freq=ts.freq)
residuals = (
backtest_forecasts.loc[:, pd.IndexSlice[:, "target"]]
- ts[backtest_forecasts.index.min() : backtest_forecasts.index.max(), :, "target"]
)
target = ts[backtest_forecasts.index.min() : backtest_forecasts.index.max(), :, "target"]
if not backtest_forecasts.index.equals(target.index):
raise ValueError("Historical backtest timestamps must match with the original dataset timestamps!")

residuals = backtest_forecasts.loc[:, pd.IndexSlice[:, "target"]] - target

self._validate_residuals_for_interval_estimation(backtest_forecasts=backtest_forecasts, residuals=residuals)
sigma = np.nanstd(residuals.values, axis=0)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,23 @@ def test_forecast_prediction_interval_not_builtin_with_nans_error(example_tsds,
_ = pipeline.forecast(prediction_interval=True, quantiles=[0.025, 0.975])


@pytest.mark.filterwarnings("ignore: There are NaNs in target on time span from .* to .*")
@pytest.mark.parametrize("model", (MovingAverageModel(),))
@pytest.mark.parametrize("stride", (1, 4, 6))
def test_add_forecast_borders_overlapping_timestamps(example_tsds, model, stride):
example_tsds.df.loc[example_tsds.index[-20:-1], pd.IndexSlice["segment_1", "target"]] = None

pipeline = Pipeline(model=model, transforms=[DateFlagsTransform()], horizon=5)
pipeline.fit(example_tsds)

forecasts = pipeline.get_historical_forecasts(ts=example_tsds, stride=stride)

with pytest.raises(ValueError, match="Historical backtest timestamps must match"):
pipeline._add_forecast_borders(
ts=example_tsds, backtest_forecasts=forecasts, quantiles=[0.025, 0.975], predictions=None
)


def test_forecast_prediction_interval_correct_values(splited_piecewise_constant_ts):
"""Test that the prediction interval for piecewise-constant dataset is correct."""
train, test = splited_piecewise_constant_ts
Expand Down

0 comments on commit 1ac5b97

Please sign in to comment.