From 4ac443b4f43efcf01337e09012aaf67d4a43131f Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Fri, 22 Nov 2024 16:53:29 +0100 Subject: [PATCH] MNT use dataclass for IRLSData (#881) * MNT use dataclass for IRLSData * CLN import order * CLN up 2 fileds * CLN import order --- src/glum/_solvers.py | 76 ++++++++++++++++---------------------------- 1 file changed, 27 insertions(+), 49 deletions(-) diff --git a/src/glum/_solvers.py b/src/glum/_solvers.py index 43b3de50..e169bdf4 100644 --- a/src/glum/_solvers.py +++ b/src/glum/_solvers.py @@ -1,7 +1,8 @@ import functools import time import warnings -from typing import Optional, Union +from dataclasses import InitVar, dataclass +from typing import Any, Optional, Union import numpy as np from scipy import linalg, sparse @@ -408,53 +409,33 @@ def _update(self, n_iter, iteration_runtime, cur_grad_norm): self.t.update(0) +@dataclass class IRLSData: """Store parameters for the IRLS optimizer.""" - def __init__( - self, - X, - y: np.ndarray, - sample_weight: np.ndarray, - P1: Union[np.ndarray, sparse.spmatrix], - P2: Union[np.ndarray, sparse.spmatrix], - fit_intercept: bool, - family: ExponentialDispersionModel, - link: Link, - max_iter: int = 100, - max_inner_iter: int = 100000, - gradient_tol: Optional[float] = 1e-4, - step_size_tol: Optional[float] = 1e-4, - hessian_approx: float = 0.0, - fixed_inner_tol: Optional[tuple] = None, - selection="cyclic", - random_state=None, - offset: Optional[np.ndarray] = None, - lower_bounds: Optional[np.ndarray] = None, - upper_bounds: Optional[np.ndarray] = None, - verbose: bool = False, - ): - self.X = X - self.y = y - self.sample_weight = sample_weight - self.P1 = P1 - - # Note: we already set P2 = l2*P2, P1 = l1*P1 - # Note: we already symmetrized P2 = 1/2 (P2 + P2') - self.P2 = P2 - - self.fit_intercept = fit_intercept - self.family = family - self.link = link - self.max_iter = max_iter - self.max_inner_iter = max_inner_iter - self.gradient_tol = gradient_tol - self.step_size_tol = step_size_tol - self.hessian_approx = hessian_approx - self.fixed_inner_tol = fixed_inner_tol - self.selection = selection - self.random_state = random_state - self.offset = offset + X: Any + y: np.ndarray + sample_weight: np.ndarray + # Note: we already set P2 = l2*P2, P1 = l1*P1 and symmetrized P2 = 1/2 (P2 + P2') + P1: Union[np.ndarray, sparse.spmatrix] + P2: Union[np.ndarray, sparse.spmatrix] + fit_intercept: bool + family: ExponentialDispersionModel + link: Link + max_iter: int = 100 + max_inner_iter: int = 100000 + gradient_tol: Optional[float] = 1e-4 + step_size_tol: Optional[float] = 1e-4 + hessian_approx: float = 0.0 + fixed_inner_tol: Optional[tuple] = None + selection: str = "cyclic" + random_state: Union[None, int, np.random.RandomState] = None + offset: Optional[np.ndarray] = None + lower_bounds: InitVar[Optional[np.ndarray]] = None + upper_bounds: InitVar[Optional[np.ndarray]] = None + verbose: bool = False + + def __post_init__(self, lower_bounds, upper_bounds): self.has_lower_bounds, self._lower_bounds = _setup_bounds( lower_bounds, self.X.dtype ) @@ -463,11 +444,8 @@ def __init__( ) self.intercept_offset = 1 if self.fit_intercept else 0 - self.verbose = verbose - - self._check_data() - def _check_data(self): + # Check data if self.P2.ndim == 2: self.P2 = check_array(self.P2, "csc", dtype=[np.float64, np.float32])