Skip to content

Commit

Permalink
[BUG] Forecast visualization with horizon=1 (#426)
Browse files Browse the repository at this point in the history
* Forecast visualization with horizon=1

* fix interval plots

* changelog

---------

Co-authored-by: Egor Baturin <[email protected]>
  • Loading branch information
egoriyaa and Egor Baturin authored Jul 18, 2024
1 parent 12f19fb commit 3553761
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix rendering in 210 tutorial ([#386](https://github.com/etna-team/etna/pull/386))
- Fix typo in 103 tutorial ([#408](https://github.com/etna-team/etna/pull/408))
- Remove sorting of `ts.df` by timestamps in `plot_forecast` and `plot_forecast_decomposition` ([#410](https://github.com/etna-team/etna/pull/410))
-
- Fix forecast visualization with `horizon=1` ([#426](https://github.com/etna-team/etna/pull/426))
-
-
- Fix passing custom model to `STLTransform` ([#412](https://github.com/etna-team/etna/pull/412))
Expand Down
23 changes: 17 additions & 6 deletions etna/analysis/forecast/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,27 @@ def plot_forecast(
plot_df = pd.DataFrame(columns=["timestamp", "target", "segment"])

if (train_ts is not None) and (n_train_samples != 0):
ax[i].plot(plot_df.index.values, plot_df.target.values, label="train")
marker = None if len(plot_df) > 1 else "o"
ax[i].plot(plot_df.index.values, plot_df.target.values, label="train", marker=marker)
if test_ts is not None:
ax[i].plot(segment_test_df.index.values, segment_test_df.target.values, color="purple", label="test")
marker = None if len(segment_test_df) > 1 else "o"
ax[i].plot(
segment_test_df.index.values, segment_test_df.target.values, color="purple", label="test", marker=marker
)

# plot forecast plot for each of given forecasts
for forecast_name, forecast in forecast_results.items():
legend_prefix = f"{forecast_name}: " if num_forecasts > 1 else ""

segment_forecast_df = forecast[:, segment, :][segment]
marker = None if len(segment_forecast_df) > 1 else "o"

line = ax[i].plot(
segment_forecast_df.index.values,
segment_forecast_df.target.values,
linewidth=1,
label=f"{legend_prefix}forecast",
marker=marker,
)
forecast_color = line[0].get_color()

Expand All @@ -182,7 +189,8 @@ def plot_forecast(
segment_borders_df.index.values,
values_low,
values_high,
facecolor=forecast_color,
linewidth=3,
color=forecast_color,
alpha=alpha[interval_idx],
label=f"{legend_prefix}{low_border}-{high_border}",
)
Expand All @@ -196,7 +204,8 @@ def plot_forecast(
segment_borders_df.index.values,
values_low,
values_next,
facecolor=forecast_color,
linewidth=3,
color=forecast_color,
alpha=alpha[interval_idx],
label=f"{legend_prefix}{low_border}-{high_border}",
)
Expand All @@ -205,17 +214,19 @@ def plot_forecast(
segment_borders_df.index.values,
values_high,
values_prev,
facecolor=forecast_color,
linewidth=3,
color=forecast_color,
alpha=alpha[interval_idx],
)
# when we can't find pair for border, we plot it separately
if len(prediction_intervals_names) % 2 != 0:
remaining_border = prediction_intervals_names[len(prediction_intervals_names) // 2]
values = segment_borders_df[remaining_border].values
marker = "--" if len(values) > 1 else "d"
ax[i].plot(
segment_borders_df.index.values,
values,
"--",
marker,
color=forecast_color,
label=f"{legend_prefix}{remaining_border}",
)
Expand Down

0 comments on commit 3553761

Please sign in to comment.