Skip to content

Commit

Permalink
MNT use dataclass for IRLSData (#881)
Browse files Browse the repository at this point in the history
* MNT use dataclass for IRLSData

* CLN import order

* CLN up 2 fileds

* CLN import order
  • Loading branch information
lorentzenchr authored Nov 22, 2024
1 parent e5b26e1 commit 4ac443b
Showing 1 changed file with 27 additions and 49 deletions.
76 changes: 27 additions & 49 deletions src/glum/_solvers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import functools
import time
import warnings
from typing import Optional, Union
from dataclasses import InitVar, dataclass
from typing import Any, Optional, Union

import numpy as np
from scipy import linalg, sparse
Expand Down Expand Up @@ -408,53 +409,33 @@ def _update(self, n_iter, iteration_runtime, cur_grad_norm):
self.t.update(0)


@dataclass
class IRLSData:
"""Store parameters for the IRLS optimizer."""

def __init__(
self,
X,
y: np.ndarray,
sample_weight: np.ndarray,
P1: Union[np.ndarray, sparse.spmatrix],
P2: Union[np.ndarray, sparse.spmatrix],
fit_intercept: bool,
family: ExponentialDispersionModel,
link: Link,
max_iter: int = 100,
max_inner_iter: int = 100000,
gradient_tol: Optional[float] = 1e-4,
step_size_tol: Optional[float] = 1e-4,
hessian_approx: float = 0.0,
fixed_inner_tol: Optional[tuple] = None,
selection="cyclic",
random_state=None,
offset: Optional[np.ndarray] = None,
lower_bounds: Optional[np.ndarray] = None,
upper_bounds: Optional[np.ndarray] = None,
verbose: bool = False,
):
self.X = X
self.y = y
self.sample_weight = sample_weight
self.P1 = P1

# Note: we already set P2 = l2*P2, P1 = l1*P1
# Note: we already symmetrized P2 = 1/2 (P2 + P2')
self.P2 = P2

self.fit_intercept = fit_intercept
self.family = family
self.link = link
self.max_iter = max_iter
self.max_inner_iter = max_inner_iter
self.gradient_tol = gradient_tol
self.step_size_tol = step_size_tol
self.hessian_approx = hessian_approx
self.fixed_inner_tol = fixed_inner_tol
self.selection = selection
self.random_state = random_state
self.offset = offset
X: Any
y: np.ndarray
sample_weight: np.ndarray
# Note: we already set P2 = l2*P2, P1 = l1*P1 and symmetrized P2 = 1/2 (P2 + P2')
P1: Union[np.ndarray, sparse.spmatrix]
P2: Union[np.ndarray, sparse.spmatrix]
fit_intercept: bool
family: ExponentialDispersionModel
link: Link
max_iter: int = 100
max_inner_iter: int = 100000
gradient_tol: Optional[float] = 1e-4
step_size_tol: Optional[float] = 1e-4
hessian_approx: float = 0.0
fixed_inner_tol: Optional[tuple] = None
selection: str = "cyclic"
random_state: Union[None, int, np.random.RandomState] = None
offset: Optional[np.ndarray] = None
lower_bounds: InitVar[Optional[np.ndarray]] = None
upper_bounds: InitVar[Optional[np.ndarray]] = None
verbose: bool = False

def __post_init__(self, lower_bounds, upper_bounds):
self.has_lower_bounds, self._lower_bounds = _setup_bounds(
lower_bounds, self.X.dtype
)
Expand All @@ -463,11 +444,8 @@ def __init__(
)

self.intercept_offset = 1 if self.fit_intercept else 0
self.verbose = verbose

self._check_data()

def _check_data(self):
# Check data
if self.P2.ndim == 2:
self.P2 = check_array(self.P2, "csc", dtype=[np.float64, np.float32])

Expand Down

0 comments on commit 4ac443b

Please sign in to comment.