Skip to content

Commit

Permalink
fix: fix using numpy.warnings in sklearn.preprocessing.PowerTransform…
Browse files Browse the repository at this point in the history
…._check_input
  • Loading branch information
d-a-bunin committed Oct 18, 2023
1 parent 71d1ae5 commit efc5d9e
Showing 1 changed file with 77 additions and 0 deletions.
77 changes: 77 additions & 0 deletions etna/transforms/math/power.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import functools
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from unittest.mock import patch

import pandas as pd
from sklearn.preprocessing import PowerTransformer

from etna.distributions import BaseDistribution
Expand All @@ -11,6 +14,20 @@
from etna.transforms.math.sklearn import TransformMode


def _replace_warnings_decorator(func):
@functools.wraps(func)
def inner(*args, **kwargs):
import warnings

import numpy as np

np.warnings = warnings

return func(*args, **kwargs)

return inner


class YeoJohnsonTransform(SklearnTransform):
"""YeoJohnsonTransform applies Yeo-Johns transformation to a DataFrame.
Expand Down Expand Up @@ -61,6 +78,36 @@ def __init__(
mode=mode,
)

def _fit(self, df: pd.DataFrame) -> "SklearnTransform":
"""
Fit transformer with data from df.
Here we are going to make a `patch <https://github.com/scikit-learn/scikit-learn/pull/23654>`_.
"""
new_check_input = _replace_warnings_decorator(PowerTransformer._check_input)
with patch("sklearn.preprocessing.PowerTransformer._check_input", new=new_check_input):
return super()._fit(df=df)

def _transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Fit transformer with data from df.
Here we are going to make a `patch <https://github.com/scikit-learn/scikit-learn/pull/23654>`_.
"""
new_check_input = _replace_warnings_decorator(PowerTransformer._check_input)
with patch("sklearn.preprocessing.PowerTransformer._check_input", new=new_check_input):
return super()._transform(df=df)

def _inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Fit transformer with data from df.
Here we are going to make a `patch <https://github.com/scikit-learn/scikit-learn/pull/23654>`_.
"""
new_check_input = _replace_warnings_decorator(PowerTransformer._check_input)
with patch("sklearn.preprocessing.PowerTransformer._check_input", new=new_check_input):
return super()._inverse_transform(df=df)

def params_to_tune(self) -> Dict[str, BaseDistribution]:
"""Get default grid for tuning hyperparameters.
Expand Down Expand Up @@ -130,6 +177,36 @@ def __init__(
mode=mode,
)

def _fit(self, df: pd.DataFrame) -> "SklearnTransform":
"""
Fit transformer with data from df.
Here we are going to make a `patch <https://github.com/scikit-learn/scikit-learn/pull/23654>`_.
"""
new_check_input = _replace_warnings_decorator(PowerTransformer._check_input)
with patch("sklearn.preprocessing.PowerTransformer._check_input", new=new_check_input):
return super()._fit(df=df)

def _transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Fit transformer with data from df.
Here we are going to make a `patch <https://github.com/scikit-learn/scikit-learn/pull/23654>`_.
"""
new_check_input = _replace_warnings_decorator(PowerTransformer._check_input)
with patch("sklearn.preprocessing.PowerTransformer._check_input", new=new_check_input):
return super()._transform(df=df)

def _inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Fit transformer with data from df.
Here we are going to make a `patch <https://github.com/scikit-learn/scikit-learn/pull/23654>`_.
"""
new_check_input = _replace_warnings_decorator(PowerTransformer._check_input)
with patch("sklearn.preprocessing.PowerTransformer._check_input", new=new_check_input):
return super()._inverse_transform(df=df)

def params_to_tune(self) -> Dict[str, BaseDistribution]:
"""Get default grid for tuning hyperparameters.
Expand Down

0 comments on commit efc5d9e

Please sign in to comment.