diff --git a/CHANGELOG.md b/CHANGELOG.md index ffa9a0e5d..4d9bc9fe2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,7 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ### Fixed -- +- Fix `ModelDecomposeTransform` import without `prophet` module ([#459](https://github.com/etna-team/etna/pull/459)) - - - diff --git a/etna/transforms/decomposition/model_based.py b/etna/transforms/decomposition/model_based.py index 00804ff35..f6da433ed 100644 --- a/etna/transforms/decomposition/model_based.py +++ b/etna/transforms/decomposition/model_based.py @@ -4,14 +4,16 @@ import pandas as pd +from etna import SETTINGS from etna.datasets import TSDataset from etna.datasets.utils import determine_num_steps from etna.models import BATSModel from etna.models import DeadlineMovingAverageModel +from etna.models import HoltModel from etna.models import HoltWintersModel -from etna.models import ProphetModel from etna.models import SARIMAXModel from etna.models import SeasonalMovingAverageModel +from etna.models import SimpleExpSmoothingModel from etna.models import TBATSModel from etna.models.base import ContextRequiredModelType from etna.models.base import ModelType @@ -19,7 +21,8 @@ _SUPPORTED_MODELS = Union[ HoltWintersModel, # full - ProphetModel, # full + HoltModel, # full + SimpleExpSmoothingModel, # full SARIMAXModel, # full DeadlineMovingAverageModel, # need to account context/prediction size SeasonalMovingAverageModel, # need to account context/prediction size @@ -27,6 +30,14 @@ TBATSModel, # dynamic components, not reliable ] +if SETTINGS.prophet_required: + from etna.models import ProphetModel + + _SUPPORTED_MODELS = Union[ # type: ignore + _SUPPORTED_MODELS, + ProphetModel, # full + ] + class ModelDecomposeTransform(IrreversibleTransform): """Transform that uses ETNA models to estimate series decomposition. diff --git a/tests/test_transforms/test_decomposition/test_model_based.py b/tests/test_transforms/test_decomposition/test_model_based.py index fed86d28f..3944ef9fa 100644 --- a/tests/test_transforms/test_decomposition/test_model_based.py +++ b/tests/test_transforms/test_decomposition/test_model_based.py @@ -6,10 +6,12 @@ from etna.models import BATSModel from etna.models import CatBoostPerSegmentModel from etna.models import DeadlineMovingAverageModel +from etna.models import HoltModel from etna.models import HoltWintersModel from etna.models import ProphetModel from etna.models import SARIMAXModel from etna.models import SeasonalMovingAverageModel +from etna.models import SimpleExpSmoothingModel from etna.models import TBATSModel from etna.pipeline import Pipeline from etna.transforms import IForestOutlierTransform @@ -253,6 +255,9 @@ def test_pipeline_models(ts_name, in_column, decompose_model, forecast_model, re "decompose_model", ( HoltWintersModel(), + HoltModel(), + SimpleExpSmoothingModel(), + HoltWintersModel(trend="add", seasonal="add"), ProphetModel(), SARIMAXModel(), DeadlineMovingAverageModel(),