Skip to content

Commit

Permalink
comment inference timesfm tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor Baturin committed Dec 27, 2024
1 parent 0d27c55 commit bdd0ec6
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 122 deletions.
133 changes: 66 additions & 67 deletions tests/test_models/test_inference/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from etna.models.nn import RNNModel
from etna.models.nn import TFTModel
from etna.models.nn import TFTNativeModel
from etna.models.nn import TimesFMModel
from etna.models.nn.deepstate import CompositeSSM
from etna.models.nn.deepstate import WeeklySeasonalitySSM
from etna.transforms import LagTransform
Expand Down Expand Up @@ -249,12 +248,12 @@ def test_forecast_in_sample_full_no_target_failed_chronos(self, model, transform
@pytest.mark.parametrize(
"model, transforms, dataset_name",
[
pytest.param(
TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
[],
"example_tsds",
marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
),
# pytest.param(
# TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
# [],
# "example_tsds",
# marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
# ),
],
)
def test_forecast_in_sample_full_no_target_failed_timesfm(self, model, transforms, dataset_name, request):
Expand Down Expand Up @@ -427,12 +426,12 @@ def test_forecast_in_sample_full_failed_chronos(self, model, transforms, dataset
@pytest.mark.parametrize(
"model, transforms, dataset_name",
[
pytest.param(
TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
[],
"example_tsds",
marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
),
# pytest.param(
# TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
# [],
# "example_tsds",
# marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
# ),
],
)
def test_forecast_in_sample_full_failed_timesfm(self, model, transforms, dataset_name, request):
Expand Down Expand Up @@ -526,12 +525,12 @@ def _test_forecast_in_sample_suffix_no_target(ts, model, transforms, num_skip_po
(NBeatsGenericModel(input_size=7, output_size=50, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"),
(ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"),
pytest.param(
TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
[],
"example_tsds",
marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
),
# pytest.param(
# TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
# [],
# "example_tsds",
# marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
# ),
],
)
def test_forecast_in_sample_suffix_no_target(self, model, transforms, dataset_name, request):
Expand Down Expand Up @@ -656,12 +655,12 @@ class TestForecastInSampleSuffix:
(NBeatsGenericModel(input_size=7, output_size=50, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"),
(ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"),
pytest.param(
TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
[],
"example_tsds",
marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
),
# pytest.param(
# TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
# [],
# "example_tsds",
# marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
# ),
],
)
def test_forecast_in_sample_suffix(self, model, transforms, dataset_name, request):
Expand Down Expand Up @@ -837,12 +836,12 @@ def _test_forecast_out_sample(ts, model, transforms, prediction_size=5):
(NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"),
(ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"),
pytest.param(
TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
[],
"example_tsds",
marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
),
# pytest.param(
# TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
# [],
# "example_tsds",
# marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
# ),
],
)
def test_forecast_out_sample_datetime_timestamp(self, model, transforms, dataset_name, request):
Expand Down Expand Up @@ -945,12 +944,12 @@ def test_forecast_out_sample_int_timestamp(self, model, transforms, dataset_name
@pytest.mark.parametrize(
"model, transforms, dataset_name",
[
pytest.param(
TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
[],
"example_tsds",
marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
),
# pytest.param(
# TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
# [],
# "example_tsds",
# marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
# ),
],
)
def test_forecast_out_sample_int_timestamp_failed_timesfm(self, model, transforms, dataset_name, request):
Expand Down Expand Up @@ -1115,12 +1114,12 @@ def _test_forecast_out_sample_prefix(ts, model, transforms, full_prediction_size
(NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"),
(ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"),
pytest.param(
TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
[],
"example_tsds",
marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
),
# pytest.param(
# TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
# [],
# "example_tsds",
# marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
# ),
],
)
def test_forecast_out_sample_prefix(self, model, transforms, dataset_name, request):
Expand Down Expand Up @@ -1332,12 +1331,12 @@ def test_forecast_out_sample_suffix_failed_chronos(self, model, transforms, data
@pytest.mark.parametrize(
"model, transforms, dataset_name",
[
pytest.param(
TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
[],
"example_tsds",
marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
),
# pytest.param(
# TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
# [],
# "example_tsds",
# marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
# ),
],
)
def test_forecast_out_sample_suffix_failed_timesfm(self, model, transforms, dataset_name, request):
Expand Down Expand Up @@ -1486,12 +1485,12 @@ def _test_forecast_mixed_in_out_sample(ts, model, transforms, num_skip_points=50
(NBeatsGenericModel(input_size=7, output_size=55, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"),
(ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"),
pytest.param(
TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
[],
"example_tsds",
marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
),
# pytest.param(
# TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
# [],
# "example_tsds",
# marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
# ),
],
)
def test_forecast_mixed_in_out_sample(self, model, transforms, dataset_name, request):
Expand Down Expand Up @@ -1654,12 +1653,12 @@ def _test_forecast_subset_segments(self, ts, model, transforms, segments, predic
),
(NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"),
pytest.param(
TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
[],
"example_tsds",
marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
),
# pytest.param(
# TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
# [],
# "example_tsds",
# marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
# ),
],
)
def test_forecast_subset_segments(self, model, transforms, dataset_name, request):
Expand Down Expand Up @@ -1837,12 +1836,12 @@ def _test_forecast_new_segments(self, ts, model, transforms, train_segments, pre
(NBeatsGenericModel(input_size=7, output_size=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(ChronosModel(path_or_url="amazon/chronos-t5-tiny", encoder_length=7), [], "example_tsds"),
(ChronosBoltModel(path_or_url="amazon/chronos-bolt-tiny", encoder_length=7), [], "example_tsds"),
pytest.param(
TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
[],
"example_tsds",
marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
),
# pytest.param(
# TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32),
# [],
# "example_tsds",
# marks=pytest.mark.skip(reason="Model causes OOM in GitHub Actions."),
# ),
],
)
def test_forecast_new_segments(self, model, transforms, dataset_name, request):
Expand Down
Loading

0 comments on commit bdd0ec6

Please sign in to comment.