Skip to content

Commit

Permalink
Add MissingCounter metric (#520)
Browse files Browse the repository at this point in the history
  • Loading branch information
d-a-bunin authored Dec 11, 2024
1 parent db5257f commit 82c1be2
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 24 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
- Add `MissingCounter` metric ([#520](https://github.com/etna-team/etna/pull/520))
-
-
-
-
-
-
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_reference/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Scalar metrics:
MaxDeviation
MedAE
Sign
MissingCounter

Interval metrics:

Expand Down
1 change: 1 addition & 0 deletions etna/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@
from etna.metrics.metrics import WAPE
from etna.metrics.metrics import MaxDeviation
from etna.metrics.metrics import MedAE
from etna.metrics.metrics import MissingCounter
from etna.metrics.metrics import Sign
from etna.metrics.utils import compute_metrics
53 changes: 52 additions & 1 deletion etna/metrics/functional_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,55 @@ def wape(y_true: ArrayLike, y_pred: ArrayLike, multioutput: str = "joint") -> Ar
return np.sum(np.abs(y_true_array - y_pred_array), axis=axis) / np.sum(np.abs(y_true_array), axis=axis) # type: ignore


__all__ = ["mae", "mse", "msle", "medae", "r2_score", "mape", "smape", "sign", "max_deviation", "rmse", "wape"]
def count_missing_values(y_true: ArrayLike, y_pred: ArrayLike, multioutput: str = "joint") -> ArrayLike:
"""Count missing values in ``y_true``.
.. math::
MissingCounter(y\_true, y\_pred) = \\sum_{i=1}^{n}{isnan(y\_true_i)}
Parameters
----------
y_true:
array-like of shape (n_samples,) or (n_samples, n_outputs)
Ground truth (correct) target values.
y_pred:
array-like of shape (n_samples,) or (n_samples, n_outputs)
Estimated target values.
multioutput:
Defines aggregating of multiple output values
(see :py:class:`~etna.metrics.functional_metrics.FunctionalMetricMultioutput`).
Returns
-------
:
A floating point value, or an array of floating point values,
one for each individual target.
"""
y_true_array, y_pred_array = np.asarray(y_true), np.asarray(y_pred)

if len(y_true_array.shape) != len(y_pred_array.shape):
raise ValueError("Shapes of the labels must be the same")

axis = _get_axis_by_multioutput(multioutput)

return np.sum(np.isnan(y_true), axis=axis).astype(float)


__all__ = [
"mae",
"mse",
"msle",
"medae",
"r2_score",
"mape",
"smape",
"sign",
"max_deviation",
"rmse",
"wape",
"count_missing_values",
]
58 changes: 57 additions & 1 deletion etna/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from etna.metrics.base import Metric
from etna.metrics.base import MetricWithMissingHandling
from etna.metrics.functional_metrics import count_missing_values
from etna.metrics.functional_metrics import mae
from etna.metrics.functional_metrics import mape
from etna.metrics.functional_metrics import max_deviation
Expand Down Expand Up @@ -417,4 +418,59 @@ def greater_is_better(self) -> bool:
return False


__all__ = ["MAE", "MSE", "RMSE", "R2", "MSLE", "MAPE", "SMAPE", "MedAE", "Sign", "MaxDeviation", "WAPE"]
class MissingCounter(MetricWithMissingHandling):
"""Missing values counter with multi-segment computation support.
.. math::
MissingCounter(y\_true, y\_pred) = \\sum_{i=1}^{n}{isnan(y\_true_i)}
Notes
-----
You can read more about logic of multi-segment metrics in Metric docs.
"""

def __init__(self, mode: str = "per-segment", **kwargs):
"""Init metric.
Parameters
----------
mode:
"macro" or "per-segment", way to aggregate metric values over segments:
* if "macro" computes average value
* if "per-segment" -- does not aggregate metrics
See :py:class:`~etna.metrics.base.MetricAggregationMode`.
kwargs:
metric's computation arguments
"""
count_missing_values_per_output = partial(count_missing_values, multioutput="raw_values")
super().__init__(
mode=mode,
metric_fn=count_missing_values_per_output,
metric_fn_signature="matrix_to_array",
missing_mode="ignore",
**kwargs,
)

@property
def greater_is_better(self) -> None:
"""Whether higher metric value is better."""
return None


__all__ = [
"MAE",
"MSE",
"RMSE",
"R2",
"MSLE",
"MAPE",
"SMAPE",
"MedAE",
"Sign",
"MaxDeviation",
"WAPE",
"MissingCounter",
]
70 changes: 69 additions & 1 deletion tests/test_metrics/test_functional_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from etna.metrics import sign
from etna.metrics import smape
from etna.metrics import wape
from etna.metrics.functional_metrics import count_missing_values


@pytest.fixture()
Expand Down Expand Up @@ -43,6 +44,7 @@ def y_pred_1d():
(sign, -1),
(max_deviation, 2),
(wape, 1 / 2),
(count_missing_values, 0),
),
)
def test_all_1d_metrics(metric, right_metrics_value, y_true_1d, y_pred_1d):
Expand All @@ -65,6 +67,7 @@ def test_mle_metric_exception(y_true_1d, y_pred_1d):
sign,
max_deviation,
wape,
count_missing_values,
),
)
def test_all_wrong_mode(metric, y_true_1d, y_pred_1d):
Expand Down Expand Up @@ -95,6 +98,7 @@ def y_pred_2d():
(sign, 0),
(max_deviation, 2),
(wape, 1 / 6),
(count_missing_values, 0),
),
)
def test_all_2d_metrics_joint(metric, right_metrics_value, y_true_2d, y_pred_2d):
Expand All @@ -114,6 +118,7 @@ def test_all_2d_metrics_joint(metric, right_metrics_value, y_true_2d, y_pred_2d)
(sign, {"multioutput": "raw_values"}, [0, 0]),
(max_deviation, {"multioutput": "raw_values"}, [1, 1]),
(wape, {"multioutput": "raw_values"}, [0.0952381, 2 / 3]),
(count_missing_values, {"multioutput": "raw_values"}, [0, 0]),
),
)
def test_all_2d_metrics_per_output(metric, params, right_metrics_value, y_true_2d, y_pred_2d):
Expand Down Expand Up @@ -177,6 +182,69 @@ def test_all_2d_metrics_per_output(metric, params, right_metrics_value, y_true_2
),
],
)
def test_values_ok(y_true, y_pred, multioutput, expected):
def test_mse_ok(y_true, y_pred, multioutput, expected):
result = mse(y_true=y_true, y_pred=y_pred, multioutput=multioutput)
npt.assert_allclose(result, expected)


