forked from tinkoff-ai/etna
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement
SavePredictionIntervalsMixin
(#87)
* added save/load mixin * removed placeholders from base class * moved fixtures * added tests * added tests * updated changelog * review fixes
- Loading branch information
Showing
6 changed files
with
232 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import pathlib | ||
import tempfile | ||
import zipfile | ||
from typing import Optional | ||
|
||
from typing_extensions import Self | ||
|
||
from etna.core import SaveMixin | ||
from etna.core import load | ||
from etna.datasets import TSDataset | ||
from etna.pipeline import BasePipeline | ||
|
||
|
||
class SavePredictionIntervalsMixin(SaveMixin): | ||
"""Implementation of ``AbstractSaveable`` abstract class for prediction intervals with pipelines inside. | ||
It saves object to the zip archive with 3 entities: | ||
* metadata.json: contains library version and class name. | ||
* object.pkl: pickled without pipeline and ts. | ||
* pipeline.zip: pipeline archive, saved with its own method. | ||
""" | ||
|
||
def save(self, path: pathlib.Path): | ||
"""Save the object. | ||
Parameters | ||
---------- | ||
path: | ||
Path to save object to. | ||
""" | ||
self.pipeline: BasePipeline | ||
|
||
pipeline = self.pipeline | ||
|
||
try: | ||
# extract pipeline to save it with its own method later | ||
delattr(self, "pipeline") | ||
|
||
# save the remaining part | ||
super().save(path=path) | ||
|
||
finally: | ||
self.pipeline = pipeline | ||
|
||
with zipfile.ZipFile(path, "a") as archive: | ||
with tempfile.TemporaryDirectory() as _temp_dir: | ||
temp_dir = pathlib.Path(_temp_dir) | ||
|
||
# save pipeline separately and add to the archive | ||
pipeline_save_path = temp_dir / "pipeline.zip" | ||
pipeline.save(path=pipeline_save_path) | ||
|
||
archive.write(pipeline_save_path, "pipeline.zip") | ||
|
||
@classmethod | ||
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self: | ||
"""Load an object. | ||
Warning | ||
------- | ||
This method uses :py:mod:`dill` module which is not secure. | ||
It is possible to construct malicious data which will execute arbitrary code during loading. | ||
Never load data that could have come from an untrusted source, or that could have been tampered with. | ||
Parameters | ||
---------- | ||
path: | ||
Path to load object from. | ||
ts: | ||
TSDataset to set into loaded pipeline. | ||
Returns | ||
------- | ||
: | ||
Loaded object. | ||
""" | ||
obj = super().load(path=path) | ||
|
||
with zipfile.ZipFile(path, "r") as archive: | ||
with tempfile.TemporaryDirectory() as _temp_dir: | ||
temp_dir = pathlib.Path(_temp_dir) | ||
|
||
archive.extractall(temp_dir) | ||
|
||
# load pipeline and add to the object | ||
pipeline_path = temp_dir / "pipeline.zip" | ||
obj.pipeline = load(path=pipeline_path, ts=ts) | ||
|
||
return obj |
14 changes: 14 additions & 0 deletions
14
tests/test_experimental/test_prediction_intervals/conftest.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import pytest | ||
|
||
from tests.test_experimental.test_prediction_intervals.common import get_naive_pipeline | ||
from tests.test_experimental.test_prediction_intervals.common import get_naive_pipeline_with_transforms | ||
|
||
|
||
@pytest.fixture() | ||
def naive_pipeline(): | ||
return get_naive_pipeline(horizon=5) | ||
|
||
|
||
@pytest.fixture() | ||
def naive_pipeline_with_transforms(): | ||
return get_naive_pipeline_with_transforms(horizon=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
98 changes: 98 additions & 0 deletions
98
tests/test_experimental/test_prediction_intervals/test_mixins.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import pathlib | ||
import pickle | ||
import zipfile | ||
from copy import deepcopy | ||
from unittest.mock import patch | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
from etna.models import NaiveModel | ||
from etna.pipeline import Pipeline | ||
from tests.test_experimental.test_prediction_intervals.common import DummyPredictionIntervals | ||
|
||
|
||
@pytest.mark.parametrize("expected_filenames", ({"metadata.json", "object.pkl", "pipeline.zip"},)) | ||
def test_save(naive_pipeline_with_transforms, example_tsds, tmp_path, expected_filenames): | ||
dummy = DummyPredictionIntervals(pipeline=naive_pipeline_with_transforms, width=4) | ||
|
||
path = pathlib.Path(tmp_path) / "dummy.zip" | ||
|
||
initial_dummy = deepcopy(dummy) | ||
dummy.save(path) | ||
|
||
with zipfile.ZipFile(path, "r") as archive: | ||
files = archive.namelist() | ||
assert set(files) == expected_filenames | ||
|
||
with archive.open("object.pkl", "r") as file: | ||
loaded_obj = pickle.load(file) | ||
assert loaded_obj.width == dummy.width | ||
|
||
# basic check that we didn't break dummy object itself | ||
assert dummy.width == initial_dummy.width | ||
assert pickle.dumps(dummy.ts) == pickle.dumps(initial_dummy.ts) | ||
assert pickle.dumps(dummy.pipeline.model) == pickle.dumps(initial_dummy.pipeline.model) | ||
assert pickle.dumps(dummy.pipeline.transforms) == pickle.dumps(initial_dummy.pipeline.transforms) | ||
|
||
|
||
def test_load_file_not_found_error(): | ||
non_existent_path = pathlib.Path("archive.zip") | ||
with pytest.raises(FileNotFoundError): | ||
DummyPredictionIntervals.load(non_existent_path) | ||
|
||
|
||
def test_load_with_ts(naive_pipeline_with_transforms, example_tsds, recwarn, tmp_path): | ||
dummy = DummyPredictionIntervals(pipeline=naive_pipeline_with_transforms, width=4) | ||
|
||
path = pathlib.Path(tmp_path) / "dummy.zip" | ||
dummy.save(path) | ||
|
||
loaded_obj = DummyPredictionIntervals.load(path=path, ts=example_tsds) | ||
|
||
assert loaded_obj.width == dummy.width | ||
assert loaded_obj.ts is not dummy.ts | ||
pd.testing.assert_frame_equal(loaded_obj.ts.to_pandas(), example_tsds.to_pandas()) | ||
assert isinstance(loaded_obj.pipeline, Pipeline) | ||
assert isinstance(loaded_obj.pipeline.model, NaiveModel) | ||
assert len(loaded_obj.pipeline.transforms) == 2 | ||
assert len(recwarn) == 0 | ||
|
||
|
||
def test_load_without_ts(naive_pipeline_with_transforms, recwarn, tmp_path): | ||
dummy = DummyPredictionIntervals(pipeline=naive_pipeline_with_transforms, width=4) | ||
|
||
path = pathlib.Path(tmp_path) / "dummy.zip" | ||
dummy.save(path) | ||
|
||
loaded_obj = DummyPredictionIntervals.load(path=path) | ||
|
||
assert loaded_obj.width == dummy.width | ||
assert loaded_obj.ts is None | ||
assert isinstance(loaded_obj.pipeline, Pipeline) | ||
assert isinstance(loaded_obj.pipeline.model, NaiveModel) | ||
assert len(loaded_obj.pipeline.transforms) == 2 | ||
assert len(recwarn) == 0 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"save_version, load_version", [((1, 5, 0), (2, 5, 0)), ((2, 5, 0), (1, 5, 0)), ((1, 5, 0), (1, 3, 0))] | ||
) | ||
@patch("etna.core.mixins.get_etna_version") | ||
def test_save_mixin_load_warning( | ||
get_version_mock, naive_pipeline_with_transforms, save_version, load_version, tmp_path | ||
): | ||
dummy = DummyPredictionIntervals(pipeline=naive_pipeline_with_transforms, width=4) | ||
path = pathlib.Path(tmp_path) / "dummy.zip" | ||
|
||
get_version_mock.return_value = save_version | ||
dummy.save(path) | ||
|
||
save_version_str = ".".join(map(str, save_version)) | ||
load_version_str = ".".join(map(str, load_version)) | ||
with pytest.warns( | ||
UserWarning, | ||
match=f"The object was saved under etna version {save_version_str} but running version is {load_version_str}", | ||
): | ||
get_version_mock.return_value = load_version | ||
_ = DummyPredictionIntervals.load(path=path) |