Skip to content

Commit

Permalink
Rework saving for DL models (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
d-a-bunin authored Oct 10, 2023
1 parent 799ccb1 commit 2651829
Show file tree
Hide file tree
Showing 33 changed files with 388 additions and 103 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add missing classes from decomposition into API Reference, add modules into page titles in API Reference ([#61](https://github.com/etna-team/etna/pull/61))
- Update `CONTRIBUTING.md` with scenarios of documentation updates and release instruction ([#77](https://github.com/etna-team/etna/pull/77))
- Set up sharding for running tests ([#99](https://github.com/etna-team/etna/pull/99))
- Rework saving DL models by separating saving model's hyperparameters and model's weights ([#98](https://github.com/etna-team/etna/pull/98))

### Fixed
- Fix `ResampleWithDistributionTransform` working with categorical columns ([#82](https://github.com/etna-team/etna/pull/82))
Expand Down
49 changes: 39 additions & 10 deletions etna/core/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __repr__(self):
"""Get default representation of etna object."""
# TODO: add tests default behaviour for all registered objects
args_str_representation = ""
init_args = inspect.signature(self.__init__).parameters
for arg, param in init_args.items():
init_parameters = self._get_init_parameters()
for arg, param in init_parameters.items():
if param.kind == param.VAR_POSITIONAL:
continue
elif param.kind == param.VAR_KEYWORD:
Expand All @@ -43,6 +43,9 @@ def __repr__(self):
args_str_representation += f"{arg} = {repr(value)}, "
return f"{self.__class__.__name__}({args_str_representation})"

def _get_init_parameters(self):
return inspect.signature(self.__init__).parameters

@staticmethod
def _get_target_from_class(value: Any):
if value is None:
Expand Down Expand Up @@ -84,9 +87,9 @@ def _parse_value(value: Any) -> Any:

def to_dict(self):
"""Collect all information about etna object in dict."""
init_args = inspect.signature(self.__init__).parameters
init_parameters = self._get_init_parameters()
params = {}
for arg in init_args.keys():
for arg in init_parameters.keys():
value = self.__dict__[arg]
if value is None:
continue
Expand Down Expand Up @@ -226,21 +229,47 @@ def _save_metadata(self, archive: zipfile.ZipFile):
with archive.open("metadata.json", "w") as output_file:
output_file.write(metadata_bytes)

def _save_state(self, archive: zipfile.ZipFile):
with archive.open("object.pkl", "w", force_zip64=True) as output_file:
dill.dump(self, output_file)
def _save_state(self, archive: zipfile.ZipFile, skip_attributes: Sequence[str] = ()):
saved_attributes = {}
try:
# remove attributes we can't easily save
saved_attributes = {attr: getattr(self, attr) for attr in skip_attributes}
for attr in skip_attributes:
delattr(self, attr)

def save(self, path: pathlib.Path):
"""Save the object.
# save the remaining part
with archive.open("object.pkl", "w", force_zip64=True) as output_file:
dill.dump(self, output_file)
finally:
# restore the skipped attributes
for attr, value in saved_attributes.items():
setattr(self, attr, value)

def _save(self, path: pathlib.Path, skip_attributes: Sequence[str] = ()):
"""Save the object with more options.
This method is intended to use to implement ``save`` method during inheritance.
Parameters
----------
path:
Path to save object to.
skip_attributes:
Attributes to be skipped during saving state. These attributes are intended to be saved manually.
"""
with zipfile.ZipFile(path, "w") as archive:
self._save_metadata(archive)
self._save_state(archive)
self._save_state(archive, skip_attributes=skip_attributes)

def save(self, path: pathlib.Path):
"""Save the object.
Parameters
----------
path:
Path to save object to.
"""
self._save(path=path)

@classmethod
def _load_metadata(cls, archive: zipfile.ZipFile) -> Dict[str, Any]:
Expand Down
15 changes: 2 additions & 13 deletions etna/ensembles/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,7 @@ def save(self, path: pathlib.Path):
self.pipelines: List[BasePipeline]
self.ts: Optional[TSDataset]

pipelines = self.pipelines
ts = self.ts
try:
# extract attributes we can't easily save
delattr(self, "pipelines")
delattr(self, "ts")

# save the remaining part
super().save(path=path)
finally:
self.pipelines = pipelines
self.ts = ts
self._save(path=path, skip_attributes=["pipelines", "ts"])

with zipfile.ZipFile(path, "a") as archive:
with tempfile.TemporaryDirectory() as _temp_dir:
Expand All @@ -106,7 +95,7 @@ def save(self, path: pathlib.Path):
pipelines_dir = temp_dir / "pipelines"
pipelines_dir.mkdir()
num_digits = 8
for i, pipeline in enumerate(pipelines):
for i, pipeline in enumerate(self.pipelines):
save_name = f"{i:0{num_digits}d}.zip"
pipeline_save_path = pipelines_dir / save_name
pipeline.save(pipeline_save_path)
Expand Down
14 changes: 2 additions & 12 deletions etna/experimental/prediction_intervals/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,15 @@ def save(self, path: pathlib.Path):
"""
self.pipeline: BasePipeline

pipeline = self.pipeline

try:
# extract pipeline to save it with its own method later
delattr(self, "pipeline")

# save the remaining part
super().save(path=path)

finally:
self.pipeline = pipeline
self._save(path=path, skip_attributes=["pipeline"])

with zipfile.ZipFile(path, "a") as archive:
with tempfile.TemporaryDirectory() as _temp_dir:
temp_dir = pathlib.Path(_temp_dir)

# save pipeline separately and add to the archive
pipeline_save_path = temp_dir / "pipeline.zip"
pipeline.save(path=pipeline_save_path)
self.pipeline.save(path=pipeline_save_path)

archive.write(pipeline_save_path, "pipeline.zip")

Expand Down
10 changes: 7 additions & 3 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from etna.distributions import BaseDistribution
from etna.loggers import tslogger
from etna.models.decorators import log_decorator
from etna.models.mixins import SaveNNMixin
from etna.models.mixins import SaveDeepBaseModelMixin

if SETTINGS.torch_required:
import torch
Expand Down Expand Up @@ -429,7 +429,11 @@ def get_model(self) -> "DeepBaseNet":


class DeepBaseNet(DeepAbstractNet, LightningModule):
"""Class for partially implemented LightningModule interface."""
"""Class for partially implemented LightningModule interface.
During inheritance don't forget to add ``self.save_hyperparameters()`` to the ``__init__``.
Otherwise, methods ``save`` and ``load`` won't work properly for your implementation of :py:class:`~etna.models.base.DeepBaseModel`.
"""

def __init__(self):
"""Init DeepBaseNet."""
Expand Down Expand Up @@ -470,7 +474,7 @@ def validation_step(self, batch: dict, *args, **kwargs): # type: ignore
return loss


class DeepBaseModel(DeepBaseAbstractModel, SaveNNMixin, NonPredictionIntervalContextRequiredAbstractModel):
class DeepBaseModel(DeepBaseAbstractModel, SaveDeepBaseModelMixin, NonPredictionIntervalContextRequiredAbstractModel):
"""Class for partially implemented interfaces for holding deep models."""

def __init__(
Expand Down
153 changes: 141 additions & 12 deletions etna/models/mixins.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pathlib
import zipfile
from abc import ABC
from abc import abstractmethod
Expand All @@ -11,13 +12,21 @@
import dill
import numpy as np
import pandas as pd
from hydra_slayer import get_factory
from typing_extensions import Self

from etna import SETTINGS
from etna.core.mixins import BaseMixin
from etna.core.mixins import SaveMixin
from etna.datasets.tsdataset import TSDataset
from etna.datasets.utils import match_target_quantiles
from etna.models.decorators import log_decorator

if SETTINGS.torch_required:
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer


class ModelForecastingMixin(ABC):
"""Base class for model mixins."""
Expand Down Expand Up @@ -640,25 +649,145 @@ def get_model(self) -> Any:
return self._base_model.get_model()


class SaveNNMixin(SaveMixin):
"""Implementation of ``AbstractSaveable`` torch related classes.
def _save_pl_model(archive: zipfile.ZipFile, filename: str, model: "LightningModule"):
with archive.open(filename, "w", force_zip64=True) as output_file:
to_save = {
"class": BaseMixin._get_target_from_class(model),
"hyperparameters": dict(model.hparams),
"state_dict": model.state_dict(),
}
torch.save(to_save, output_file, pickle_module=dill)


def _load_pl_model(archive: zipfile.ZipFile, filename: str) -> "LightningModule":
with archive.open(filename, "r") as input_file:
net_loaded = torch.load(input_file, pickle_module=dill)

cls = get_factory(net_loaded["class"])
net = cls(**net_loaded["hyperparameters"])
net.load_state_dict(net_loaded["state_dict"])

return net


class SaveDeepBaseModelMixin(SaveMixin):
"""Implementation of ``AbstractSaveable`` for :py:class:`~etna.models.base.DeepBaseModel` models.
It saves object to the zip archive with files:
* metadata.json: contains library version and class name.
* object.pkl: pickled without ``self.net`` and ``self.trainer``.
* net.pt: parameters of ``self.net`` saved by ``torch.save``.
"""

def save(self, path: pathlib.Path):
"""Save the object.
Parameters
----------
path:
Path to save object to.
"""
from etna.models.base import DeepBaseNet

self.trainer: Optional[Trainer]
self.net: DeepBaseNet

self._save(path=path, skip_attributes=["trainer", "net"])

with zipfile.ZipFile(path, "a") as archive:
_save_pl_model(archive=archive, filename="net.pt", model=self.net)

@classmethod
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self:
"""Load an object.
Warning
-------
This method uses :py:mod:`dill` module which is not secure.
It is possible to construct malicious data which will execute arbitrary code during loading.
Never load data that could have come from an untrusted source, or that could have been tampered with.
Parameters
----------
path:
Path to load object from.
ts:
TSDataset to set into loaded pipeline.
Returns
-------
:
Loaded object.
"""
obj = super().load(path=path)

with zipfile.ZipFile(path, "r") as archive:
obj.net = _load_pl_model(archive=archive, filename="net.pt")
obj.trainer = None

return obj


It saves object to the zip archive with 2 files:
class SavePytorchForecastingMixin(SaveMixin):
"""Implementation of ``AbstractSaveable`` for :py:mod:`pytorch_forecasting` models.
It saves object to the zip archive with files:
* metadata.json: contains library version and class name.
* object.pt: object saved by ``torch.save``.
* object.pkl: pickled without ``self.model`` and ``self.trainer``.
* model.pt: parameters of ``self.model`` saved by ``torch.save`` if model was fitted.
"""

def _save_state(self, archive: zipfile.ZipFile):
import torch
def save(self, path: pathlib.Path):
"""Save the object.
Parameters
----------
path:
Path to save object to.
"""
self.trainer: Optional[Trainer]
self.model: Optional[LightningModule]

with archive.open("object.pt", "w", force_zip64=True) as output_file:
torch.save(self, output_file, pickle_module=dill)
if self.model is None:
self._save(path=path, skip_attributes=["trainer"])
else:
self._save(path=path, skip_attributes=["trainer", "model"])
with zipfile.ZipFile(path, "a") as archive:
_save_pl_model(archive=archive, filename="model.pt", model=self.model)

@classmethod
def _load_state(cls, archive: zipfile.ZipFile) -> Self:
import torch
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self:
"""Load an object.
Warning
-------
This method uses :py:mod:`dill` module which is not secure.
It is possible to construct malicious data which will execute arbitrary code during loading.
Never load data that could have come from an untrusted source, or that could have been tampered with.
Parameters
----------
path:
Path to load object from.
ts:
TSDataset to set into loaded pipeline.
Returns
-------
:
Loaded object.
"""
obj = super().load(path=path)
obj.trainer = None

if not hasattr(obj, "model"):
with zipfile.ZipFile(path, "r") as archive:
obj.model = _load_pl_model(archive=archive, filename="model.pt")

with archive.open("object.pt", "r") as input_file:
return torch.load(input_file, pickle_module=dill)
return obj
Loading

0 comments on commit 2651829

Please sign in to comment.