Skip to content

Commit

Permalink
Add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Oct 6, 2024
1 parent adc910c commit 8127ca9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
6 changes: 5 additions & 1 deletion metalearners/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ def copy_matrix(matrix: Matrix) -> Matrix:

def index_matrix(matrix: Matrix, rows: Vector) -> Matrix:
"""Subselect certain ows from a matrix."""
if isinstance(rows, pd.Series):
if isinstance(rows, pd.Series) or isinstance(rows, pl.Series):
rows = rows.to_numpy()
if isinstance(matrix, pd.DataFrame):
return matrix.iloc[rows]
if isinstance(matrix, pl.DataFrame):
return matrix.filter(pl.Series(rows))

Check warning on line 50 in metalearners/_utils.py

View check run for this annotation

Codecov / codecov/patch

metalearners/_utils.py#L50

Added line #L50 was not covered by tests
return matrix[rows, :]


Expand All @@ -55,6 +57,8 @@ def index_vector(vector: Vector, rows: Vector) -> Vector:
rows = rows.to_numpy()
if isinstance(vector, pd.Series):
return vector.iloc[rows]
if isinstance(vector, pl.Series):
return vector.filter(rows)
return vector[rows]


Expand Down
5 changes: 3 additions & 2 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ONNX_PROBABILITIES_OUTPUTS,
default_metric,
index_matrix,
index_vector,
safe_len,
validate_model_and_predict_method,
validate_number_positive,
Expand Down Expand Up @@ -343,9 +344,9 @@ def _validate_outcome(self, y: Vector, w: Vector) -> None:
f" Yet we found {len(np.unique(y))} classes."
)
if self.is_classification:
classes_0 = set(np.unique(y[w == 0]))
classes_0 = set(np.unique(index_vector(y, w == 0)))
for tv in range(self.n_variants):
if set(np.unique(y[w == tv])) != classes_0:
if set(np.unique(index_vector(y, w == tv))) != classes_0:
raise ValueError(
f"Variants 0 and {tv} have seen different sets of classification outcomes. Please check your data."
)
Expand Down
21 changes: 17 additions & 4 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import pytest
from lightgbm import LGBMClassifier, LGBMRegressor
from scipy.sparse import csr_matrix
Expand Down Expand Up @@ -481,7 +482,7 @@ def test_combine_propensity_and_nuisance_specs(
),
],
)
@pytest.mark.parametrize("backend", ["np", "pd", "csr"])
@pytest.mark.parametrize("backend", ["np", "pd", "csr", "pl"])
def test_feature_set(feature_set, expected_n_features, backend, rng):
ml = _TestMetaLearner(
nuisance_model_factory=LGBMRegressor,
Expand All @@ -500,6 +501,10 @@ def test_feature_set(feature_set, expected_n_features, backend, rng):
X = pd.DataFrame(X)
y = pd.Series(y)
w = pd.Series(w)
elif backend == "pl":
X = pl.DataFrame(X)
y = pl.Series(y)
w = pl.Series(w)
elif backend == "csr":
X = csr_matrix(X)
ml.fit(X, y, w)
Expand Down Expand Up @@ -1081,15 +1086,19 @@ def test_n_jobs_base_learners(implementation, rng):
"implementation",
[TLearner, SLearner, XLearner, RLearner, DRLearner],
)
@pytest.mark.parametrize("backend", ["np", "pd", "csr"])
@pytest.mark.parametrize("backend", ["np", "pd", "csr", "pl"])
def test_validate_outcome_one_class(implementation, backend, rng):
X = rng.standard_normal((10, 2))
y = np.zeros(10)
w = rng.integers(0, 2, 10)
if backend == "pandas":
if backend == "pd":
X = pd.DataFrame(X)
y = pd.Series(y)
w = pd.Series(w)
elif backend == "pl":
X = pl.DataFrame(X)
y = pl.Series(y)
w = pl.Series(w)
elif backend == "csr":
X = csr_matrix(X)

Expand All @@ -1111,7 +1120,7 @@ def test_validate_outcome_one_class(implementation, backend, rng):
"implementation",
[TLearner, SLearner, XLearner, RLearner, DRLearner],
)
@pytest.mark.parametrize("backend", ["np", "pd", "csr"])
@pytest.mark.parametrize("backend", ["np", "pd", "csr", "pl"])
def test_validate_outcome_different_classes(implementation, backend, rng):
X = rng.standard_normal((4, 2))
y = np.array([0, 1, 0, 0])
Expand All @@ -1120,6 +1129,10 @@ def test_validate_outcome_different_classes(implementation, backend, rng):
X = pd.DataFrame(X)
y = pd.Series(y)
w = pd.Series(w)
elif backend == "pl":
X = pl.DataFrame(X)
y = pl.Series(y)
w = pl.Series(w)
elif backend == "csr":
X = csr_matrix(X)

Expand Down

0 comments on commit 8127ca9

Please sign in to comment.