Skip to content

Commit

Permalink
Update interval metrics to work with arbitrary interval bounds (#113)
Browse files Browse the repository at this point in the history
* use prediction intervals from dataset

* added tests

* updated changelog

* removed unused fixture

* reworked default behavior

* updated tests

* review fix
  • Loading branch information
brsnw250 authored Oct 17, 2023
1 parent 71d1ae5 commit f9e1f11
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 43 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add config for Codecov to control CI ([#80](https://github.com/etna-team/etna/pull/80))
- Add `EventTransform` ([#78](https://github.com/etna-team/etna/pull/78))
- `NaiveVariancePredictionIntervals` method for prediction quantiles estimation ([#109](https://github.com/etna-team/etna/pull/109))
- Update interval metrics to work with arbitrary interval bounds ([#113](https://github.com/etna-team/etna/pull/113))

### Changed
-
Expand Down
135 changes: 116 additions & 19 deletions etna/metrics/intervals_metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Dict
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union

import numpy as np
import pandas as pd

from etna.datasets import TSDataset
from etna.metrics.base import Metric
Expand All @@ -15,15 +17,30 @@ def dummy(y_true: ArrayLike, y_pred: ArrayLike) -> ArrayLike:
return np.nan


class _QuantileMetricMixin:
def _validate_tsdataset_quantiles(self, ts: TSDataset, quantiles: Sequence[float]) -> None:
"""Check if quantiles presented in y_pred."""
features = set(ts.df.columns.get_level_values("feature"))
for quantile in quantiles:
assert f"target_{quantile:.4g}" in features, f"Quantile {quantile} is not presented in tsdataset."
class _IntervalsMetricMixin:
def _validate_tsdataset_intervals(
self, ts: TSDataset, quantiles: Sequence[float], upper_name: Optional[str], lower_name: Optional[str]
) -> None:
"""Check if intervals borders presented in ``y_pred``."""
ts_intervals = set(ts.prediction_intervals_names)

borders_set = {upper_name, lower_name}
borders_presented = borders_set.issubset(ts_intervals)

class Coverage(Metric, _QuantileMetricMixin):
quantiles_set = {f"target_{quantile:.4g}" for quantile in quantiles}
quantiles_presented = quantiles_set.issubset(ts_intervals)
quantiles_presented &= len(quantiles_set) > 0

if upper_name is not None and lower_name is not None:
if not borders_presented:
raise ValueError("Provided intervals borders names must be in dataset!")

else:
if not quantiles_presented:
raise ValueError("All quantiles must be presented in the dataset!")


class Coverage(Metric, _IntervalsMetricMixin):
"""Coverage metric for prediction intervals - precenteage of samples in the interval ``[lower quantile, upper quantile]``.
.. math::
Expand All @@ -32,10 +49,17 @@ class Coverage(Metric, _QuantileMetricMixin):
Notes
-----
Works just if ``quantiles`` presented in ``y_pred``
When ``quantiles``, ``upper_name`` and ``lower_name`` all set to ``None`` then 0.025 and 0.975 quantiles will be used.
"""

def __init__(
self, quantiles: Tuple[float, float] = (0.025, 0.975), mode: str = MetricAggregationMode.per_segment, **kwargs
self,
quantiles: Optional[Tuple[float, float]] = None,
mode: str = MetricAggregationMode.per_segment,
upper_name: Optional[str] = None,
lower_name: Optional[str] = None,
**kwargs,
):
"""Init metric.
Expand All @@ -45,11 +69,32 @@ def __init__(
lower and upper quantiles
mode: 'macro' or 'per-segment'
metrics aggregation mode
upper_name:
name of column with upper border of the interval
lower_name:
name of column with lower border of the interval
kwargs:
metric's computation arguments
"""
if (lower_name is None) ^ (upper_name is None):
raise ValueError("Both `lower_name` and `upper_name` must be set if using names to specify borders!")

if not (quantiles is None or lower_name is None):
raise ValueError(
"Both `quantiles` and border names are specified. Use only one way to set interval borders!"
)

if quantiles is not None and len(quantiles) != 2:
raise ValueError(f"Expected tuple with two values for `quantiles` parameter, got {len(quantiles)}")

# default behavior
if quantiles is None and lower_name is None:
quantiles = (0.025, 0.975)

super().__init__(mode=mode, metric_fn=dummy, **kwargs)
self.quantiles = quantiles
self.quantiles = sorted(quantiles if quantiles is not None else tuple())
self.upper_name = upper_name
self.lower_name = lower_name

def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[str, float]]:
"""
Expand All @@ -74,11 +119,23 @@ def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[st
self._validate_target_columns(y_true=y_true, y_pred=y_pred)
self._validate_index(y_true=y_true, y_pred=y_pred)
self._validate_nans(y_true=y_true, y_pred=y_pred)
self._validate_tsdataset_quantiles(ts=y_pred, quantiles=self.quantiles)
self._validate_tsdataset_intervals(
ts=y_pred, quantiles=self.quantiles, lower_name=self.lower_name, upper_name=self.upper_name
)

if self.upper_name is not None:
lower_border = self.lower_name
upper_border = self.upper_name

else:
lower_border = f"target_{self.quantiles[0]:.4g}"
upper_border = f"target_{self.quantiles[1]:.4g}"

df_true = y_true[:, :, "target"].sort_index(axis=1)
df_pred_lower = y_pred[:, :, f"target_{self.quantiles[0]:.4g}"].sort_index(axis=1)
df_pred_upper = y_pred[:, :, f"target_{self.quantiles[1]:.4g}"].sort_index(axis=1)

intervals_df: pd.DataFrame = y_pred.get_prediction_intervals()
df_pred_lower = intervals_df.loc[:, pd.IndexSlice[:, lower_border]].sort_index(axis=1)
df_pred_upper = intervals_df.loc[:, pd.IndexSlice[:, upper_border]].sort_index(axis=1)

segments = df_true.columns.get_level_values("segment").unique()

Expand All @@ -96,19 +153,26 @@ def greater_is_better(self) -> None:
return None


class Width(Metric, _QuantileMetricMixin):
class Width(Metric, _IntervalsMetricMixin):
"""Mean width of prediction intervals.
.. math::
Width(y\_true, y\_pred) = \\frac{\\sum_{i=0}^{n-1}\\mid y\_pred_i^{upper\_quantile} - y\_pred_i^{lower\_quantile} \\mid}{n}
Notes
-----
Works just if quantiles presented in ``y_pred``
Works just if quantiles presented in ``y_pred``.
When ``quantiles``, ``upper_name`` and ``lower_name`` all set to ``None`` then 0.025 and 0.975 quantiles will be used.
"""

def __init__(
self, quantiles: Tuple[float, float] = (0.025, 0.975), mode: str = MetricAggregationMode.per_segment, **kwargs
self,
quantiles: Optional[Tuple[float, float]] = None,
mode: str = MetricAggregationMode.per_segment,
upper_name: Optional[str] = None,
lower_name: Optional[str] = None,
**kwargs,
):
"""Init metric.
Expand All @@ -118,11 +182,32 @@ def __init__(
lower and upper quantiles
mode: 'macro' or 'per-segment'
metrics aggregation mode
upper_name:
name of column with upper border of the interval
lower_name:
name of column with lower border of the interval
kwargs:
metric's computation arguments
"""
if (lower_name is None) ^ (upper_name is None):
raise ValueError("Both `lower_name` and `upper_name` must be set if using names to specify borders!")

if not (quantiles is None or lower_name is None):
raise ValueError(
"Both `quantiles` and border names are specified. Use only one way to set interval borders!"
)

if quantiles is not None and len(quantiles) != 2:
raise ValueError(f"Expected tuple with two values for `quantiles` parameter, got {len(quantiles)}")

# default behavior
if quantiles is None and lower_name is None:
quantiles = (0.025, 0.975)

super().__init__(mode=mode, metric_fn=dummy, **kwargs)
self.quantiles = quantiles
self.quantiles = sorted(quantiles if quantiles is not None else tuple())
self.upper_name = upper_name
self.lower_name = lower_name

def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[str, float]]:
"""
Expand All @@ -147,11 +232,23 @@ def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[st
self._validate_target_columns(y_true=y_true, y_pred=y_pred)
self._validate_index(y_true=y_true, y_pred=y_pred)
self._validate_nans(y_true=y_true, y_pred=y_pred)
self._validate_tsdataset_quantiles(ts=y_pred, quantiles=self.quantiles)
self._validate_tsdataset_intervals(
ts=y_pred, quantiles=self.quantiles, lower_name=self.lower_name, upper_name=self.upper_name
)

if self.upper_name is not None:
lower_border = self.lower_name
upper_border = self.upper_name

else:
lower_border = f"target_{self.quantiles[0]:.4g}"
upper_border = f"target_{self.quantiles[1]:.4g}"

df_true = y_true[:, :, "target"].sort_index(axis=1)
df_pred_lower = y_pred[:, :, f"target_{self.quantiles[0]:.4g}"].sort_index(axis=1)
df_pred_upper = y_pred[:, :, f"target_{self.quantiles[1]:.4g}"].sort_index(axis=1)

intervals_df: pd.DataFrame = y_pred.get_prediction_intervals()
df_pred_lower = intervals_df.loc[:, pd.IndexSlice[:, lower_border]].sort_index(axis=1)
df_pred_upper = intervals_df.loc[:, pd.IndexSlice[:, upper_border]].sort_index(axis=1)

segments = df_true.columns.get_level_values("segment").unique()

Expand Down
Loading

0 comments on commit f9e1f11

Please sign in to comment.