From 495333b47b0781b9a564943a1aa4eaf578d42218 Mon Sep 17 00:00:00 2001 From: mathurinm Date: Thu, 1 Aug 2024 16:44:23 +0200 Subject: [PATCH] FIX make PDCD_WS solver usable in GeneralizedLinearEstimator (#274) Co-authored-by: Badr-MOUFAD --- skglm/experimental/pdcd_ws.py | 8 ++++++-- skglm/experimental/tests/test_quantile_regression.py | 8 ++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/skglm/experimental/pdcd_ws.py b/skglm/experimental/pdcd_ws.py index 81e72da8c..5ef49e5d4 100644 --- a/skglm/experimental/pdcd_ws.py +++ b/skglm/experimental/pdcd_ws.py @@ -82,13 +82,17 @@ class PDCD_WS(BaseSolver): _datafit_required_attr = ('prox_conjugate',) _penalty_required_attr = ("prox_1d",) - def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None, - p0=100, tol=1e-6, verbose=False): + def __init__( + self, max_iter=1000, max_epochs=1000, dual_init=None, p0=100, tol=1e-6, + fit_intercept=False, warm_start=True, verbose=False + ): self.max_iter = max_iter self.max_epochs = max_epochs self.dual_init = dual_init self.p0 = p0 self.tol = tol + self.fit_intercept = fit_intercept # TODO not handled + self.warm_start = warm_start # TODO not handled self.verbose = verbose def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): diff --git a/skglm/experimental/tests/test_quantile_regression.py b/skglm/experimental/tests/test_quantile_regression.py index 65e0c1e65..f4d1aa914 100644 --- a/skglm/experimental/tests/test_quantile_regression.py +++ b/skglm/experimental/tests/test_quantile_regression.py @@ -3,6 +3,7 @@ from numpy.linalg import norm from skglm.penalties import L1 +from skglm import GeneralizedLinearEstimator from skglm.experimental.pdcd_ws import PDCD_WS from skglm.experimental.quantile_regression import Pinball from skglm.utils.jit_compilation import compiled_clone @@ -37,6 +38,13 @@ def test_PDCD_WS(quantile_level): ).fit(X, y) np.testing.assert_allclose(w, clf.coef_, atol=1e-5) + # test compatibility when inside GLM: + estimator = GeneralizedLinearEstimator( + datafit=Pinball(.2), + penalty=L1(alpha=1.), + solver=PDCD_WS(), + ) + estimator.fit(X, y) if __name__ == '__main__':