forked from tinkoff-ai/etna
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Transform for series decomposition using ETNA models (#427)
* added implementation * updated imports * moved fixture * fixed predict components for segment subset * added inference tests * added tests * updated documentation * fixed tests * fixed tests * save index name * fixed typos * fixed docstring * updated tests * updated changelog * review fixes * fixed test * updated doc * fixed doc * updated doc
- Loading branch information
Showing
12 changed files
with
612 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
from typing import List | ||
from typing import Union | ||
from typing import get_args | ||
|
||
import pandas as pd | ||
|
||
from etna.datasets import TSDataset | ||
from etna.datasets.utils import determine_num_steps | ||
from etna.models import BATSModel | ||
from etna.models import DeadlineMovingAverageModel | ||
from etna.models import HoltWintersModel | ||
from etna.models import ProphetModel | ||
from etna.models import SARIMAXModel | ||
from etna.models import SeasonalMovingAverageModel | ||
from etna.models import TBATSModel | ||
from etna.models.base import ContextRequiredModelType | ||
from etna.models.base import ModelType | ||
from etna.transforms import IrreversibleTransform | ||
|
||
_SUPPORTED_MODELS = Union[ | ||
HoltWintersModel, # full | ||
ProphetModel, # full | ||
SARIMAXModel, # full | ||
DeadlineMovingAverageModel, # need to account context/prediction size | ||
SeasonalMovingAverageModel, # need to account context/prediction size | ||
BATSModel, # dynamic components, not reliable | ||
TBATSModel, # dynamic components, not reliable | ||
] | ||
|
||
|
||
class ModelDecomposeTransform(IrreversibleTransform): | ||
"""Transform that uses ETNA models to estimate series decomposition. | ||
Note | ||
---- | ||
This transform decomposes only in-sample data. For the future timestamps it produces ``NaN``. | ||
For the dataset to be transformed, it should contain at least the minimum amount of in-sample timestamps that are required by the model. | ||
""" | ||
|
||
def __init__(self, model: ModelType, in_column: str = "target", residuals: bool = False): | ||
"""Init ``ModelDecomposeTransform``. | ||
Parameters | ||
---------- | ||
model: | ||
instance of the model to use for the decomposition. Note that not all models are supported. Possible selections are: | ||
- ``HoltWintersModel`` | ||
- ``ProphetModel`` | ||
- ``SARIMAXModel`` | ||
- ``DeadlineMovingAverageModel`` | ||
- ``SeasonalMovingAverageModel`` | ||
- ``BATSModel`` | ||
- ``TBATSModel`` | ||
Currently, only the specified series itself is used for model fitting. There is no way to add additional features/regressors to the decomposition model. | ||
in_column: | ||
name of the processed column. | ||
residuals: | ||
whether to add residuals after decomposition. This guarantees that all components, including residuals, sum up to the series. | ||
Warning | ||
------- | ||
Options for parameter ``model`` :py:class:`etna.models.BATSModel` and :py:class:`etna.models.TBATSModel` may result in different components set compared to the initialization parameters. | ||
In such case, a corresponding warning would be raised. | ||
""" | ||
if not isinstance(model, get_args(_SUPPORTED_MODELS)): | ||
raise ValueError( | ||
f"Model type `{type(model).__name__}` is not supported! Supported models are: {_SUPPORTED_MODELS}" | ||
) | ||
|
||
self.model = model | ||
self.in_column = in_column | ||
self.residuals = residuals | ||
|
||
self._first_timestamp = None | ||
self._last_timestamp = None | ||
|
||
super().__init__(required_features=[in_column]) | ||
|
||
def get_regressors_info(self) -> List[str]: | ||
"""Return the list with regressors created by the transform.""" | ||
return [] | ||
|
||
def _fit(self, df: pd.DataFrame): | ||
"""Fit transform with the dataframe.""" | ||
pass | ||
|
||
def _transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
"""Transform provided dataframe.""" | ||
pass | ||
|
||
def _prepare_ts(self, ts: TSDataset) -> TSDataset: | ||
"""Prepare dataset for the decomposition model.""" | ||
if self.in_column not in ts.features: | ||
raise KeyError(f"Column {self.in_column} is not found in features!") | ||
|
||
df = ts.df.loc[:, pd.IndexSlice[:, self.in_column]] | ||
df = df.rename(columns={self.in_column: "target"}, level="feature") | ||
|
||
return TSDataset(df=df, freq=ts.freq) | ||
|
||
def fit(self, ts: TSDataset) -> "ModelDecomposeTransform": | ||
"""Fit the transform and the decomposition model. | ||
Parameters | ||
---------- | ||
ts: | ||
dataset to fit the transform on. | ||
Returns | ||
------- | ||
: | ||
the fitted transform instance. | ||
""" | ||
self._first_timestamp = ts.index.min() | ||
self._last_timestamp = ts.index.max() | ||
|
||
ts = self._prepare_ts(ts=ts) | ||
|
||
self.model.fit(ts) | ||
return self | ||
|
||
def transform(self, ts: TSDataset) -> TSDataset: | ||
"""Transform ``TSDataset`` inplace. | ||
Parameters | ||
---------- | ||
ts: | ||
Dataset to transform. | ||
Returns | ||
------- | ||
: | ||
Transformed ``TSDataset``. | ||
""" | ||
if self._first_timestamp is None: | ||
raise ValueError("Transform is not fitted!") | ||
|
||
if ts.index.min() < self._first_timestamp: | ||
raise ValueError( | ||
f"First index of the dataset to be transformed must be larger or equal than {self._first_timestamp}!" | ||
) | ||
|
||
if ts.index.min() > self._last_timestamp: | ||
raise ValueError( | ||
f"Dataset to be transformed must contain historical observations in range {self._first_timestamp} - {self._last_timestamp}" | ||
) | ||
|
||
decompose_ts = self._prepare_ts(ts=ts) | ||
|
||
future_steps = 0 | ||
ts_max_timestamp = decompose_ts.index.max() | ||
if ts_max_timestamp > self._last_timestamp: | ||
future_steps = determine_num_steps(self._last_timestamp, ts_max_timestamp, freq=decompose_ts.freq) | ||
decompose_ts.df = decompose_ts.df.loc[: self._last_timestamp] | ||
|
||
target = decompose_ts[..., "target"].droplevel("feature", axis=1) | ||
|
||
if isinstance(self.model, get_args(ContextRequiredModelType)): | ||
decompose_ts = self.model.predict( | ||
decompose_ts, prediction_size=decompose_ts.size()[0] - self.model.context_size, return_components=True | ||
) | ||
|
||
else: | ||
decompose_ts = self.model.predict(decompose_ts, return_components=True) | ||
|
||
components_df = decompose_ts[..., decompose_ts.target_components_names] | ||
|
||
components_names = [x.replace("target_component", self.in_column) for x in decompose_ts.target_components_names] | ||
|
||
rename = dict(zip(decompose_ts.target_components_names, components_names)) | ||
|
||
if self.residuals: | ||
components_sum = components_df.groupby(level="segment", axis=1).sum() | ||
for segment in ts.segments: | ||
components_df[segment, f"{self.in_column}_residuals"] = target[segment] - components_sum[segment] | ||
|
||
components_df.rename(columns=rename, level="feature", inplace=True) | ||
|
||
if future_steps > 0: | ||
components_df = TSDataset._expand_index(df=components_df, future_steps=future_steps, freq=decompose_ts.freq) | ||
|
||
ts.add_columns_from_pandas(components_df) | ||
|
||
return ts | ||
|
||
|
||
__all__ = ["ModelDecomposeTransform"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from etna.datasets import TSDataset | ||
from etna.datasets import generate_ar_df | ||
|
||
|
||
@pytest.fixture() | ||
def ts_with_exogs() -> TSDataset: | ||
periods = 100 | ||
periods_exog = periods + 10 | ||
df = generate_ar_df(start_time="2020-01-01", periods=periods, freq="D", n_segments=2) | ||
df_exog = generate_ar_df(start_time="2020-01-01", periods=periods_exog, freq="D", n_segments=2, random_seed=2) | ||
df_exog.rename(columns={"target": "exog"}, inplace=True) | ||
df_exog["holiday"] = np.random.choice([0, 1], size=periods_exog * 2) | ||
|
||
ts = TSDataset(df, freq="D", df_exog=df_exog, known_future="all") | ||
return ts | ||
|
||
|
||
@pytest.fixture() | ||
def ts_with_exogs_train_test(ts_with_exogs): | ||
return ts_with_exogs.train_test_split(test_size=20) | ||
|
||
|
||
@pytest.fixture() | ||
def forward_stride_datasets(ts_with_exogs): | ||
train_df = ts_with_exogs.df.iloc[:-10] | ||
test_df = ts_with_exogs.df.iloc[-20:] | ||
|
||
train_ts = TSDataset(df=train_df, freq=ts_with_exogs.freq) | ||
test_ts = TSDataset(df=test_df, freq=ts_with_exogs.freq) | ||
|
||
return train_ts, test_ts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.