diff --git a/metricsifter/sifter.py b/metricsifter/sifter.py index 03eeb21..d5ac198 100644 --- a/metricsifter/sifter.py +++ b/metricsifter/sifter.py @@ -41,6 +41,26 @@ def filter(x: pd.Series) -> bool: return X.loc[:, utils.parallel_apply(X, filter, n_jobs)] return X.loc[:, X.apply(filter)] + def run_upto_cpd(self, data: pd.DataFrame, without_simple_filter: bool = False) -> pd.DataFrame: + """ Run up to change point detection""" + if without_simple_filter: + X = data + else: + # STEP0: simple filter + X = self._filter_no_changes(data, n_jobs=self.n_jobs) + + # STEP1: detect change points + _, _, metric_to_cps = detection.detect_multi_changepoints( + X, + search_method=self.search_method, + cost_model=self.cost_model, + penalty=self.penalty, + penalty_adjust=self.penalty_adjust, + n_jobs=self.n_jobs, + ) + remained_metrics = set(metric for metric, cps in metric_to_cps.items() if len(cps) > 0) + return X.loc[:, list(remained_metrics)] + def run(self, data: pd.DataFrame, without_simple_filter: bool = False) -> pd.DataFrame: if without_simple_filter: X = data @@ -48,7 +68,6 @@ def run(self, data: pd.DataFrame, without_simple_filter: bool = False) -> pd.Dat # STEP0: simple filter X = self._filter_no_changes(data, n_jobs=self.n_jobs) - metrics: list[str] = X.columns.tolist() # STEP1: detect change points flatten_change_points, cp_to_metrics, metric_to_cps = detection.detect_multi_changepoints( @@ -71,14 +90,12 @@ def run(self, data: pd.DataFrame, without_simple_filter: bool = False) -> pd.Dat ) # STEP3: select the largest segment - remained_metrics = self.select_largest_segment(cluster_label_to_metrics, metrics, metric_to_cps) + remained_metrics = self.select_largest_segment(cluster_label_to_metrics, metric_to_cps) return X.loc[:, list(remained_metrics)] - def select_largest_segment( self, cluster_label_to_metrics: dict, - metrics: list[str], metric_to_cps: dict[str, list[int]], ) -> set[str]: if not cluster_label_to_metrics: diff --git a/tests/test_sifter.py b/tests/test_sifter.py index 3b5fa92..7135add 100644 --- a/tests/test_sifter.py +++ b/tests/test_sifter.py @@ -19,3 +19,9 @@ def test_sifter_run(synthetic_data_20): sifter = Sifter(n_jobs=1) siftered_data = sifter.run(data) assert siftered_data.shape[1] < data.shape[1], "The number of columns should be reduced." + +def test_sifter_run_upto_cpd(synthetic_data_20): + data = synthetic_data_20 + sifter = Sifter(n_jobs=1) + siftered_data = sifter.run_upto_cpd(data) + assert siftered_data.shape[1] < data.shape[1], "The number of columns should be reduced."