Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates in prediction type compositions #408

Merged
merged 45 commits into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
e909389
only NAs for crank in null (default prediction type)
bblodfon Aug 11, 2024
e374a91
deprecate crank => distr composition, refactor a bit
bblodfon Aug 12, 2024
09889e2
update example, using coxph instead of rpart
bblodfon Aug 12, 2024
7aee854
updocs
bblodfon Aug 12, 2024
656280d
refactor bibentries for simplicity, add C-hacking paper
bblodfon Aug 12, 2024
9514426
add assert_surv_matrix fun
bblodfon Aug 12, 2024
caf028e
add new function to the website
bblodfon Aug 12, 2024
6cff3ea
remove distr composition from response in Surv => Regr pipeline
bblodfon Aug 12, 2024
488a3a6
small fix on the survavg pipeop
bblodfon Aug 12, 2024
1160be8
fix distrcompositor tests
bblodfon Aug 12, 2024
d190845
fix survavg tests
bblodfon Aug 12, 2024
c25e3ba
Merge branch 'main' into rmst_crankcompose_updates
bblodfon Aug 12, 2024
bb8cb50
import survivalmodels::surv_to_risk() rename as get_mortality()
bblodfon Aug 12, 2024
2b3ec02
rerun automated Rcpp compilation
bblodfon Aug 13, 2024
33c0283
add example for get_mortality()
bblodfon Aug 13, 2024
1799c13
updocs
bblodfon Aug 13, 2024
0bcac63
better doc formatting
bblodfon Aug 13, 2024
2620226
use get_mortality, remove survivalmodels dependency
bblodfon Aug 13, 2024
df2e8a5
simplify crankcompositor
bblodfon Aug 13, 2024
498a030
add new paper
bblodfon Aug 13, 2024
78e1809
add response compositor
bblodfon Aug 13, 2024
5337ac6
small doc correction
bblodfon Aug 16, 2024
3203be6
update crankcompositor pipeline
bblodfon Aug 16, 2024
76ea7d7
refactor tests, for crank and distr composition pipeops and pipelines
bblodfon Aug 16, 2024
1426c08
updocs
bblodfon Aug 16, 2024
99b7d4a
add namespace
bblodfon Aug 16, 2024
3551d5b
add function in the website
bblodfon Aug 16, 2024
c7e48c0
fix breslow test
bblodfon Aug 16, 2024
e7430f1
bug fix: overwrite behavior
bblodfon Aug 16, 2024
a2f4fcc
add tests for responsecompose pipeop
bblodfon Aug 16, 2024
1c65ef7
fix small bug
bblodfon Aug 16, 2024
c53f54d
add responsecompositor pipeline
bblodfon Aug 16, 2024
490bf87
fix test
bblodfon Aug 16, 2024
6f4dd92
rename test files
bblodfon Aug 16, 2024
6edfb68
updocs
bblodfon Aug 16, 2024
ba2ec04
fix line width
bblodfon Aug 16, 2024
c99e827
fix param names
bblodfon Aug 16, 2024
dab44ac
fix and update unloading
bblodfon Aug 16, 2024
1e675e1
updocs
bblodfon Aug 16, 2024
3afeb43
update news and description
bblodfon Aug 16, 2024
ac2a761
fix doc spelling mistake
bblodfon Aug 16, 2024
3ff4218
refactor test using new TaskSurv function
bblodfon Aug 16, 2024
7ab553b
refactor test
bblodfon Aug 16, 2024
9faae6a
remove partition() as mlr3 removed the S3 method
bblodfon Aug 17, 2024
da0a858
remove comment about stratification
bblodfon Aug 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.6.6
Version: 0.6.7
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down Expand Up @@ -64,11 +64,9 @@ Imports:
paradox (>= 1.0.0),
R6,
Rcpp (>= 1.0.4),
survival,
survivalmodels (>= 0.1.12)
survival
Suggests:
bujar,
cubature,
GGally,
knitr,
lgr,
Expand Down Expand Up @@ -145,6 +143,7 @@ Collate:
'PipeOpPredRegrSurv.R'
'PipeOpPredSurvRegr.R'
'PipeOpProbregrCompositor.R'
'PipeOpResponseCompositor.R'
'PipeOpSurvAvg.R'
'PipeOpTaskRegrSurv.R'
'PipeOpTaskSurvClassifDiscTime.R'
Expand Down Expand Up @@ -176,7 +175,6 @@ Collate:
'histogram.R'
'integrated_scores.R'
'mlr3proba-package.R'
'partition.R'
'pecs.R'
'pipelines.R'
'plot.R'
Expand Down
4 changes: 3 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ S3method(check_prediction_data,PredictionDataSurv)
S3method(filter_prediction_data,PredictionDataSurv)
S3method(is_missing_prediction_data,PredictionDataDens)
S3method(is_missing_prediction_data,PredictionDataSurv)
S3method(partition,TaskSurv)
S3method(pecs,PredictionSurv)
S3method(pecs,list)
S3method(plot,LearnerSurv)
Expand Down Expand Up @@ -77,6 +76,7 @@ export(PipeOpPredRegrSurv)
export(PipeOpPredSurvRegr)
export(PipeOpPredTransformer)
export(PipeOpProbregr)
export(PipeOpResponseCompositor)
export(PipeOpSurvAvg)
export(PipeOpTaskRegrSurv)
export(PipeOpTaskSurvClassifDiscTime)
Expand All @@ -95,7 +95,9 @@ export(as_prediction_surv)
export(as_task_dens)
export(as_task_surv)
export(assert_surv)
export(assert_surv_matrix)
export(breslow)
export(get_mortality)
export(pecs)
export(pipeline_survtoclassif_disctime)
export(pipeline_survtoregr)
Expand Down
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# mlr3proba 0.6.7