@pytest.mark.parametrize(
"y_true, y_pred, multioutput, expected",
[
# 1d
(np.array([1.0]), np.array([1.0]), "joint", 0.0),
(np.array([1.0, 2.0, 3.0]), np.array([3.0, 1.0, 2.0]), "joint", 0.0),
(np.array([1.0, np.NaN, 3.0]), np.array([3.0, 1.0, 2.0]), "joint", 1.0),
(np.array([1.0, 2.0, 3.0]), np.array([3.0, np.NaN, 2.0]), "joint", 0.0),
(np.array([1.0, np.NaN, 3.0]), np.array([3.0, np.NaN, 2.0]), "joint", 1.0),
(np.array([1.0, np.NaN, 3.0]), np.array([3.0, 1.0, np.NaN]), "joint", 1.0),
(np.array([1.0, np.NaN, np.NaN]), np.array([np.NaN, np.NaN, 2.0]), "joint", 2.0),
(np.array([np.NaN, np.NaN, np.NaN]), np.array([3.0, 1.0, 2.0]), "joint", 3.0),
# 2d
(np.array([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]]).T, np.array([[3.0, 1.0, 2.0], [5.0, 2.0, 4.0]]).T, "joint", 0.0),
(
np.array([[1.0, np.NaN, 3.0], [3.0, 4.0, np.NaN]]).T,
np.array([[3.0, 1.0, 2.0], [5.0, 2.0, 4.0]]).T,
"joint",
2.0,
),
(
np.array([[np.NaN, np.NaN, np.NaN], [3.0, 4.0, 5.0]]).T,
np.array([[3.0, 1.0, 2.0], [5.0, 2.0, 4.0]]).T,
"joint",
3.0,
),
(
np.array([[np.NaN, np.NaN, np.NaN], [np.NaN, np.NaN, np.NaN]]).T,
np.array([[3.0, 1.0, 2.0], [5.0, 2.0, 4.0]]).T,
"joint",
6.0,
),
(
np.array([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]]).T,
np.array([[3.0, 1.0, 2.0], [5.0, 2.0, 4.0]]).T,
"raw_values",
np.array([0.0, 0.0]),
),
(
np.array([[1.0, np.NaN, 3.0], [3.0, 4.0, np.NaN]]).T,
np.array([[3.0, 1.0, 2.0], [5.0, 2.0, 4.0]]).T,
"raw_values",
np.array([1.0, 1.0]),
),
(
np.array([[np.NaN, np.NaN, np.NaN], [3.0, 4.0, 5.0]]).T,
np.array([[3.0, 1.0, 2.0], [5.0, 2.0, 4.0]]).T,
"raw_values",
np.array([3.0, 0.0]),
),
(
np.array([[np.NaN, np.NaN, np.NaN], [np.NaN, np.NaN, np.NaN]]).T,
np.array([[3.0, 1.0, 2.0], [5.0, 2.0, 4.0]]).T,
"raw_values",
np.array([3.0, 3.0]),
),
],
)
def test_count_missing_values_ok(y_true, y_pred, multioutput, expected):
result = count_missing_values(y_true=y_true, y_pred=y_pred, multioutput=multioutput)
npt.assert_allclose(result, expected)
Loading

0 comments on commit 82c1be2

Please sign in to comment.