Skip to content

Commit

Permalink
Remove sorting of ts.df by timestamps in plot methods (#410)
Browse files Browse the repository at this point in the history
* fix: remove sorting of ts.df by timestamps in plot methods (#389)

* fix: remove sorting of ts.df by timestamps in plot methods (#389)

* Update CHANGELOG.md

---------

Co-authored-by: Mikhail <[email protected]>
Co-authored-by: d-a-bunin <[email protected]>
  • Loading branch information
3 people authored Jun 26, 2024
1 parent 0b2cd55 commit ae2bfce
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 11 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- Fix rendering in 210 tutorial ([#386](https://github.com/etna-team/etna/pull/386))
-
- Remove sorting of `ts.df` by timestamps in `plot_forecast` and `plot_forecast_decomposition` ([#410](https://github.com/etna-team/etna/pull/410))
-
-
-
Expand Down
12 changes: 2 additions & 10 deletions etna/analysis/forecast/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,6 @@ def plot_forecast(
if prediction_intervals:
prediction_intervals_names = _select_prediction_intervals_names(forecast_results, quantiles)

if train_ts is not None:
train_ts.df.sort_values(by="timestamp", inplace=True)
if test_ts is not None:
test_ts.df.sort_values(by="timestamp", inplace=True)

for i, segment in enumerate(segments):
if train_ts is not None:
segment_train_df = train_ts[:, segment, :][segment]
Expand Down Expand Up @@ -156,7 +151,7 @@ def plot_forecast(
for forecast_name, forecast in forecast_results.items():
legend_prefix = f"{forecast_name}: " if num_forecasts > 1 else ""

segment_forecast_df = forecast[:, segment, :][segment].sort_values(by="timestamp")
segment_forecast_df = forecast[:, segment, :][segment]
line = ax[i].plot(
segment_forecast_df.index.values,
segment_forecast_df.target.values,
Expand Down Expand Up @@ -914,9 +909,6 @@ def plot_forecast_decomposition(

_, ax = _prepare_axes(num_plots=num_plots, columns_num=columns_num, figsize=figsize, set_grid=show_grid)

if test_ts is not None:
test_ts.df.sort_values(by="timestamp", inplace=True)

alpha = 0.5 if components_mode == ComponentsMode.joint else 1.0
ax_array = np.asarray(ax).reshape(-1, columns_num).T.ravel()

Expand All @@ -927,7 +919,7 @@ def plot_forecast_decomposition(
else:
segment_test_df = pd.DataFrame(columns=["timestamp", "target", "segment"])

segment_forecast_df = forecast_ts[:, segment, :][segment].sort_values(by="timestamp")
segment_forecast_df = forecast_ts[:, segment, :][segment]

ax_array[i].set_title(segment)

Expand Down

0 comments on commit ae2bfce

Please sign in to comment.