Skip to content

Commit

Permalink
unlock all tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor Baturin committed Dec 27, 2024
1 parent c911b2f commit 2b3e3fc
Showing 1 changed file with 0 additions and 19 deletions.
19 changes: 0 additions & 19 deletions tests/test_models/test_nn/test_timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def test_url(tmp_path):
assert os.path.exists(tmp_path / model_name)


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.smoke
def test_cache_dir(tmp_path):
path_or_url = "google/timesfm-1.0-200m-pytorch"
Expand All @@ -96,36 +95,31 @@ def test_cache_dir(tmp_path):
assert os.path.exists(tmp_path / f"models--google--{model_name}")


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.smoke
def test_context_size():
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=10)
assert model.context_size == 10


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.smoke
def test_get_model(example_tsds):
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch")
assert isinstance(model.get_model(), TimesFmTorch)


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.smoke
def test_fit(example_tsds):
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch")
model.fit(example_tsds)


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.smoke
def test_predict(example_tsds):
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch")
with pytest.raises(NotImplementedError, match="Method predict isn't currently implemented!"):
model.predict(ts=example_tsds, prediction_size=1)


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
def test_forecast_warns_big_context_size(ts_increasing_integers):
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=512)
pipeline = Pipeline(model=model, horizon=1)
Expand All @@ -134,7 +128,6 @@ def test_forecast_warns_big_context_size(ts_increasing_integers):
_ = pipeline.forecast()


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.parametrize("encoder_length", [32, 64, 128])
@pytest.mark.parametrize("ts", ["ts_increasing_integers", "ts_nan_start"])
def test_forecast(ts, expected_ts_increasing_integers, encoder_length, request):
Expand All @@ -146,7 +139,6 @@ def test_forecast(ts, expected_ts_increasing_integers, encoder_length, request):
assert_frame_equal(forecast.df, expected_ts_increasing_integers.df, atol=1)


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
def test_forecast_failed_nan_middle_target(ts_nan_middle):
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=128)
pipeline = Pipeline(model=model, horizon=2)
Expand All @@ -155,7 +147,6 @@ def test_forecast_failed_nan_middle_target(ts_nan_middle):
_ = pipeline.forecast()


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.parametrize("encoder_length", [32, 64, 128])
@pytest.mark.parametrize("ts", ["ts_increasing_integers", "ts_nan_start"])
def test_forecast_exogenous_features(ts, expected_ts_increasing_integers, encoder_length, request):
Expand All @@ -180,7 +171,6 @@ def test_forecast_exogenous_features(ts, expected_ts_increasing_integers, encode
assert_frame_equal(forecast.df.loc[:, pd.IndexSlice[:, "target"]], expected_ts_increasing_integers.df, atol=1)


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
def test_forecast_exog_features_failed_nan_middle_target(ts_nan_middle):
horizon = 2
transforms = [
Expand All @@ -201,7 +191,6 @@ def test_forecast_exog_features_failed_nan_middle_target(ts_nan_middle):
_ = pipeline.forecast()


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.parametrize("ts", ["ts_exog_middle_nan", "ts_exog_all_nan"])
def test_forecast_exog_features_failed_exog_nan(ts, request):
ts = request.getfixturevalue(ts)
Expand All @@ -220,7 +209,6 @@ def test_forecast_exog_features_failed_exog_nan(ts, request):
_ = pipeline.forecast()


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.smoke
def test_forecast_only_target_failed_int_timestamps(example_tsds_int_timestamp):
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32)
Expand All @@ -233,7 +221,6 @@ def test_forecast_only_target_failed_int_timestamps(example_tsds_int_timestamp):
_ = pipeline.forecast()


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.smoke
def test_forecast_exog_int_timestamps(example_tsds_int_timestamp):
horizon = 2
Expand All @@ -253,7 +240,6 @@ def test_forecast_exog_int_timestamps(example_tsds_int_timestamp):
_ = pipeline.forecast()


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.parametrize("encoder_length", [16, 33])
def test_forecast_wrong_context_len(ts_increasing_integers, encoder_length):
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=encoder_length)
Expand All @@ -263,15 +249,13 @@ def test_forecast_wrong_context_len(ts_increasing_integers, encoder_length):
_ = pipeline.forecast()


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.smoke
def test_forecast_without_fit(example_tsds):
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=32)
pipeline = Pipeline(model=model, horizon=1)
_ = pipeline.forecast(example_tsds)


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.smoke
def test_forecast_fails_components(example_tsds):
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch")
Expand All @@ -280,13 +264,11 @@ def test_forecast_fails_components(example_tsds):
pipeline.forecast(ts=example_tsds, return_components=True)


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.smoke
def test_list_models():
assert TimesFMModel.list_models() == ["google/timesfm-1.0-200m-pytorch"]


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.smoke
def test_save_load(tmp_path, ts_increasing_integers):
path = Path(tmp_path) / "tmp.zip"
Expand All @@ -300,7 +282,6 @@ def test_save_load(tmp_path, ts_increasing_integers):
assert isinstance(loaded_model, TimesFMModel)


@pytest.mark.skip(reason="Model causes OOM in GitHub Actions.")
@pytest.mark.smoke
def test_params_to_tune():
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch")
Expand Down

0 comments on commit 2b3e3fc

Please sign in to comment.