Skip to content

Commit

Permalink
Fix analysis.forecast.plots.metric_per_segment_distribution_plot to…
Browse files Browse the repository at this point in the history
… handle `None` from metrics (#543)
  • Loading branch information
d-a-bunin authored Dec 24, 2024
1 parent 7b37537 commit e73d138
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add parameter `missing_mode` into `R2` and `MedAE` metrics ([#537](https://github.com/etna-team/etna/pull/537))
- Update `analysis.forecast.plots.plot_metric_per_segment` to handle `None` from metrics ([#540](https://github.com/etna-team/etna/pull/540))
- Add parameter `missing_mode` into `RMSE` and `MSLE` metrics ([#542](https://github.com/etna-team/etna/pull/542))
-
- Update `analysis.forecast.plots.metric_per_segment_distribution_plot` to handle `None` from metrics ([#543](https://github.com/etna-team/etna/pull/543))
-
-

Expand Down
19 changes: 17 additions & 2 deletions etna/analysis/forecast/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def plot_metric_per_segment(
Warnings
--------
UserWarning:
There are segments without non-missing metric values.
There are segments with all missing metric values.
UserWarning:
Some segments have different set of folds to be aggregated on due to missing values.
"""
Expand Down Expand Up @@ -803,7 +803,12 @@ def metric_per_segment_distribution_plot(
seaborn_params: Optional[Dict[str, Any]] = None,
figsize: Tuple[int, int] = (10, 5),
):
"""Plot per-segment metrics distribution.
"""Plot distribution of metric values over all segments.
If for some segment all metric values are missing, it isn't plotted, and the warning is raised.
If some segments have different set of folds with non-missing metrics,
it can lead to incompatible values between folds. The warning is raised in such case.
Parameters
----------
Expand Down Expand Up @@ -831,6 +836,13 @@ def metric_per_segment_distribution_plot(
if ``metric_name`` isn't present in ``metrics_df``
NotImplementedError:
unknown ``per_fold_aggregation_mode`` is given
Warnings
--------
UserWarning:
There are segments with all missing metric values.
UserWarning:
Some segments have different set of folds to be aggregated on due to missing values.
"""
if seaborn_params is None:
seaborn_params = {}
Expand All @@ -844,6 +856,9 @@ def metric_per_segment_distribution_plot(
if metric_name not in metrics_df.columns:
raise ValueError("Given metric_name isn't present in metrics_df")

_check_metrics_df_empty_segments(metrics_df=metrics_df, metric_name=metric_name)
_check_metrics_df_same_folds_for_each_segment(metrics_df=metrics_df, metric_name=metric_name)

# draw plot for each fold
if per_fold_aggregation_mode is None and "fold_number" in metrics_df.columns:
if plot_type_enum == MetricPlotType.hist:
Expand Down
2 changes: 1 addition & 1 deletion etna/analysis/forecast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,6 @@ def _check_metrics_df_same_folds_for_each_segment(metrics_df: pd.DataFrame, metr
df = metrics_df[["segment", "fold_number", metric_name]]
# we don't take into account segments without any non-missing metrics, they are handled by other check
df = df.dropna(subset=[metric_name])
num_unique = df.groupby("segment")["fold_number"].apply(frozenset).nunique()
num_unique = df.groupby("segment", group_keys=False)["fold_number"].apply(frozenset).nunique()
if num_unique > 1:
warnings.warn("Some segments have different set of folds to be aggregated on due to missing values.")
49 changes: 48 additions & 1 deletion tests/test_analysis/test_forecast/test_plots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
import pytest

from etna.analysis import metric_per_segment_distribution_plot
from etna.analysis import plot_metric_per_segment
from etna.analysis import plot_residuals
from etna.analysis.forecast.plots import _get_borders_comparator
Expand Down Expand Up @@ -80,7 +81,7 @@ def metrics_df_no_folds(metrics_df_with_folds) -> pd.DataFrame:
"df_name, metric_name",
[
("metrics_df_with_folds", "MAE"),
("metrics_df_no_folds", "MSE"),
("metrics_df_no_folds", "MAE"),
("metrics_df_no_folds", "MSE"),
],
)
Expand Down Expand Up @@ -112,3 +113,49 @@ def test_plot_metric_per_segment_warning_non_comparable_segments(df_name, metric
metrics_df = request.getfixturevalue(df_name)
with pytest.warns(UserWarning, match="Some segments have different set of folds to be aggregated on"):
plot_metric_per_segment(metrics_df=metrics_df, metric_name=metric_name)


@pytest.mark.parametrize("plot_type", ["hist", "box", "violin"])
@pytest.mark.parametrize(
"df_name, metric_name, per_fold_aggregation_mode",
[
("metrics_df_with_folds", "MAE", None),
("metrics_df_with_folds", "MAE", "mean"),
("metrics_df_with_folds", "MAE", "median"),
("metrics_df_no_folds", "MAE", None),
("metrics_df_no_folds", "MSE", None),
],
)
def test_plot_metric_per_segment_ok(df_name, metric_name, per_fold_aggregation_mode, plot_type, request):
metrics_df = request.getfixturevalue(df_name)
metric_per_segment_distribution_plot(
metrics_df=metrics_df,
metric_name=metric_name,
per_fold_aggregation_mode=per_fold_aggregation_mode,
plot_type=plot_type,
)


@pytest.mark.parametrize(
"df_name, metric_name",
[
("metrics_df_with_folds", "MAPE"),
("metrics_df_no_folds", "RMSE"),
],
)
def test_plot_metric_per_segment_warning_empty_segments(df_name, metric_name, request):
metrics_df = request.getfixturevalue(df_name)
with pytest.warns(UserWarning, match="There are segments with all missing metric values"):
metric_per_segment_distribution_plot(metrics_df=metrics_df, metric_name=metric_name)


@pytest.mark.parametrize(
"df_name, metric_name",
[
("metrics_df_with_folds", "MSE"),
],
)
def test_plot_metric_per_segment_warning_non_comparable_segments(df_name, metric_name, request):
metrics_df = request.getfixturevalue(df_name)
with pytest.warns(UserWarning, match="Some segments have different set of folds to be aggregated on"):
metric_per_segment_distribution_plot(metrics_df=metrics_df, metric_name=metric_name)

0 comments on commit e73d138

Please sign in to comment.