Skip to content

Commit

Permalink
feat: add quantile regression to gbm (#380)
Browse files Browse the repository at this point in the history
* feat: add quantile regression to gbm

* fix: description

* refactor: rename

* fix: ignore weight when quantil regression

* chore: remotes
  • Loading branch information
be-marc authored Aug 22, 2024
1 parent f496537 commit 98d790f
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 6 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ Remotes:
xoopR/distr6,
xoopR/param6,
xoopR/set6,
ropensci/aorsf
ropensci/aorsf,
mlr-org/mlr3
VignetteBuilder:
knitr
Config/testthat/edition: 3
Expand Down
27 changes: 23 additions & 4 deletions R/learner_gbm_regr_gbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#' Gradient Boosting Regression Algorithm.
#' Calls [gbm::gbm()] from \CRANpkg{gbm}.
#'
#' Weights are ignored for quantile prediction.
#'
#' @templateVar id regr.gbm
#' @template learner
#'
Expand Down Expand Up @@ -59,7 +61,7 @@ LearnerRegrGBM = R6Class("LearnerRegrGBM",
id = "regr.gbm",
packages = c("mlr3extralearners", "gbm"),
feature_types = c("integer", "numeric", "factor", "ordered"),
predict_types = "response",
predict_types = c("response", "quantiles"),
param_set = ps,
properties = c("weights", "importance", "missings"),
man = "mlr3extralearners::mlr_learners_regr.gbm",
Expand Down Expand Up @@ -94,10 +96,19 @@ LearnerRegrGBM = R6Class("LearnerRegrGBM",
f = task$formula()
data = task$data()

if ("weights" %in% task$properties) {
if ("weights" %in% task$properties && self$predict_type != "quantiles") {
pars = insert_named(pars, list(weights = task$weights$weight))
}

if (self$predict_type == "quantiles") {

if (length(self$quantiles) > 1) {
stop("Only one quantile is supported")
}

pars$distribution = list(name = "quantile", alpha = self$quantiles)
}

invoke(gbm::gbm, formula = f, data = data, .args = pars)
},

Expand All @@ -110,8 +121,16 @@ LearnerRegrGBM = R6Class("LearnerRegrGBM",
}
newdata = ordered_features(task, self)

p = invoke(predict, self$model, newdata = newdata, .args = pars)
list(response = p)
pred = invoke(predict, self$model, newdata = newdata, .args = pars)

if (self$predict_type == "quantiles") {
quantiles = matrix(pred, ncol = 1)
attr(quantiles, "probs") = private$.quantiles
attr(quantiles, "response") = private$.quantile_response
return(list(quantiles = quantiles))
}

list(response = pred)
}
)
)
Expand Down
4 changes: 3 additions & 1 deletion man/mlr_learners_regr.gbm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions tests/testthat/test_gbm_regr_gbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,22 @@ test_that("autotest", {
result = run_autotest(learner)
expect_true(result, info = result$error)
})

test_that("quantile prediction", {
learner = mlr3::lrn("regr.gbm",
predict_type = "quantiles",
quantiles = 0.1,
n.minobsinnode = 1)
task = tsk("mtcars")

learner$train(task)
pred = learner$predict(task)

expect_matrix(pred$quantiles, ncol = 1L)
expect_true(!any(apply(pred$quantiles, 1L, is.unsorted)))
expect_equal(pred$response, pred$quantiles[, 1L])

tab = as.data.table(pred)
expect_names(names(tab), identical.to = c("row_ids", "truth", "q0.1", "response"))
expect_equal(tab$response, tab$q0.1)
})

0 comments on commit 98d790f

Please sign in to comment.