From f2ac94e28fa2a9dbb1d03bbb8979536152186c39 Mon Sep 17 00:00:00 2001 From: mathurinm Date: Thu, 7 Nov 2024 17:05:49 +0100 Subject: [PATCH] add UT + rename to QuadraticHessian + rm estimator class --- doc/api.rst | 3 +- skglm/datafits/__init__.py | 4 +-- skglm/datafits/single_task.py | 19 ++++++----- skglm/estimators.py | 59 +++++------------------------------ skglm/tests/test_datafits.py | 20 +++++++++++- 5 files changed, 40 insertions(+), 65 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 6781ac9df..5ad56081b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -68,6 +68,7 @@ Datafits Poisson Quadratic QuadraticGroup + QuadraticHessian QuadraticSVC WeightedQuadratic @@ -102,4 +103,4 @@ Experimental PDCD_WS Pinball SqrtQuadratic - SqrtLasso \ No newline at end of file + SqrtLasso diff --git a/skglm/datafits/__init__.py b/skglm/datafits/__init__.py index 4a728ce15..74e0e5d75 100644 --- a/skglm/datafits/__init__.py +++ b/skglm/datafits/__init__.py @@ -1,6 +1,6 @@ from .base import BaseDatafit, BaseMultitaskDatafit from .single_task import (Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, - Cox, WeightedQuadratic, HessianQuadratic,) + Cox, WeightedQuadratic, QuadraticHessian,) from .multi_task import QuadraticMultiTask from .group import QuadraticGroup, LogisticGroup @@ -10,5 +10,5 @@ Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox, QuadraticMultiTask, QuadraticGroup, LogisticGroup, WeightedQuadratic, - HessianQuadratic + QuadraticHessian ] diff --git a/skglm/datafits/single_task.py b/skglm/datafits/single_task.py index 0a6917b7a..4d413d5e0 100644 --- a/skglm/datafits/single_task.py +++ b/skglm/datafits/single_task.py @@ -240,15 +240,16 @@ def intercept_update_step(self, y, Xw): return np.sum(self.sample_weights * (Xw - y)) / self.sample_weights.sum() -class HessianQuadratic(BaseDatafit): - r"""_summary_ +class QuadraticHessian(BaseDatafit): + r"""Quadratic datafit where we pass the Hessian A directly. The datafit reads: - .. math:: 1 / 2 x^\\top A x + \\langle b, x \\rangle - - for A symmetric + .. math:: 1 / 2 x^(\top) A x + \langle b, x \rangle + For a symmetric A. Up to a constant, it is the same as a Quadratic, with + :math:`A = 1 / (n_"samples") X^(\top)X` and :math:`b = - 1 / n_"samples" X^(\top)y`. + When the Hessian is available, this datafit is more efficient than using Quadratic. """ def __init__(self): @@ -264,7 +265,7 @@ def get_lipschitz(self, A, b): n_features = A.shape[0] lipschitz = np.zeros(n_features, dtype=A.dtype) for j in range(n_features): - lipschitz[j] = np.sqrt((A[:, j]**2).sum()) + lipschitz[j] = A[j, j] return lipschitz def gradient_scalar(self, A, b, w, Ax, j): @@ -887,8 +888,7 @@ def _A_dot_vec(self, vec): for idx in range(n_H): current_H_idx = self.H_indices[self.H_indptr[idx]: self.H_indptr[idx+1]] size_current_H = current_H_idx.shape[0] - frac_range = np.arange( - size_current_H, dtype=vec.dtype) / size_current_H + frac_range = np.arange(size_current_H, dtype=vec.dtype) / size_current_H sum_vec_H = np.sum(vec[current_H_idx]) out[current_H_idx] = sum_vec_H * frac_range @@ -903,8 +903,7 @@ def _AT_dot_vec(self, vec): for idx in range(n_H): current_H_idx = self.H_indices[self.H_indptr[idx]: self.H_indptr[idx+1]] size_current_H = current_H_idx.shape[0] - frac_range = np.arange( - size_current_H, dtype=vec.dtype) / size_current_H + frac_range = np.arange(size_current_H, dtype=vec.dtype) / size_current_H weighted_sum_vec_H = vec[current_H_idx] @ frac_range out[current_H_idx] = weighted_sum_vec_H * np.ones(size_current_H) diff --git a/skglm/estimators.py b/skglm/estimators.py index 692b934d0..85a735ffd 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -21,7 +21,7 @@ from skglm.utils.jit_compilation import compiled_clone from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD from skglm.datafits import (Cox, Quadratic, Logistic, QuadraticSVC, - QuadraticMultiTask, QuadraticGroup, HessianQuadratic) + QuadraticMultiTask, QuadraticGroup, QuadraticHessian) from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2, MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1) from skglm.utils.data import grp_converter @@ -126,8 +126,7 @@ def _glm_fit(X, y, model, datafit, penalty, solver): w = np.zeros(n_features + fit_intercept, dtype=X_.dtype) Xw = np.zeros(n_samples, dtype=X_.dtype) else: # multitask - w = np.zeros((n_features + fit_intercept, - y.shape[1]), dtype=X_.dtype) + w = np.zeros((n_features + fit_intercept, y.shape[1]), dtype=X_.dtype) Xw = np.zeros(y.shape, dtype=X_.dtype) # check consistency of weights for WeightedL1 @@ -450,42 +449,6 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params): return solver.path(X, y, datafit, penalty, alphas, coef_init, return_n_iter) -class L1PenalizedQP(BaseEstimator): - def __init__(self, alpha=1., max_iter=50, max_epochs=50_000, p0=10, verbose=0, - tol=1e-4, positive=False, fit_intercept=True, warm_start=False, - ws_strategy="subdiff"): - super().__init__() - self.alpha = alpha - self.tol = tol - self.max_iter = max_iter - self.max_epochs = max_epochs - self.p0 = p0 - self.ws_strategy = ws_strategy - self.positive = positive - self.fit_intercept = fit_intercept - self.warm_start = warm_start - self.verbose = verbose - - def fit(self, A, b): - """Fit the model according to the given training data. - - Parameters - ---------- - A : array-like, shape (n_features, n_features) - b : array-like, shape (n_samples,) - - Returns - ------- - self : - Fitted estimator. - """ - solver = AndersonCD( - self.max_iter, self.max_epochs, self.p0, tol=self.tol, - ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept, - warm_start=self.warm_start, verbose=self.verbose) - return _glm_fit(A, b, self, HessianQuadratic(), L1(self.alpha, self.positive), solver) - - class WeightedLasso(LinearModel, RegressorMixin): r"""WeightedLasso estimator based on Celer solver and primal extrapolation. @@ -613,8 +576,7 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params): raise ValueError("The number of weights must match the number of \ features. Got %s, expected %s." % ( len(weights), X.shape[1])) - penalty = compiled_clone(WeightedL1( - self.alpha, weights, self.positive)) + penalty = compiled_clone(WeightedL1(self.alpha, weights, self.positive)) datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32) solver = AndersonCD( self.max_iter, self.max_epochs, self.p0, tol=self.tol, @@ -952,8 +914,7 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params): f"Got {len(self.weights)}, expected {X.shape[1]}." ) penalty = compiled_clone( - WeightedMCPenalty(self.alpha, self.gamma, - self.weights, self.positive) + WeightedMCPenalty(self.alpha, self.gamma, self.weights, self.positive) ) datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32) solver = AndersonCD( @@ -1348,8 +1309,7 @@ def fit(self, X, y): # copy/paste from https://github.com/scikit-learn/scikit-learn/blob/ \ # 23ff51c07ebc03c866984e93c921a8993e96d1f9/sklearn/utils/ \ # estimator_checks.py#L3886 - raise ValueError( - "requires y to be passed, but the target y is None") + raise ValueError("requires y to be passed, but the target y is None") y = check_array( y, accept_sparse=False, @@ -1390,8 +1350,7 @@ def fit(self, X, y): # init solver if self.l1_ratio == 0.: - solver = LBFGS(max_iter=self.max_iter, - tol=self.tol, verbose=self.verbose) + solver = LBFGS(max_iter=self.max_iter, tol=self.tol, verbose=self.verbose) else: solver = ProxNewton( max_iter=self.max_iter, tol=self.tol, verbose=self.verbose, @@ -1529,8 +1488,7 @@ def fit(self, X, Y): if not self.warm_start or not hasattr(self, "coef_"): self.coef_ = None - datafit_jit = compiled_clone( - QuadraticMultiTask(), X.dtype == np.float32) + datafit_jit = compiled_clone(QuadraticMultiTask(), X.dtype == np.float32) penalty_jit = compiled_clone(L2_1(self.alpha), X.dtype == np.float32) solver = MultiTaskBCD( @@ -1710,8 +1668,7 @@ def fit(self, X, y): "The total number of group members must equal the number of features. " f"Got {n_features}, expected {X.shape[1]}.") - weights = np.ones( - len(group_sizes)) if self.weights is None else self.weights + weights = np.ones(len(group_sizes)) if self.weights is None else self.weights group_penalty = WeightedGroupL2(alpha=self.alpha, grp_ptr=grp_ptr, grp_indices=grp_indices, weights=weights, positive=self.positive) diff --git a/skglm/tests/test_datafits.py b/skglm/tests/test_datafits.py index a6068b2ec..e5ca43235 100644 --- a/skglm/tests/test_datafits.py +++ b/skglm/tests/test_datafits.py @@ -6,7 +6,7 @@ from numpy.testing import assert_allclose, assert_array_less from skglm.datafits import (Huber, Logistic, Poisson, Gamma, Cox, WeightedQuadratic, - Quadratic,) + Quadratic, QuadraticHessian) from skglm.penalties import L1, WeightedL1 from skglm.solvers import AndersonCD, ProxNewton from skglm import GeneralizedLinearEstimator @@ -219,5 +219,23 @@ def test_sample_weights(fit_intercept): # np.testing.assert_equal(n_iter, n_iter_overs) +def test_HessianQuadratic(): + n_samples = 20 + n_features = 10 + X, y, _ = make_correlated_data( + n_samples=n_samples, n_features=n_features, random_state=0) + A = X.T @ X / n_samples + b = -X.T @ y / n_samples + alpha = np.max(np.abs(b)) / 10 + + pen = L1(alpha) + solv = AndersonCD(warm_start=False, verbose=2, fit_intercept=False) + lasso = GeneralizedLinearEstimator(Quadratic(), pen, solv).fit(X, y) + qpl1 = GeneralizedLinearEstimator(QuadraticHessian(), pen, solv).fit(A, b) + + np.testing.assert_allclose(lasso.coef_, qpl1.coef_) + # check that it's not just because we got alpha too high and thus 0 coef + np.testing.assert_array_less(0.1, np.max(np.abs(qpl1.coef_))) + if __name__ == '__main__': pass