Skip to content

Commit

Permalink
added support for Tweedie, Poisson and Gamma regressors (#650)
Browse files Browse the repository at this point in the history
  • Loading branch information
interesaaat authored Nov 4, 2022
1 parent abfeee4 commit 29c3f68
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def __init__(
self.multi_class = multi_class
self.regression = is_linear_regression
self.classification = not is_linear_regression
self.loss = loss if loss is not None else "log"
self.loss = loss
if self.loss is None and self.classification:
self.loss = "log"

self.perform_class_select = False
if min(classes) != 0 or max(classes) != len(classes) - 1:
Expand All @@ -48,6 +50,8 @@ def forward(self, x):
if self.multi_class == "multinomial":
output = torch.softmax(output, dim=1)
elif self.regression:
if self.loss == "log":
return torch.exp(output)
return output
else:
if self.loss == "modified_huber":
Expand Down
10 changes: 9 additions & 1 deletion hummingbird/ml/operator_converters/sklearn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
from onnxconverter_common.registration import register_converter
from sklearn._loss.link import LogLink

from .._linear_implementations import LinearModel

Expand Down Expand Up @@ -81,6 +82,7 @@ def convert_sklearn_linear_regression_model(operator, device, extra_config):
"""
assert operator is not None, "Cannot convert None operator"

loss = None
coefficients = operator.raw_operator.coef_.transpose().astype("float32")
if len(coefficients.shape) == 1:
coefficients = coefficients.reshape(-1, 1)
Expand All @@ -91,7 +93,10 @@ def convert_sklearn_linear_regression_model(operator, device, extra_config):
else:
intercepts = intercepts.reshape(1, -1).astype("float32")

return LinearModel(operator, coefficients, intercepts, device, is_linear_regression=True)
if hasattr(operator.raw_operator, "_base_loss") and type(operator.raw_operator._base_loss.link) == LogLink:
loss = "log"

return LinearModel(operator, coefficients, intercepts, device, loss=loss, is_linear_regression=True)


register_converter("SklearnLinearRegression", convert_sklearn_linear_regression_model)
Expand All @@ -104,3 +109,6 @@ def convert_sklearn_linear_regression_model(operator, device, extra_config):
register_converter("SklearnSGDClassifier", convert_sklearn_linear_model)
register_converter("SklearnLogisticRegressionCV", convert_sklearn_linear_model)
register_converter("SklearnRidgeCV", convert_sklearn_linear_regression_model)
register_converter("SklearnTweedieRegressor", convert_sklearn_linear_regression_model)
register_converter("SklearnPoissonRegressor", convert_sklearn_linear_regression_model)
register_converter("SklearnGammaRegressor", convert_sklearn_linear_regression_model)
9 changes: 9 additions & 0 deletions hummingbird/ml/supported.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ExtraTreesClassifier,
ExtraTreesRegressor,
FastICA,
GammaRegressor,
GaussianNB,
GradientBoostingClassifier,
GradientBoostingRegressor,
Expand Down Expand Up @@ -51,6 +52,7 @@
OneHotEncoder,
PCA,
PLSRegression,
PoissonRegressor,
PolynomialFeatures,
RandomForestClassifier,
RandomForestRegressor,
Expand All @@ -64,6 +66,7 @@
TreeEnsembleClassifier,
TreeEnsembleRegressor,
TruncatedSVD,
TweedieRegressor,
VarianceThreshold,
**Supported Operators (LGBM)**
Expand Down Expand Up @@ -145,6 +148,9 @@ def _build_sklearn_operator_list():
ElasticNet,
Ridge,
Lasso,
TweedieRegressor,
PoissonRegressor,
GammaRegressor,
)

# SVM-based models
Expand Down Expand Up @@ -223,6 +229,9 @@ def _build_sklearn_operator_list():
Lasso,
ElasticNet,
Ridge,
TweedieRegressor,
PoissonRegressor,
GammaRegressor,
# Clustering
KMeans,
MeanShift,
Expand Down
33 changes: 33 additions & 0 deletions tests/test_sklearn_linear_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
Lasso,
ElasticNet,
Ridge,
TweedieRegressor,
PoissonRegressor,
GammaRegressor,
)
from sklearn import datasets

Expand Down Expand Up @@ -495,6 +498,36 @@ def test_lr_tvm(self):

np.testing.assert_allclose(model.predict(X), tvm_model.predict(X), rtol=1e-6, atol=1e-3)

def test_tweedie_regressor(self):
clf = TweedieRegressor()
X = [[1, 2], [2, 3], [3, 4], [4, 3]]
y = [2, 3.5, 5, 5.5]

clf.fit(X, y)
hb_model = hummingbird.ml.convert(clf, "torch")

np.testing.assert_allclose(clf.predict([[1, 1], [3, 4]]), hb_model.predict([[1, 1], [3, 4]]), rtol=1e-6, atol=1e-3)

def test_poisson_regressor(self):
clf = PoissonRegressor()
X = [[1, 2], [2, 3], [3, 4], [4, 3]]
y = [12, 17, 22, 21]

clf.fit(X, y)
hb_model = hummingbird.ml.convert(clf, "torch")

np.testing.assert_allclose(clf.predict([[1, 1], [3, 4]]), hb_model.predict([[1, 1], [3, 4]]), rtol=1e-6, atol=1e-3)

def test_gamma_regressor(self):
clf = GammaRegressor()
X = [[1, 2], [2, 3], [3, 4], [4, 3]]
y = [19, 26, 33, 30]

clf.fit(X, y)
hb_model = hummingbird.ml.convert(clf, "torch")

np.testing.assert_allclose(clf.predict([[1, 1], [3, 4]]), hb_model.predict([[1, 1], [3, 4]]), rtol=1e-6, atol=1e-3)


if __name__ == "__main__":
unittest.main()

0 comments on commit 29c3f68

Please sign in to comment.