diff --git a/tests/test_models/test_inference/test_forecast.py b/tests/test_models/test_inference/test_forecast.py index eb3a25ea9..7b7557c9b 100644 --- a/tests/test_models/test_inference/test_forecast.py +++ b/tests/test_models/test_inference/test_forecast.py @@ -46,7 +46,6 @@ from etna.models.nn import RNNModel from etna.models.nn import TFTModel from etna.models.nn import TFTNativeModel -from etna.models.nn import TimesFMModel from etna.models.nn.deepstate import CompositeSSM from etna.models.nn.deepstate import WeeklySeasonalitySSM from etna.transforms import LagTransform @@ -249,12 +248,12 @@ def test_forecast_in_sample_full_no_target_failed_chronos(self, model, transform @pytest.mark.parametrize( "model, transforms, dataset_name", [ - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_in_sample_full_no_target_failed_timesfm(self, model, transforms, dataset_name, request): @@ -427,12 +426,12 @@ def test_forecast_in_sample_full_failed_chronos(self, model, transforms, dataset @pytest.mark.parametrize( "model, transforms, dataset_name", [ - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_in_sample_full_failed_timesfm(self, model, transforms, dataset_name, request): @@ -526,12 +525,12 @@ def _test_forecast_in_sample_suffix_no_target(ts, model, transforms, num_skip_po (NBeatsGenericModel(input_size=7, output_size=50, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_in_sample_suffix_no_target(self, model, transforms, dataset_name, request): @@ -656,12 +655,12 @@ class TestForecastInSampleSuffix: (NBeatsGenericModel(input_size=7, output_size=50, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_in_sample_suffix(self, model, transforms, dataset_name, request): @@ -837,12 +836,12 @@ def _test_forecast_out_sample(ts, model, transforms, prediction_size=5): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_out_sample_datetime_timestamp(self, model, transforms, dataset_name, request): @@ -945,12 +944,12 @@ def test_forecast_out_sample_int_timestamp(self, model, transforms, dataset_name @pytest.mark.parametrize( "model, transforms, dataset_name", [ - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_out_sample_int_timestamp_failed_timesfm(self, model, transforms, dataset_name, request): @@ -1115,12 +1114,12 @@ def _test_forecast_out_sample_prefix(ts, model, transforms, full_prediction_size (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_out_sample_prefix(self, model, transforms, dataset_name, request): @@ -1332,12 +1331,12 @@ def test_forecast_out_sample_suffix_failed_chronos(self, model, transforms, data @pytest.mark.parametrize( "model, transforms, dataset_name", [ - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_out_sample_suffix_failed_timesfm(self, model, transforms, dataset_name, request): @@ -1486,12 +1485,12 @@ def _test_forecast_mixed_in_out_sample(ts, model, transforms, num_skip_points=50 (NBeatsGenericModel(input_size=7, output_size=55, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_mixed_in_out_sample(self, model, transforms, dataset_name, request): @@ -1654,12 +1653,12 @@ def _test_forecast_subset_segments(self, ts, model, transforms, segments, predic ), (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_subset_segments(self, model, transforms, dataset_name, request): @@ -1837,12 +1836,12 @@ def _test_forecast_new_segments(self, ts, model, transforms, train_segments, pre (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_forecast_new_segments(self, model, transforms, dataset_name, request): diff --git a/tests/test_models/test_inference/test_predict.py b/tests/test_models/test_inference/test_predict.py index a6ef08bd8..2c8d433dc 100644 --- a/tests/test_models/test_inference/test_predict.py +++ b/tests/test_models/test_inference/test_predict.py @@ -44,7 +44,6 @@ from etna.models.nn import RNNModel from etna.models.nn import TFTModel from etna.models.nn import TFTNativeModel -from etna.models.nn import TimesFMModel from etna.models.nn.deepstate import CompositeSSM from etna.models.nn.deepstate import WeeklySeasonalitySSM from etna.transforms import LagTransform @@ -196,12 +195,12 @@ def test_predict_in_sample_full_failed_not_enough_context(self, model, transform (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_in_sample_full_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -333,12 +332,12 @@ def test_predict_in_sample_suffix_datetime_timestamp(self, model, transforms, da (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_in_sample_suffix_datetime_timestamp_failed_not_implemented_predict( @@ -485,12 +484,12 @@ def test_predict_in_sample_suffix_int_timestamp_failed(self, model, transforms, (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_in_sample_suffix_int_timestamp_failed_not_implemented_predict( @@ -633,12 +632,12 @@ def test_predict_out_sample(self, model, transforms, dataset_name, request): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_out_sample_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -798,12 +797,12 @@ def test_predict_out_sample_prefix(self, model, transforms, dataset_name, reques (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_out_sample_prefix_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -977,12 +976,12 @@ def test_predict_out_sample_suffix(self, model, transforms, dataset_name, reques (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_out_sample_suffix_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1161,12 +1160,12 @@ def test_predict_mixed_in_out_sample(self, model, transforms, dataset_name, requ (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_mixed_in_out_sample_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1333,12 +1332,12 @@ def test_predict_subset_segments(self, model, transforms, dataset_name, request) (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_subset_segments_failed_not_implemented_predict(self, model, transforms, dataset_name, request): @@ -1470,12 +1469,12 @@ def test_predict_new_segments(self, model, transforms, dataset_name, request): (NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), (ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"), (ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"), - pytest.param( - TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), - [], - "example_tsds", - marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), - ), + # pytest.param( + # TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32), + # [], + # "example_tsds", + # marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."), + # ), ], ) def test_predict_new_segments_failed_not_implemented_predict(self, model, transforms, dataset_name, request):