Skip to content

Commit

Permalink
add run_upto_cpd method
Browse files Browse the repository at this point in the history
  • Loading branch information
yuuki committed Dec 3, 2023
1 parent ebf03c1 commit 2669290
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
25 changes: 21 additions & 4 deletions metricsifter/sifter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,33 @@ def filter(x: pd.Series) -> bool:
return X.loc[:, utils.parallel_apply(X, filter, n_jobs)]
return X.loc[:, X.apply(filter)]

def run_upto_cpd(self, data: pd.DataFrame, without_simple_filter: bool = False) -> pd.DataFrame:
""" Run up to change point detection"""
if without_simple_filter:
X = data
else:
# STEP0: simple filter
X = self._filter_no_changes(data, n_jobs=self.n_jobs)

# STEP1: detect change points
_, _, metric_to_cps = detection.detect_multi_changepoints(
X,
search_method=self.search_method,
cost_model=self.cost_model,
penalty=self.penalty,
penalty_adjust=self.penalty_adjust,
n_jobs=self.n_jobs,
)
remained_metrics = set(metric for metric, cps in metric_to_cps.items() if len(cps) > 0)
return X.loc[:, list(remained_metrics)]

def run(self, data: pd.DataFrame, without_simple_filter: bool = False) -> pd.DataFrame:
if without_simple_filter:
X = data
else:
# STEP0: simple filter
X = self._filter_no_changes(data, n_jobs=self.n_jobs)

metrics: list[str] = X.columns.tolist()

# STEP1: detect change points
flatten_change_points, cp_to_metrics, metric_to_cps = detection.detect_multi_changepoints(
Expand All @@ -71,14 +90,12 @@ def run(self, data: pd.DataFrame, without_simple_filter: bool = False) -> pd.Dat
)

# STEP3: select the largest segment
remained_metrics = self.select_largest_segment(cluster_label_to_metrics, metrics, metric_to_cps)
remained_metrics = self.select_largest_segment(cluster_label_to_metrics, metric_to_cps)
return X.loc[:, list(remained_metrics)]


def select_largest_segment(
self,
cluster_label_to_metrics: dict,
metrics: list[str],
metric_to_cps: dict[str, list[int]],
) -> set[str]:
if not cluster_label_to_metrics:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_sifter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,9 @@ def test_sifter_run(synthetic_data_20):
sifter = Sifter(n_jobs=1)
siftered_data = sifter.run(data)
assert siftered_data.shape[1] < data.shape[1], "The number of columns should be reduced."

def test_sifter_run_upto_cpd(synthetic_data_20):
data = synthetic_data_20
sifter = Sifter(n_jobs=1)
siftered_data = sifter.run_upto_cpd(data)
assert siftered_data.shape[1] < data.shape[1], "The number of columns should be reduced."

0 comments on commit 2669290

Please sign in to comment.