Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix analysis.forecast.plots.metric_per_segment_distribution_plot to handle None from metrics #543

Merged
merged 4 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add parameter `missing_mode` into `R2` and `MedAE` metrics ([#537](https://github.com/etna-team/etna/pull/537))
- Update `analysis.forecast.plots.plot_metric_per_segment` to handle `None` from metrics ([#540](https://github.com/etna-team/etna/pull/540))
- Add parameter `missing_mode` into `RMSE` and `MSLE` metrics ([#542](https://github.com/etna-team/etna/pull/542))
-
- Update `analysis.forecast.plots.metric_per_segment_distribution_plot` to handle `None` from metrics ([#543](https://github.com/etna-team/etna/pull/543))
-
-

Expand Down
19 changes: 17 additions & 2 deletions etna/analysis/forecast/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def plot_metric_per_segment(
Warnings
--------
UserWarning:
There are segments without non-missing metric values.
There are segments with all missing metric values.
UserWarning:
Some segments have different set of folds to be aggregated on due to missing values.
"""
Expand Down Expand Up @@ -803,7 +803,12 @@ def metric_per_segment_distribution_plot(
seaborn_params: Optional[Dict[str, Any]] = None,
figsize: Tuple[int, int] = (10, 5),
):
"""Plot per-segment metrics distribution.
"""Plot distribution of metric values over all segments.

If for some segment all metric values are missing, it isn't plotted, and the warning is raised.

If some segments have different set of folds with non-missing metrics,
it can lead to incompatible values between folds. The warning is raised in such case.

Parameters
----------
Expand Down Expand Up @@ -831,6 +836,13 @@ def metric_per_segment_distribution_plot(
if ``metric_name`` isn't present in ``metrics_df``
NotImplementedError:
unknown ``per_fold_aggregation_mode`` is given

Warnings
--------
UserWarning:
There are segments with all missing metric values.
UserWarning:
Some segments have different set of folds to be aggregated on due to missing values.
"""
if seaborn_params is None:
seaborn_params = {}
Expand All @@ -844,6 +856,9 @@ def metric_per_segment_distribution_plot(
if metric_name not in metrics_df.columns:
raise ValueError("Given metric_name isn't present in metrics_df")

_check_metrics_df_empty_segments(metrics_df=metrics_df, metric_name=metric_name)
_check_metrics_df_same_folds_for_each_segment(metrics_df=metrics_df, metric_name=metric_name)

# draw plot for each fold
if per_fold_aggregation_mode is None and "fold_number" in metrics_df.columns:
if plot_type_enum == MetricPlotType.hist:
Expand Down
2 changes: 1 addition & 1 deletion etna/analysis/forecast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,6 @@ def _check_metrics_df_same_folds_for_each_segment(metrics_df: pd.DataFrame, metr
df = metrics_df[["segment", "fold_number", metric_name]]
# we don't take into account segments without any non-missing metrics, they are handled by other check
df = df.dropna(subset=[metric_name])
num_unique = df.groupby("segment")["fold_number"].apply(frozenset).nunique()
num_unique = df.groupby("segment", group_keys=False)["fold_number"].apply(frozenset).nunique()
if num_unique > 1:
warnings.warn("Some segments have different set of folds to be aggregated on due to missing values.")
49 changes: 48 additions & 1 deletion tests/test_analysis/test_forecast/test_plots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
import pytest

from etna.analysis import metric_per_segment_distribution_plot
from etna.analysis import plot_metric_per_segment
from etna.analysis import plot_residuals
from etna.analysis.forecast.plots import _get_borders_comparator
Expand Down Expand Up @@ -80,7 +81,7 @@ def metrics_df_no_folds(metrics_df_with_folds) -> pd.DataFrame:
"df_name, metric_name",
[
("metrics_df_with_folds", "MAE"),
("metrics_df_no_folds", "MSE"),
("metrics_df_no_folds", "MAE"),
("metrics_df_no_folds", "MSE"),
],
)
Expand Down Expand Up @@ -112,3 +113,49 @@ def test_plot_metric_per_segment_warning_non_comparable_segments(df_name, metric
metrics_df = request.getfixturevalue(df_name)
with pytest.warns(UserWarning, match="Some segments have different set of folds to be aggregated on"):
plot_metric_per_segment(metrics_df=metrics_df, metric_name=metric_name)


@pytest.mark.parametrize("plot_type", ["hist", "box", "violin"])
@pytest.mark.parametrize(
"df_name, metric_name, per_fold_aggregation_mode",
[
("metrics_df_with_folds", "MAE", None),
("metrics_df_with_folds", "MAE", "mean"),
("metrics_df_with_folds", "MAE", "median"),
("metrics_df_no_folds", "MAE", None),
("metrics_df_no_folds", "MSE", None),
],
)
def test_plot_metric_per_segment_ok(df_name, metric_name, per_fold_aggregation_mode, plot_type, request):
metrics_df = request.getfixturevalue(df_name)
metric_per_segment_distribution_plot(
metrics_df=metrics_df,
metric_name=metric_name,
per_fold_aggregation_mode=per_fold_aggregation_mode,
plot_type=plot_type,
)


@pytest.mark.parametrize(
"df_name, metric_name",
[
("metrics_df_with_folds", "MAPE"),
("metrics_df_no_folds", "RMSE"),
],
)
def test_plot_metric_per_segment_warning_empty_segments(df_name, metric_name, request):
metrics_df = request.getfixturevalue(df_name)
with pytest.warns(UserWarning, match="There are segments with all missing metric values"):
metric_per_segment_distribution_plot(metrics_df=metrics_df, metric_name=metric_name)


@pytest.mark.parametrize(
"df_name, metric_name",
[
("metrics_df_with_folds", "MSE"),
],
)
def test_plot_metric_per_segment_warning_non_comparable_segments(df_name, metric_name, request):
metrics_df = request.getfixturevalue(df_name)
with pytest.warns(UserWarning, match="Some segments have different set of folds to be aggregated on"):
metric_per_segment_distribution_plot(metrics_df=metrics_df, metric_name=metric_name)
Loading