Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor Baturin committed Dec 26, 2024
1 parent e99e687 commit 85fbe3c
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 10 deletions.
30 changes: 22 additions & 8 deletions etna/models/nn/timesfm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import reprlib
import warnings
from pathlib import Path
from typing import Dict
Expand Down Expand Up @@ -32,7 +33,8 @@ class TimesFMModel(NonPredictionIntervalContextRequiredAbstractModel):
This model doesn't support forecasting on misaligned data with `freq=None` without exogenous features.
Use :py:class:`~etna.transforms.TimeSeriesImputerTransform` to fill NaNs for stable behaviour.
This model doesn't support NaN in the middle or at the end of time series.
Use :py:class:`~etna.transforms.TimeSeriesImputerTransform` to fill them.
Official implementation: https://github.com/google-research/timesfm
Expand Down Expand Up @@ -179,7 +181,7 @@ def predict(
"""
raise NotImplementedError("Method predict isn't currently implemented!")

def _get_exog_features(self) -> List[str]:
def _exog_сolumns(self) -> List[str]:
static_reals = [] if self.static_reals is None else self.static_reals
static_categoricals = [] if self.static_categoricals is None else self.static_categoricals
time_varying_reals = [] if self.time_varying_reals is None else self.time_varying_reals
Expand Down Expand Up @@ -235,22 +237,34 @@ def forecast(

end_idx = len(ts.index)

all_exog = self._get_exog_features()
all_exog = self._exog_сolumns()
df_slice = ts.df.loc[:, pd.IndexSlice[:, all_exog + ["target"]]]
first_valid_index = df_slice.isna().any(axis=1).idxmin()
first_valid_index = (
df_slice.isna().any(axis=1).idxmin()
) # If all timestamps contains NaNs, idxmin() returns the first timestamp

target_df = df_slice.loc[first_valid_index : ts.index[-prediction_size - 1], pd.IndexSlice[:, "target"]]
if target_df.isna().any().any():
raise ValueError("There are NaNs in the middle or end of the time series.")

nan_segment_mask = target_df.isna().any()
if nan_segment_mask.any():
nan_segments = nan_segment_mask.loc[:, nan_segment_mask].index.get_level_values(0).tolist()
raise ValueError(
f"There are NaNs in the middle or at the end of target. Segments with NaNs: {reprlib.repr(nan_segments)}."
)

future_ts = ts.tsdataset_idx_slice(start_idx=end_idx - prediction_size, end_idx=end_idx)

if len(all_exog) > 0:
target = target_df.values.swapaxes(1, 0).tolist()

exog_df = df_slice.loc[first_valid_index:, pd.IndexSlice[:, all_exog]]
if exog_df.isna().any().any():
raise ValueError("There are NaNs in the middle or end of the exogenous features.")

nan_segment_mask = exog_df.isna().any()
if nan_segment_mask.any():
nan_segments = nan_segment_mask.loc[:, nan_segment_mask].index.get_level_values(0).tolist()
raise ValueError(
f"There are NaNs in the middle or at the end of exogenous features. Segments with NaNs: {reprlib.repr(nan_segments)}."
)

static_reals_dict = (
{
Expand Down
47 changes: 45 additions & 2 deletions tests/test_models/test_nn/test_timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ def generate_increasing_df():
return df


def generate_exog():
n = 128
df_exog = generate_ar_df(start_time="2001-01-01", periods=n + 2, n_segments=2)
df_exog.rename(columns={"target": "exog"}, inplace=True)
return df_exog


@pytest.fixture
def ts_increasing_integers():
df = generate_increasing_df()
Expand Down Expand Up @@ -54,6 +61,24 @@ def expected_ts_increasing_integers():
return ts


@pytest.fixture
def ts_exog_middle_nan():
df = generate_increasing_df()
df_exog = generate_exog()
df_exog.loc[120, "exog"] = np.NaN
ts = TSDataset(df, df_exog=df_exog, freq="D", known_future="all")
return ts


@pytest.fixture
def ts_exog_all_nan():
df = generate_increasing_df()
df_exog = generate_exog()
df_exog["exog"] = np.NaN
ts = TSDataset(df, df_exog=df_exog, freq="D", known_future="all")
return ts


@pytest.mark.smoke
def test_url(tmp_path):
model_name = "timesfm-1.0-200m-pytorch.ckpt"
Expand Down Expand Up @@ -118,7 +143,7 @@ def test_forecast_failed_nan_middle_target(ts_nan_middle):
model = TimesFMModel(path_or_url="google/timesfm-1.0-200m-pytorch", encoder_length=128)
pipeline = Pipeline(model=model, horizon=2)
pipeline.fit(ts_nan_middle)
with pytest.raises(ValueError, match="There are NaNs in the middle or end of the time series."):
with pytest.raises(ValueError, match=r"There are NaNs in the middle or at the end of target. Segments with NaNs:"):
_ = pipeline.forecast()


Expand Down Expand Up @@ -162,7 +187,25 @@ def test_forecast_exog_features_failed_nan_middle_target(ts_nan_middle):
)
pipeline = Pipeline(model=model, transforms=transforms, horizon=horizon)
pipeline.fit(ts_nan_middle)
with pytest.raises(ValueError, match="There are NaNs in the middle or end of the time series."):
with pytest.raises(ValueError, match="There are NaNs in the middle or at the end of target. Segments with NaNs:"):
_ = pipeline.forecast()


@pytest.mark.parametrize("ts", ["ts_exog_middle_nan", "ts_exog_all_nan"])
def test_forecast_exog_features_failed_exog_nan(ts, request):
ts = request.getfixturevalue(ts)

horizon = 2
model = TimesFMModel(
path_or_url="google/timesfm-1.0-200m-pytorch",
encoder_length=128,
time_varying_reals=["exog"],
)
pipeline = Pipeline(model=model, transforms=[], horizon=horizon)
pipeline.fit(ts)
with pytest.raises(
ValueError, match="There are NaNs in the middle or at the end of exogenous features. Segments with NaNs:"
):
_ = pipeline.forecast()


Expand Down

0 comments on commit 85fbe3c

Please sign in to comment.