Skip to content

Commit

Permalink
fix notebook, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor Baturin committed Dec 25, 2024
1 parent 2b89f55 commit 10d60b4
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 50 deletions.
164 changes: 116 additions & 48 deletions examples/202-NN_examples.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions tests/test_models/test_inference/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,7 @@ def test_forecast_out_sample_suffix_failed_chronos(self, model, transforms, data
],
)
def test_forecast_out_sample_suffix_failed_timesfm(self, model, transforms, dataset_name, request):
"""This test is expected to fail due to patch strategy in TimesFM"""
ts = request.getfixturevalue(dataset_name)
with pytest.raises(AssertionError):
self._test_forecast_out_sample_suffix(ts, model, transforms)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_models/test_nn/test_timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def test_predict(example_tsds):
model.predict(ts=example_tsds, prediction_size=1)


@pytest.mark.smoke
def test_forecast_warns_big_context_size(ts_increasing_integers):
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=512)
pipeline = Pipeline(model=model, horizon=1)
Expand Down Expand Up @@ -142,7 +141,6 @@ def test_forecast_exog_int_timestamps(example_tsds_int_timestamp):
_ = pipeline.forecast()


@pytest.mark.smoke
@pytest.mark.parametrize("encoder_length", [16, 33])
def test_forecast_wrong_context_len(ts_increasing_integers, encoder_length):
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=encoder_length)
Expand Down

0 comments on commit 10d60b4

Please sign in to comment.