- Deprecate `crank` to `distr` composition in `distrcompose` pipeop (only from `lp` => `distr` works now)
- Add `get_mortality()` function (from `survivalmodels::surv_to_risk()`
- Add Rcpp function `assert_surv_matrix()`
- Update and simplify `crankcompose` pipeop and respective pipeline (no `response` is created anymore)
- Add `responsecompositor` pipeline with `rmst` and `median`

# mlr3proba 0.6.6

- Small fixes and refactoring to the discrete-time pipeops
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvCindex.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
#' library(mlr3)
#' task = tsk("rats")
#' learner = lrn("surv.coxph")
#' part = partition(task) # train/test split, stratified on `status` by default
#' part = partition(task) # train/test split
#' learner$train(task, part$train)
#' p = learner$predict(task, part$test)
#'
Expand Down
150 changes: 45 additions & 105 deletions R/PipeOpCrankCompositor.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,57 +16,36 @@
#' ```
#'
#' @section Input and Output Channels:
#' [PipeOpCrankCompositor] has one input channel named "input", which takes
#' `NULL` during training and [PredictionSurv] during prediction.
#' [PipeOpCrankCompositor] has one input channel named `"input"`, which takes `NULL` during training and [PredictionSurv] during prediction.
#'
#' [PipeOpCrankCompositor] has one output channel named "output", producing `NULL` during training
#' and a [PredictionSurv] during prediction.
#' [PipeOpCrankCompositor] has one output channel named `"output"`, producing `NULL` during training and a [PredictionSurv] during prediction.
#'
#' The output during prediction is the [PredictionSurv] from the "pred" input but with the `crank`
#' predict type overwritten by the given estimation method.
#' The output during prediction is the [PredictionSurv] from the input but with the `crank` predict type overwritten by the given estimation method.
#'
#' @section State:
#' The `$state` is left empty (`list()`).
#'
#' @section Parameters:
#' * `method` :: `character(1)` \cr
#' Determines what method should be used to produce a continuous ranking from the distribution.
#' One of `sum_haz`, `median`, `mode`, or `mean` corresponding to the
#' respective functions in the predicted survival distribution. Note that
#' for models with a proportional hazards form, the ranking implied by
#' `mean` and `median` will be identical (but not the value of `crank`
#' itself). `sum_haz` (default) uses [survivalmodels::surv_to_risk()].
#' * `which` :: `numeric(1)`\cr
#' If `method = "mode"` then specifies which mode to use if multi-modal, default is the first.
#' * `response` :: `logical(1)`\cr
#' If `TRUE` then the `response` predict type is estimated with the same values as `crank`.
#' Currently only `mort` is supported, which is the sum of the cumulative hazard, also called *expected/ensemble mortality*, see Ishwaran et al. (2008).
#' For more details, see [get_mortality()].
#' * `overwrite` :: `logical(1)` \cr
#' If `FALSE` (default) then if the "pred" input already has a `crank`, the compositor only
#' composes a `response` type if `response = TRUE` and does not already exist. If `TRUE` then
#' both the `crank` and `response` are overwritten.
#'
#' @section Internals:
#' The `median`, `mode`, or `mean` will use analytical expressions if possible but if not they are
#' calculated using methods from [distr6]. `mean` requires \CRANpkg{cubature}.
#' If `FALSE` (default) and the prediction already has a `crank` prediction, then the compositor returns the input prediction unchanged.
#' If `TRUE`, then the `crank` will be overwritten.
#'
#' @seealso [pipeline_crankcompositor]
#' @family survival compositors
#' @examples
#' \dontrun{
#' if (requireNamespace("mlr3pipelines", quietly = TRUE)) {
#' library(mlr3)
#' library(mlr3pipelines)
#' task = tsk("rats")
#'
#' learn = lrn("surv.coxph")$train(task)$predict(task)
#' poc = po("crankcompose", param_vals = list(method = "sum_haz"))
#' poc$predict(list(learn))[[1]]
#'
#' if (requireNamespace("cubature", quietly = TRUE)) {
#' learn = lrn("surv.coxph")$train(task)$predict(task)
#' poc = po("crankcompose", param_vals = list(method = "sum_haz"))
#' poc$predict(list(learn))[[1]]
#' }
#' # change the crank prediction type of a Cox's model predictions
#' pred = lrn("surv.coxph")$train(task)$predict(task)
#' poc = po("crankcompose", param_vals = list(overwrite = TRUE))
#' poc$predict(list(pred))[[1L]]
#' }
#' }
#' @export
Expand All @@ -77,21 +56,18 @@ PipeOpCrankCompositor = R6Class("PipeOpCrankCompositor",
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "crankcompose", param_vals = list()) {
param_set = ps(
method = p_fct(default = "sum_haz", levels = c("sum_haz", "mean", "median", "mode"),
tags = "predict"),
which = p_int(1L, default = 1L, tags = "predict", depends = quote(method == "mode")),
response = p_lgl(default = FALSE, tags = "predict"),
method = p_fct(default = "mort", levels = c("mort"), tags = "predict"),
overwrite = p_lgl(default = FALSE, tags = "predict")
)
param_set$set_values(method = "sum_haz", response = FALSE, overwrite = FALSE)
param_set$set_values(method = "mort", overwrite = FALSE)

super$initialize(
id = id,
param_set = param_set,
param_vals = param_vals,
input = data.table(name = "input", train = "NULL", predict = "PredictionSurv"),
output = data.table(name = "output", train = "NULL", predict = "PredictionSurv"),
packages = c("mlr3proba", "distr6")
packages = c("mlr3proba")
)
}
),
Expand All @@ -103,83 +79,47 @@ PipeOpCrankCompositor = R6Class("PipeOpCrankCompositor",
},

