From ae2bfce7817b051fcdbbfcf80bdcb3ba53eb5e0c Mon Sep 17 00:00:00 2001 From: Mikhail <92105261+kenshi777@users.noreply.github.com> Date: Wed, 26 Jun 2024 13:07:51 +0300 Subject: [PATCH] Remove sorting of ts.df by timestamps in plot methods (#410) * fix: remove sorting of ts.df by timestamps in plot methods (#389) * fix: remove sorting of ts.df by timestamps in plot methods (#389) * Update CHANGELOG.md --------- Co-authored-by: Mikhail Co-authored-by: d-a-bunin <142778107+d-a-bunin@users.noreply.github.com> --- CHANGELOG.md | 2 +- etna/analysis/forecast/plots.py | 12 ++---------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e438b3cee..5ef17e885 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fix rendering in 210 tutorial ([#386](https://github.com/etna-team/etna/pull/386)) -- +- Remove sorting of `ts.df` by timestamps in `plot_forecast` and `plot_forecast_decomposition` ([#410](https://github.com/etna-team/etna/pull/410)) - - - diff --git a/etna/analysis/forecast/plots.py b/etna/analysis/forecast/plots.py index 7dd4b4c4c..bf95771f7 100644 --- a/etna/analysis/forecast/plots.py +++ b/etna/analysis/forecast/plots.py @@ -124,11 +124,6 @@ def plot_forecast( if prediction_intervals: prediction_intervals_names = _select_prediction_intervals_names(forecast_results, quantiles) - if train_ts is not None: - train_ts.df.sort_values(by="timestamp", inplace=True) - if test_ts is not None: - test_ts.df.sort_values(by="timestamp", inplace=True) - for i, segment in enumerate(segments): if train_ts is not None: segment_train_df = train_ts[:, segment, :][segment] @@ -156,7 +151,7 @@ def plot_forecast( for forecast_name, forecast in forecast_results.items(): legend_prefix = f"{forecast_name}: " if num_forecasts > 1 else "" - segment_forecast_df = forecast[:, segment, :][segment].sort_values(by="timestamp") + segment_forecast_df = forecast[:, segment, :][segment] line = ax[i].plot( segment_forecast_df.index.values, segment_forecast_df.target.values, @@ -914,9 +909,6 @@ def plot_forecast_decomposition( _, ax = _prepare_axes(num_plots=num_plots, columns_num=columns_num, figsize=figsize, set_grid=show_grid) - if test_ts is not None: - test_ts.df.sort_values(by="timestamp", inplace=True) - alpha = 0.5 if components_mode == ComponentsMode.joint else 1.0 ax_array = np.asarray(ax).reshape(-1, columns_num).T.ravel() @@ -927,7 +919,7 @@ def plot_forecast_decomposition( else: segment_test_df = pd.DataFrame(columns=["timestamp", "target", "segment"]) - segment_forecast_df = forecast_ts[:, segment, :][segment].sort_values(by="timestamp") + segment_forecast_df = forecast_ts[:, segment, :][segment] ax_array[i].set_title(segment)