diff --git a/etna/models/nn/timesfm.py b/etna/models/nn/timesfm.py index a869ec491..0fc504360 100644 --- a/etna/models/nn/timesfm.py +++ b/etna/models/nn/timesfm.py @@ -1,4 +1,5 @@ import os +import reprlib import warnings from pathlib import Path from typing import Dict @@ -32,7 +33,8 @@ class TimesFMModel(NonPredictionIntervalContextRequiredAbstractModel): This model doesn't support forecasting on misaligned data with `freq=None` without exogenous features. - Use :py:class:`~etna.transforms.TimeSeriesImputerTransform` to fill NaNs for stable behaviour. + This model doesn't support NaN in the middle or at the end of time series. + Use :py:class:`~etna.transforms.TimeSeriesImputerTransform` to fill them. Official implementation: https://github.com/google-research/timesfm @@ -179,7 +181,7 @@ def predict( """ raise NotImplementedError("Method predict isn't currently implemented!") - def _get_exog_features(self) -> List[str]: + def _exog_сolumns(self) -> List[str]: static_reals = [] if self.static_reals is None else self.static_reals static_categoricals = [] if self.static_categoricals is None else self.static_categoricals time_varying_reals = [] if self.time_varying_reals is None else self.time_varying_reals @@ -235,13 +237,20 @@ def forecast( end_idx = len(ts.index) - all_exog = self._get_exog_features() + all_exog = self._exog_сolumns() df_slice = ts.df.loc[:, pd.IndexSlice[:, all_exog + ["target"]]] - first_valid_index = df_slice.isna().any(axis=1).idxmin() + first_valid_index = ( + df_slice.isna().any(axis=1).idxmin() + ) # If all timestamps contains NaNs, idxmin() returns the first timestamp target_df = df_slice.loc[first_valid_index : ts.index[-prediction_size - 1], pd.IndexSlice[:, "target"]] - if target_df.isna().any().any(): - raise ValueError("There are NaNs in the middle or end of the time series.") + + nan_segment_mask = target_df.isna().any() + if nan_segment_mask.any(): + nan_segments = nan_segment_mask.loc[:, nan_segment_mask].index.get_level_values(0).tolist() + raise ValueError( + f"There are NaNs in the middle or at the end of target. Segments with NaNs: {reprlib.repr(nan_segments)}." + ) future_ts = ts.tsdataset_idx_slice(start_idx=end_idx - prediction_size, end_idx=end_idx) @@ -249,8 +258,13 @@ def forecast( target = target_df.values.swapaxes(1, 0).tolist() exog_df = df_slice.loc[first_valid_index:, pd.IndexSlice[:, all_exog]] - if exog_df.isna().any().any(): - raise ValueError("There are NaNs in the middle or end of the exogenous features.") + + nan_segment_mask = exog_df.isna().any() + if nan_segment_mask.any(): + nan_segments = nan_segment_mask.loc[:, nan_segment_mask].index.get_level_values(0).tolist() + raise ValueError( + f"There are NaNs in the middle or at the end of exogenous features. Segments with NaNs: {reprlib.repr(nan_segments)}." + ) static_reals_dict = ( { diff --git a/tests/test_models/test_nn/test_timesfm.py b/tests/test_models/test_nn/test_timesfm.py index 27f0629ef..3778c432b 100644 --- a/tests/test_models/test_nn/test_timesfm.py +++ b/tests/test_models/test_nn/test_timesfm.py @@ -23,6 +23,13 @@ def generate_increasing_df(): return df +def generate_exog(): + n = 128 + df_exog = generate_ar_df(start_time="2001-01-01", periods=n + 2, n_segments=2) + df_exog.rename(columns={"target": "exog"}, inplace=True) + return df_exog + + @pytest.fixture def ts_increasing_integers(): df = generate_increasing_df() @@ -54,6 +61,24 @@ def expected_ts_increasing_integers(): return ts +@pytest.fixture +def ts_exog_middle_nan(): + df = generate_increasing_df() + df_exog = generate_exog() + df_exog.loc[120, "exog"] = np.NaN + ts = TSDataset(df, df_exog=df_exog, freq="D", known_future="all") + return ts + + +@pytest.fixture +def ts_exog_all_nan(): + df = generate_increasing_df() + df_exog = generate_exog() + df_exog["exog"] = np.NaN + ts = TSDataset(df, df_exog=df_exog, freq="D", known_future="all") + return ts + + @pytest.mark.smoke def test_url(tmp_path): model_name = "timesfm-1.0-200m-pytorch.ckpt" @@ -118,7 +143,7 @@ def test_forecast_failed_nan_middle_target(ts_nan_middle): model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=128) pipeline = Pipeline(model=model, horizon=2) pipeline.fit(ts_nan_middle) - with pytest.raises(ValueError, match="There are NaNs in the middle or end of the time series."): + with pytest.raises(ValueError, match=r"There are NaNs in the middle or at the end of target. Segments with NaNs:"): _ = pipeline.forecast() @@ -162,7 +187,25 @@ def test_forecast_exog_features_failed_nan_middle_target(ts_nan_middle): ) pipeline = Pipeline(model=model, transforms=transforms, horizon=horizon) pipeline.fit(ts_nan_middle) - with pytest.raises(ValueError, match="There are NaNs in the middle or end of the time series."): + with pytest.raises(ValueError, match="There are NaNs in the middle or at the end of target. Segments with NaNs:"): + _ = pipeline.forecast() + + +@pytest.mark.parametrize("ts", ["ts_exog_middle_nan", "ts_exog_all_nan"]) +def test_forecast_exog_features_failed_exog_nan(ts, request): + ts = request.getfixturevalue(ts) + + horizon = 2 + model = TimesFMModel( + path_or_url="google/timesfm-1.0-200m-pytorch", + encoder_length=128, + time_varying_reals=["exog"], + ) + pipeline = Pipeline(model=model, transforms=[], horizon=horizon) + pipeline.fit(ts) + with pytest.raises( + ValueError, match="There are NaNs in the middle or at the end of exogenous features. Segments with NaNs:" + ): _ = pipeline.forecast()