Skip to content

Commit

Permalink
FIX make PDCD_WS solver usable in GeneralizedLinearEstimator (#274)
Browse files Browse the repository at this point in the history
Co-authored-by: Badr-MOUFAD <[email protected]>
  • Loading branch information
mathurinm and Badr-MOUFAD authored Aug 1, 2024
1 parent 6ac303c commit 495333b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
8 changes: 6 additions & 2 deletions skglm/experimental/pdcd_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,17 @@ class PDCD_WS(BaseSolver):
_datafit_required_attr = ('prox_conjugate',)
_penalty_required_attr = ("prox_1d",)

def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None,
p0=100, tol=1e-6, verbose=False):
def __init__(
self, max_iter=1000, max_epochs=1000, dual_init=None, p0=100, tol=1e-6,
fit_intercept=False, warm_start=True, verbose=False
):
self.max_iter = max_iter
self.max_epochs = max_epochs
self.dual_init = dual_init
self.p0 = p0
self.tol = tol
self.fit_intercept = fit_intercept # TODO not handled
self.warm_start = warm_start # TODO not handled
self.verbose = verbose

def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
Expand Down
8 changes: 8 additions & 0 deletions skglm/experimental/tests/test_quantile_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from numpy.linalg import norm

from skglm.penalties import L1
from skglm import GeneralizedLinearEstimator
from skglm.experimental.pdcd_ws import PDCD_WS
from skglm.experimental.quantile_regression import Pinball
from skglm.utils.jit_compilation import compiled_clone
Expand Down Expand Up @@ -37,6 +38,13 @@ def test_PDCD_WS(quantile_level):
).fit(X, y)

np.testing.assert_allclose(w, clf.coef_, atol=1e-5)
# test compatibility when inside GLM:
estimator = GeneralizedLinearEstimator(
datafit=Pinball(.2),
penalty=L1(alpha=1.),
solver=PDCD_WS(),
)
estimator.fit(X, y)


if __name__ == '__main__':
Expand Down

0 comments on commit 495333b

Please sign in to comment.