Skip to content

Commit

Permalink
Add handling sklearn version in OneHotEncoderTransform and `Hierarc…
Browse files Browse the repository at this point in the history
…hicalClustering` (#529)
  • Loading branch information
d-a-bunin authored Dec 18, 2024
1 parent 064e645 commit f527fdb
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix working with NaN target in `MeanEncoderTransform` ([#492](https://github.com/etna-team/etna/pull/492))
- Fix `target` leakage in `MeanSegmentEncoderTransform` ([#503](https://github.com/etna-team/etna/pull/503))
-
-
- Add handling scikit-learn version >= 1.4 in `OneHotEncoderTransform` and `HierarchicalClustering` ([#529](https://github.com/etna-team/etna/pull/529))
-
-
-
Expand Down
10 changes: 9 additions & 1 deletion etna/clustering/hierarchical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Union

import pandas as pd
from sklearn import __version__ as sklearn_version
from sklearn.cluster import AgglomerativeClustering

from etna.clustering.base import Clustering
Expand Down Expand Up @@ -81,9 +82,16 @@ def build_clustering_algo(
"""
self.n_clusters = n_clusters
self.linkage = ClusteringLinkageMode(linkage).name

sklearn_version_tuple = tuple(map(int, sklearn_version.split(".")))
if sklearn_version_tuple < (1, 2):
clustering_algo_params["affinity"] = "precomputed"
else:
clustering_algo_params["metric"] = "precomputed"
self.clustering_algo = AgglomerativeClustering(
n_clusters=self.n_clusters, affinity="precomputed", linkage=self.linkage, **clustering_algo_params
n_clusters=self.n_clusters, linkage=self.linkage, **clustering_algo_params
)

self.clusters = None
self.segment2cluster = None
self.centroids_df = None
Expand Down
11 changes: 10 additions & 1 deletion etna/transforms/encoders/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
from sklearn import __version__ as sklearn_version
from sklearn import preprocessing
from sklearn.utils._encode import _check_unknown
from sklearn.utils._encode import _encode
Expand Down Expand Up @@ -215,7 +216,15 @@ def __init__(self, in_column: str, out_column: Optional[str] = None, return_type
self.in_column = in_column
self.out_column = out_column
self.return_type = ReturnType(return_type)
self.ohe = preprocessing.OneHotEncoder(handle_unknown="ignore", sparse=False, dtype=int)

sklearn_version_tuple = tuple(map(int, sklearn_version.split(".")))
encoder_params = {}
if sklearn_version_tuple < (1, 2):
encoder_params["sparse"] = False
else:
encoder_params["sparse_output"] = False
self.ohe = preprocessing.OneHotEncoder(handle_unknown="ignore", dtype=int, **encoder_params)

self.in_column_regressor: Optional[bool] = None

def get_regressors_info(self) -> List[str]:
Expand Down

0 comments on commit f527fdb

Please sign in to comment.