.predict = function(inputs) {

inpred = inputs[[1L]]

response = self$param_set$values$response
b_response = !anyMissing(inpred$response)
if (!length(response)) response = FALSE

pred = inputs[[1L]]
overwrite = self$param_set$values$overwrite
if (!length(overwrite)) overwrite = FALSE
# it's impossible for a learner not to predict crank in mlr3proba,
# but let's check either way:
has_crank = !all(is.na(pred$crank))

# if crank and response already exist and not overwriting then return prediction
if (!overwrite && (!response || (response && b_response))) {
return(list(inpred))
if (!overwrite & has_crank) {
# return prediction as is
return(list(pred))
} else {
assert("distr" %in% inpred$predict_types)
method = self$param_set$values$method
if (length(method) == 0L) method = "sum_haz"
if (method == "sum_haz") {
if (inherits(inpred$data$distr, "matrix") ||
!requireNamespace("survivalmodels", quietly = TRUE)) {
comp = survivalmodels::surv_to_risk(inpred$data$distr)
} else {
comp = as.numeric(
colSums(inpred$distr$cumHazard(sort(unique(inpred$truth[, 1]))))
)
}
} else if (method == "mean") {
comp = try(inpred$distr$mean(), silent = TRUE)
if (inherits(comp, "try-error")) {
requireNamespace("cubature")
comp = try(inpred$distr$mean(cubature = TRUE), silent = TRUE)
}
if (inherits(comp, "try-error")) {
comp = numeric(length(inpred$crank))
}
} else {
comp = switch(method,
median = inpred$distr$median(),
mode = inpred$distr$mode(self$param_set$values$which))
}
# compose crank from distr prediction
assert("distr" %in% pred$predict_types)

comp = as.numeric(comp)

# if crank exists and not overwriting then return predicted crank, otherwise compose
if (!overwrite) {
crank = inpred$crank
# get survival matrix
if (inherits(pred$data$distr, "array")) {
surv = pred$data$distr
if (length(dim(surv)) == 3L) {
# survival 3d array, extract median
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
}
} else {
crank = -comp
# missing imputed with median
crank[is.na(crank)] = stats::median(crank[!is.na(crank)])
crank[crank == Inf] = 1e3
crank[crank == -Inf] = -1e3
stop("Distribution prediction does not have a survival matrix or array
in the $data$distr slot")
}

# i) not overwriting or requesting response, and already predicted
if (b_response && (!overwrite || !response)) {
response = inpred$response
# ii) not requesting response and doesn't exist
} else if (!response) {
response = NULL
# iii) requesting response and happy to overwrite
# iv) requesting response and doesn't exist
} else {
response = comp
response[is.na(response)] = 0
response[response == Inf | response == -Inf] = 0
method = self$param_set$values$method
if (method == "mort") {
crank = get_mortality(surv)
}

if (!anyMissing(inpred$lp)) {
lp = inpred$lp
} else {
lp = NULL
}
# update only `crank`
p = PredictionSurv$new(
row_ids = pred$row_ids,
truth = pred$truth,
crank = crank,
distr = pred$distr,
lp = pred$lp,
response = pred$response
)

return(list(PredictionSurv$new(
row_ids = inpred$row_ids, truth = inpred$truth, crank = crank,
distr = inpred$distr, lp = lp, response = response)))
return(list(p))
}
}
)
Expand Down
Loading