Skip to content

Commit

Permalink
Implement SavePredictionIntervalsMixin (#87)
Browse files Browse the repository at this point in the history
* added save/load mixin

* removed placeholders from base class

* moved fixtures

* added tests

* added tests

* updated changelog

* review fixes
  • Loading branch information
brsnw250 authored Sep 23, 2023
1 parent a8fdd3c commit 4d87da9
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 22 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
- Base class `BasePredictionIntervals` for prediction intervals into experimental module. ([#86](https://github.com/etna-team/etna/pull/86))
- `SavePredictionIntervalsMixin` for the `BasePredictionIntervals` ([#87](https://github.com/etna-team/etna/pull/87))
- Base class `BasePredictionIntervals` for prediction intervals into experimental module ([#86](https://github.com/etna-team/etna/pull/86))
- Add `fit_params` parameter to `etna.models.sarimax.SARIMAXModel` ([#69](https://github.com/etna-team/etna/pull/69))
- Add `quickstart` notebook, add `mechanics_of_forecasting` notebook ([#1343](https://github.com/tinkoff-ai/etna/pull/1343))
- Add gallery of tutorials divided by level ([#46](https://github.com/etna-team/etna/pull/46))
Expand Down
13 changes: 2 additions & 11 deletions etna/experimental/prediction_intervals/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pathlib
from abc import abstractmethod
from typing import Dict
from typing import Optional
Expand All @@ -8,10 +7,11 @@

from etna.datasets import TSDataset
from etna.distributions import BaseDistribution
from etna.experimental.prediction_intervals.mixins import SavePredictionIntervalsMixin
from etna.pipeline.base import BasePipeline


class BasePredictionIntervals(BasePipeline):
class BasePredictionIntervals(SavePredictionIntervalsMixin, BasePipeline):
"""Base class for prediction intervals methods.
This class implements a wrapper interface for pipelines and ensembles that provides the ability to
Expand Down Expand Up @@ -113,15 +113,6 @@ def _forecast(self, ts: TSDataset, return_components: bool) -> TSDataset:
"""Make point forecasts using base pipeline or ensemble."""
return self.pipeline._forecast(ts=ts, return_components=return_components)

def save(self, path: pathlib.Path):
"""Implement in SavePredictionIntervalsMixin."""
pass

@classmethod
def load(cls, path: pathlib.Path):
"""Implement in SavePredictionIntervalsMixin."""
pass

def forecast(
self,
ts: Optional[TSDataset] = None,
Expand Down
92 changes: 92 additions & 0 deletions etna/experimental/prediction_intervals/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pathlib
import tempfile
import zipfile
from typing import Optional

from typing_extensions import Self

from etna.core import SaveMixin
from etna.core import load
from etna.datasets import TSDataset
from etna.pipeline import BasePipeline


class SavePredictionIntervalsMixin(SaveMixin):
"""Implementation of ``AbstractSaveable`` abstract class for prediction intervals with pipelines inside.
It saves object to the zip archive with 3 entities:
* metadata.json: contains library version and class name.
* object.pkl: pickled without pipeline and ts.
* pipeline.zip: pipeline archive, saved with its own method.
"""

def save(self, path: pathlib.Path):
"""Save the object.
Parameters
----------
path:
Path to save object to.
"""
self.pipeline: BasePipeline

pipeline = self.pipeline

try:
# extract pipeline to save it with its own method later
delattr(self, "pipeline")

# save the remaining part
super().save(path=path)

finally:
self.pipeline = pipeline

with zipfile.ZipFile(path, "a") as archive:
with tempfile.TemporaryDirectory() as _temp_dir:
temp_dir = pathlib.Path(_temp_dir)

# save pipeline separately and add to the archive
pipeline_save_path = temp_dir / "pipeline.zip"
pipeline.save(path=pipeline_save_path)

archive.write(pipeline_save_path, "pipeline.zip")

@classmethod
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self:
"""Load an object.
Warning
-------
This method uses :py:mod:`dill` module which is not secure.
It is possible to construct malicious data which will execute arbitrary code during loading.
Never load data that could have come from an untrusted source, or that could have been tampered with.
Parameters
----------
path:
Path to load object from.
ts:
TSDataset to set into loaded pipeline.
Returns
-------
:
Loaded object.
"""
obj = super().load(path=path)

with zipfile.ZipFile(path, "r") as archive:
with tempfile.TemporaryDirectory() as _temp_dir:
temp_dir = pathlib.Path(_temp_dir)

archive.extractall(temp_dir)

# load pipeline and add to the object
pipeline_path = temp_dir / "pipeline.zip"
obj.pipeline = load(path=pipeline_path, ts=ts)

return obj
14 changes: 14 additions & 0 deletions tests/test_experimental/test_prediction_intervals/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest

from tests.test_experimental.test_prediction_intervals.common import get_naive_pipeline
from tests.test_experimental.test_prediction_intervals.common import get_naive_pipeline_with_transforms


@pytest.fixture()
def naive_pipeline():
return get_naive_pipeline(horizon=5)


@pytest.fixture()
def naive_pipeline_with_transforms():
return get_naive_pipeline_with_transforms(horizon=5)
34 changes: 24 additions & 10 deletions tests/test_experimental/test_prediction_intervals/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tests.test_experimental.test_prediction_intervals.common import get_naive_pipeline
from tests.test_experimental.test_prediction_intervals.common import get_naive_pipeline_with_transforms
from tests.test_experimental.test_prediction_intervals.utils import assert_sampling_is_valid
from tests.test_pipeline.utils import assert_pipeline_equals_loaded_original


def run_base_pipeline_compat_check(ts, pipeline, expected_columns):
Expand All @@ -35,16 +36,6 @@ def run_base_pipeline_compat_check(ts, pipeline, expected_columns):
assert np.sum(intervals_pipeline_pred.df.isna().values) == 0


@pytest.fixture()
def naive_pipeline():
return get_naive_pipeline(horizon=5)


@pytest.fixture()
def naive_pipeline_with_transforms():
return get_naive_pipeline_with_transforms(horizon=5)


def test_pipeline_ref_initialized(naive_pipeline):
intervals_pipeline = DummyPredictionIntervals(pipeline=naive_pipeline)

Expand Down Expand Up @@ -215,3 +206,26 @@ def test_default_params_to_tune_error(pipeline):

with pytest.raises(NotImplementedError, match=f"{pipeline.__class__.__name__} doesn't support"):
_ = intervals_pipeline.params_to_tune()


@pytest.mark.parametrize("load_ts", (True, False))
@pytest.mark.parametrize(
"pipeline",
(
Pipeline(model=LinearPerSegmentModel(), transforms=[DateFlagsTransform()]),
AutoRegressivePipeline(model=LinearPerSegmentModel(), transforms=[DateFlagsTransform()], horizon=1),
HierarchicalPipeline(
model=LinearPerSegmentModel(),
transforms=[DateFlagsTransform()],
horizon=1,
reconciliator=BottomUpReconciliator(target_level="total", source_level="market"),
),
DirectEnsemble(pipelines=[get_naive_pipeline(horizon=1), get_naive_pipeline_with_transforms(horizon=2)]),
VotingEnsemble(pipelines=[get_naive_pipeline(horizon=1), get_naive_pipeline_with_transforms(horizon=1)]),
StackingEnsemble(pipelines=[get_naive_pipeline(horizon=1), get_naive_pipeline_with_transforms(horizon=1)]),
),
)
def test_save_load(load_ts, pipeline, market_level_constant_hierarchical_ts_w_exog):
ts = market_level_constant_hierarchical_ts_w_exog
intervals_pipeline = DummyPredictionIntervals(pipeline=pipeline)
assert_pipeline_equals_loaded_original(pipeline=intervals_pipeline, ts=ts, load_ts=load_ts)
98 changes: 98 additions & 0 deletions tests/test_experimental/test_prediction_intervals/test_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import pathlib
import pickle
import zipfile
from copy import deepcopy
from unittest.mock import patch

import pandas as pd
import pytest

from etna.models import NaiveModel
from etna.pipeline import Pipeline
from tests.test_experimental.test_prediction_intervals.common import DummyPredictionIntervals


@pytest.mark.parametrize("expected_filenames", ({"metadata.json", "object.pkl", "pipeline.zip"},))
def test_save(naive_pipeline_with_transforms, example_tsds, tmp_path, expected_filenames):
dummy = DummyPredictionIntervals(pipeline=naive_pipeline_with_transforms, width=4)

path = pathlib.Path(tmp_path) / "dummy.zip"

initial_dummy = deepcopy(dummy)
dummy.save(path)

with zipfile.ZipFile(path, "r") as archive:
files = archive.namelist()
assert set(files) == expected_filenames

with archive.open("object.pkl", "r") as file:
loaded_obj = pickle.load(file)
assert loaded_obj.width == dummy.width

# basic check that we didn't break dummy object itself
assert dummy.width == initial_dummy.width
assert pickle.dumps(dummy.ts) == pickle.dumps(initial_dummy.ts)
assert pickle.dumps(dummy.pipeline.model) == pickle.dumps(initial_dummy.pipeline.model)
assert pickle.dumps(dummy.pipeline.transforms) == pickle.dumps(initial_dummy.pipeline.transforms)


def test_load_file_not_found_error():
non_existent_path = pathlib.Path("archive.zip")
with pytest.raises(FileNotFoundError):
DummyPredictionIntervals.load(non_existent_path)


def test_load_with_ts(naive_pipeline_with_transforms, example_tsds, recwarn, tmp_path):
dummy = DummyPredictionIntervals(pipeline=naive_pipeline_with_transforms, width=4)

path = pathlib.Path(tmp_path) / "dummy.zip"
dummy.save(path)

loaded_obj = DummyPredictionIntervals.load(path=path, ts=example_tsds)

assert loaded_obj.width == dummy.width
assert loaded_obj.ts is not dummy.ts
pd.testing.assert_frame_equal(loaded_obj.ts.to_pandas(), example_tsds.to_pandas())
assert isinstance(loaded_obj.pipeline, Pipeline)
assert isinstance(loaded_obj.pipeline.model, NaiveModel)
assert len(loaded_obj.pipeline.transforms) == 2
assert len(recwarn) == 0


def test_load_without_ts(naive_pipeline_with_transforms, recwarn, tmp_path):
dummy = DummyPredictionIntervals(pipeline=naive_pipeline_with_transforms, width=4)

path = pathlib.Path(tmp_path) / "dummy.zip"
dummy.save(path)

loaded_obj = DummyPredictionIntervals.load(path=path)

assert loaded_obj.width == dummy.width
assert loaded_obj.ts is None
assert isinstance(loaded_obj.pipeline, Pipeline)
assert isinstance(loaded_obj.pipeline.model, NaiveModel)
assert len(loaded_obj.pipeline.transforms) == 2
assert len(recwarn) == 0


@pytest.mark.parametrize(
"save_version, load_version", [((1, 5, 0), (2, 5, 0)), ((2, 5, 0), (1, 5, 0)), ((1, 5, 0), (1, 3, 0))]
)
@patch("etna.core.mixins.get_etna_version")
def test_save_mixin_load_warning(
get_version_mock, naive_pipeline_with_transforms, save_version, load_version, tmp_path
):
dummy = DummyPredictionIntervals(pipeline=naive_pipeline_with_transforms, width=4)
path = pathlib.Path(tmp_path) / "dummy.zip"

get_version_mock.return_value = save_version
dummy.save(path)

save_version_str = ".".join(map(str, save_version))
load_version_str = ".".join(map(str, load_version))
with pytest.warns(
UserWarning,
match=f"The object was saved under etna version {save_version_str} but running version is {load_version_str}",
):
get_version_mock.return_value = load_version
_ = DummyPredictionIntervals.load(path=path)

0 comments on commit 4d87da9

Please sign in to comment.