Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft Surv to classif pipeline #194 #391

Merged
merged 72 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
a991f70
draft survival to classification pipeline
studener Jun 8, 2024
cda6c3a
implement comments
studener Jun 12, 2024
f55e64d
minor code styling
bblodfon Jun 14, 2024
3a08895
draft pipeline
studener Jun 18, 2024
5d5645b
Merge branch 'main' into surv_to_classif_pipeline
studener Jun 20, 2024
02dbcdd
add tests / add documentation
studener Jun 22, 2024
ab723bf
add mlr3learners to suggests
studener Jun 23, 2024
b8b2526
add pammtools to suggests
studener Jun 25, 2024
4c0b3b6
update cut parameter
studener Jun 25, 2024
19f0a34
minor fix
studener Jun 25, 2024
6e3b51a
better code align
bblodfon Jun 27, 2024
1338abd
fix .predict for potasksurvclassif
studener Jun 27, 2024
7919228
add positive to taskclassif
studener Jun 27, 2024
58c18f8
refactor cox test
bblodfon Jun 27, 2024
17f2ede
increase threshold to error (avoid seeing cox fit warnings)
bblodfon Jun 27, 2024
1792197
add ordered_features function
bblodfon Jun 27, 2024
3f2c01f
* refactor (especially in the learners cox and rpart)
bblodfon Jun 27, 2024
2580481
fix namespace
bblodfon Jun 27, 2024
c046847
Merge branch 'surv_to_classif_pipeline' of https://github.com/mlr-org…
bblodfon Jun 27, 2024
cc5df76
refactor pipeop surv => classif
bblodfon Jun 27, 2024
c053269
fix typo
bblodfon Jun 27, 2024
128e8c7
code styling
bblodfon Jun 27, 2024
7f0c5f3
move pipeline code to appropriate file
bblodfon Jun 27, 2024
9c17471
doc refactor
bblodfon Jun 27, 2024
d1336a5
puting the tests to the appropriate file
bblodfon Jun 27, 2024
dd3c1b5
inherit from PipeOp
bblodfon Jun 27, 2024
fe6c9a3
fix build notes
bblodfon Jun 27, 2024
a1c5e8d
manually remove Breslow pipeop
bblodfon Jun 27, 2024
2b99127
refactor unload test for readability
bblodfon Jun 27, 2024
ace5d58
fix note
bblodfon Jun 27, 2024
9945f22
update PipeOpTaskSurvClassif.R
studener Jun 28, 2024
e4e19d3
skip unloading test for now
bblodfon Jul 1, 2024
8c6eb67
Merge branch 'surv_to_classif_pipeline' of https://github.com/mlr-org…
bblodfon Jul 1, 2024
229c623
inherit from PipeOp not PipeOpTransformer
bblodfon Jul 2, 2024
e197735
small update on the example
bblodfon Jul 2, 2024
534a452
add rhs param to pipeline
studener Jul 2, 2024
d2940d5
Merge branch 'surv_to_classif_pipeline' of https://github.com/mlr-org…
studener Jul 2, 2024
4448790
implement changes
studener Jul 2, 2024
d0551a6
refactor pipeline_survtoclassif
studener Jul 4, 2024
0c1f65b
update tests
studener Jul 4, 2024
2dbbce6
update example
studener Jul 4, 2024
dc3b192
add ref
studener Jul 7, 2024
1727750
update NEWS.md / add myself to authors
studener Jul 7, 2024
65be39f
minor fixes
studener Jul 10, 2024
cf955f5
Merge branch 'main' into surv_to_classif_pipeline
bblodfon Jul 11, 2024
69d38d5
fix authors
bblodfon Jul 11, 2024
82bf006
long format for numbers
bblodfon Jul 11, 2024
a487c3f
refactor tests or clarity + testthat v3 compatibility
bblodfon Jul 11, 2024
cf5068b
rename graph learner
bblodfon Jul 11, 2024
cb5245f
run document()
bblodfon Jul 11, 2024
e71d3c4
some refactoring
bblodfon Jul 13, 2024
9b6dc9e
update docs + names
studener Jul 16, 2024
274b0af
update pipe to handle no events in test data
studener Jul 17, 2024
92405b1
rename files to match new class names
bblodfon Jul 22, 2024
8e0b176
fix data.table notes
bblodfon Jul 22, 2024
cfff122
add small comment in test
bblodfon Jul 22, 2024
546b9d2
update examples
bblodfon Jul 22, 2024
9fa6456
refactor => hazards to surv conversion (better readability)
bblodfon Jul 22, 2024
e91e3a0
add more doc
bblodfon Jul 22, 2024
0fae37c
force cut as integer in some cases (avoid continuous times)
bblodfon Jul 22, 2024
1f8301e
better doc for pipeline
bblodfon Jul 22, 2024
a2cb938
updocs
bblodfon Jul 22, 2024
959985d
revert this back from now
bblodfon Jul 22, 2024
2369da2
fix ped_status in transformed data / refactor
studener Jul 23, 2024
1be6914
add details to pipeline doc
studener Jul 23, 2024
3ec8186
fix typo
studener Jul 23, 2024
4101e9b
revert to non-integer cut
bblodfon Jul 25, 2024
c9502fa
better doc, remove example
bblodfon Jul 25, 2024
82d4392
minimize example
bblodfon Jul 25, 2024
718c7a9
updocs
bblodfon Jul 25, 2024
510648d
small test refactoring
bblodfon Jul 25, 2024
80bf473
update version and news
bblodfon Jul 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ Authors@R:
email = "[email protected]",
role = "ctb",
comment = c(ORCID = "0000-0001-7528-3795")),
person(given = "Philip",
family = "Studener",
role = "aut",
email = "[email protected]"),
person(given = "Maximilian",
family = "Muecke",
email = "[email protected]",
Expand Down Expand Up @@ -79,7 +83,9 @@ Suggests:
vdiffr,
abind,
Ecdat,
coxed
coxed,
mlr3learners,
pammtools
LinkingTo:
Rcpp
Remotes:
Expand Down Expand Up @@ -133,13 +139,15 @@ Collate:
'PipeOpBreslow.R'
'PipeOpCrankCompositor.R'
'PipeOpDistrCompositor.R'
'PipeOpPredClassifSurv.R'
'PipeOpTransformer.R'
'PipeOpPredTransformer.R'
'PipeOpPredRegrSurv.R'
'PipeOpPredSurvRegr.R'
'PipeOpProbregrCompositor.R'
'PipeOpSurvAvg.R'
'PipeOpTaskRegrSurv.R'
'PipeOpTaskSurvClassif.R'
'PipeOpTaskSurvRegr.R'
'PipeOpTaskTransformer.R'
'PredictionDataDens.R'
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ export(MeasureSurvXuR2)
export(PipeOpBreslow)
export(PipeOpCrankCompositor)
export(PipeOpDistrCompositor)
export(PipeOpPredClassifSurv)
export(PipeOpPredRegrSurv)
export(PipeOpPredSurvRegr)
export(PipeOpPredTransformer)
export(PipeOpProbregr)
export(PipeOpSurvAvg)
export(PipeOpTaskRegrSurv)
export(PipeOpTaskSurvClassif)
export(PipeOpTaskSurvRegr)
export(PipeOpTaskTransformer)
export(PipeOpTransformer)
Expand All @@ -95,6 +97,7 @@ export(as_task_surv)
export(assert_surv)
export(breslow)
export(pecs)
export(pipeline_survtoclassif)
export(pipeline_survtoregr)
export(plot_probregr)
import(checkmate)
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# mlr3proba 0.6.4

* Add `PipeTaskSurvClassif`, `PipeOpPredClassifSurv` and `pipeline_survtoclassif` to transform a survival task into a classification task by discretizing the status.
* Add useR! 2024 tutorial
* Lots of refactoring, improve code quality (thanks to @m-muecke)
* Lots of refactoring, improving code quality, migration to testthat v3, etc. (thanks to @m-muecke)

# mlr3proba 0.6.3

Expand Down
24 changes: 7 additions & 17 deletions R/LearnerSurvCoxPH.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,20 @@ LearnerSurvCoxPH = R6Class("LearnerSurvCoxPH",
pv$weights = task$weights$weight
}

invoke(survival::coxph, formula = task$formula(), data = task$data(), .args = pv, x = TRUE)
invoke(survival::coxph, formula = task$formula(), data = task$data(),
.args = pv, x = TRUE)
},

.predict = function(task) {

newdata = task$data(cols = task$feature_names)

# We move the missingness checks here manually as if any NAs are made in predictions then the
# distribution object cannot be create (initialization of distr6 objects does not handle NAs)
if (anyMissing(newdata)) {
stopf(
"Learner %s on task %s failed to predict: Missing values in new data (line(s) %s)\n",
self$id, task$id,
toString(which(!complete.cases(newdata)))
)
}

newdata = ordered_features(task, self)
pv = self$param_set$get_values(tags = "predict")

# Get predicted values
# Get survival predictions via `survfit`
fit = invoke(survival::survfit, formula = self$model, newdata = newdata,
se.fit = FALSE, .args = pv)
se.fit = FALSE, .args = pv)

lp = predict(self$model, type = "lp", newdata = newdata)
# Get linear predictors
lp = invoke(predict, self$model, type = "lp", newdata = newdata)

.surv_return(times = fit$time, surv = t(fit$surv), lp = lp)
}
Expand Down
11 changes: 6 additions & 5 deletions R/LearnerSurvRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,15 @@ LearnerSurvRpart = R6Class("LearnerSurvRpart",
pv = insert_named(pv, list(weights = task$weights$weight))
}

invoke(rpart::rpart,
formula = task$formula(), data = task$data(),
method = "exp", .args = pv)
invoke(rpart::rpart, formula = task$formula(), data = task$data(),
method = "exp", .args = pv)
},

.predict = function(task) {
preds = invoke(predict, object = self$model, newdata = task$data(cols = task$feature_names))
list(crank = preds)
newdata = ordered_features(task, self)
p = invoke(predict, object = self$model, newdata = newdata)

list(crank = p)
}
)
)
Expand Down
112 changes: 112 additions & 0 deletions R/PipeOpPredClassifSurv.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#' @title PipeOpPredClassifSurv
#' @name mlr_pipeops_trafopred_classifsurv
#'
#' @description
#' Transform [PredictionClassif] to [PredictionSurv] by converting
#' event probabilities of a pseudo status variable (discrete time hazards)
#' to survival probabilities.
#'
#' @section Input and Output Channels:
#' Input and output channels are inherited from [PipeOp][mlr3pipelines::PipeOp].
#'
#' The output is the input [PredictionClassif] transformed to a [PredictionSurv].
#'
#' @examples
#' \dontrun{
#' if (requireNamespace("mlr3pipelines", quietly = TRUE)) {
#' library(mlr3)
#' library(mlr3pipelines)
#'
#' task = tsk("rats")
#'
#' if (requireNamespace("mlr3learners", quietly = TRUE)) {
#' library(mlr3learners)
#' po_tasktoclassif = po("trafotask_survclassif")
#' po_tasktoclassif$train(list(task))
#' task_classif = po_tasktoclassif$predict(list(task))[[1]]
studener marked this conversation as resolved.
Show resolved Hide resolved
#' trafo_data = po_tasktoclassif$predict(list(task))[[2]]
#'
#' learner = lrn("classif.log_reg", predict_type = "prob")
#' learner$train(task_classif)
#' pred = learner$predict(task_classif)
#'
#' po_predtosurv = po("trafopred_classifsurv")
#' po_predtosurv$train(list(pred, trafo_data))
#' po_predtosurv$predict(list(pred, trafo_data))
#' }
#' }
#' }
#' @family PipeOps
#' @family Transformation PipeOps
#' @export
PipeOpPredClassifSurv = R6Class(
"PipeOpPredClassifSurv",
inherit = mlr3pipelines::PipeOp,

public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#' @param id (character(1))\cr
#' Identifier of the resulting object.
initialize = function(id = "trafopred_classifsurv") {
super$initialize(
id = id,
input = data.table::data.table(
name = c("input", "transformed_data"),
train = c("NULL", "data.table"),
predict = c("PredictionClassif", "data.table")
),
output = data.table::data.table(
name = "output",
train = "NULL",
predict = "PredictionSurv"
)
)
}
),

private = list(
.predict = function(input) {
pred = input[[1]]
data = input[[2]]
assert_true(!is.null(pred$prob))
data = cbind(data, pred = pred$prob[, "0"])

## convert hazards to surv as prod(1 - h(t))
rows_per_id = nrow(data)/length(unique(data$id))
surv = t(vapply(unique(data$id), function(unique_id) {
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
x = cumprod((data[data$id == unique_id, ][["pred"]]))
x
}, numeric(rows_per_id)))

pred_list = list()
unique_end_times = sort(unique(data$tend))
## coerce to distribution and crank
pred_list = .surv_return(times = unique_end_times, surv = surv)

# select the real tend values by only selecting the last row of each id
# basically a slightly more complex unique()
real_tend = data$time2[seq_len(nrow(data)) %% rows_per_id == 0]

# select last row for every id
data = as.data.table(data)
id = ped_status = NULL # to fix note
data = data[, .SD[.N, list(ped_status)], by = id]

## create prediction object
p = PredictionSurv$new(
row_ids = seq_row(data),
crank = pred_list$crank, distr = pred_list$distr,
truth = Surv(real_tend, as.integer(as.character(data$ped_status))))

list(p)
},

.train = function(input) {
self$state = list()
list(input)
}
)
)

register_pipeop("trafopred_classifsurv", PipeOpPredClassifSurv)
150 changes: 150 additions & 0 deletions R/PipeOpTaskSurvClassif.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#' @title PipeOpTaskSurvClassif
studener marked this conversation as resolved.
Show resolved Hide resolved
#' @name mlr_pipeops_trafotask_survclassif
#' @template param_pipelines
#'
#' @description
#' Transform [TaskSurv] to [TaskClassif][mlr3::TaskClassif] by creating multiple
studener marked this conversation as resolved.
Show resolved Hide resolved
#' interval observations for each subject based on `cut`, with a `ped_status` variable
#' indicating whether an event occurred in each interval.
#'
#' @section Input and Output Channels:
#' Input and output channels are inherited from [PipeOp][mlr3pipelines::PipeOp].
#'
#' The output is the input [TaskSurv] transformed to a [TaskClassif][mlr3::TaskClassif]
#' as well as the transformed data during prediction.
studener marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @section State:
#' The `$state` contains information about the `cut` parameter used
#' as well as `time_var` and `event_var`, the names of the two target
#' columns of the survival task.
#'
#' @section Parameters:
#' The parameters are
#'
#' * `cut :: numeric()`\cr
#' Split points, used to partition the data into intervals.
studener marked this conversation as resolved.
Show resolved Hide resolved
#' If unspecified, all unique event times will be used.
#' If `cut` is a single integer, it will be interpreted as the number of equidistant
#' intervals from 0 until the maximum event time.
#' * `max_time :: numeric(1)`\cr
#' If cut is unspecified, this will be the last possible event time.
#' All event times after max_time will be administratively censored at max_time.
studener marked this conversation as resolved.
Show resolved Hide resolved
#' Needs to be greater than the minimum event time.
studener marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @examples
#' \dontrun{
#' if (requireNamespace("mlr3pipelines", quietly = TRUE)) {
#' library(mlr3)
#' library(mlr3pipelines)
#'
#' task = tsk("lung")
#' po = po("trafotask_survclassif")
#' po$train(list(task))
#' po$predict(list(task))[[1]]
#' }
#' }
#'
#' @references
#' `r format_bib("tutz_2016")`
#'
#' @family PipeOps
#' @family Transformation PipeOps
#' @export
PipeOpTaskSurvClassif = R6Class("PipeOpTaskSurvClassif",
inherit = mlr3pipelines::PipeOp,

public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "trafotask_survclassif") {
param_set = ps(
studener marked this conversation as resolved.
Show resolved Hide resolved
cut = p_uty(default = NULL),
max_time = p_dbl(default = NULL, special_vals = list(NULL))
)
super$initialize(
id = id,
param_set = param_set,
input = data.table::data.table(
name = "input",
train = "TaskSurv",
predict = "TaskSurv"
),
output = data.table::data.table(
name = c("output", "transformed_data"),
train = c("TaskClassif", "data.table"),
predict = c("TaskClassif", "data.table")
)
)
}
),

private = list(
.train = function(input) {
task = input[[1]]
studener marked this conversation as resolved.
Show resolved Hide resolved
data = task$data()
assert_true(task$censtype == "right")

cut = assert_numeric(self$param_set$values$cut, null.ok = TRUE, lower = 0)
max_time = self$param_set$values$max_time

time_var = task$target_names[1]
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
event_var = task$target_names[2]
if (testInt(cut, lower = 1)) {
cut = seq(0, data[get(event_var) == 1, max(get(time_var))], length.out = cut + 1)
}
if (!is.null(max_time)) {
assert(max_time > data[get(event_var) == 1, min(get(time_var))],
"max_time must be greater than the minimum event time.")
}

form = mlr3misc::formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".")

# TODO: do without pammtools
long_data = pammtools::as_ped(data = data, formula = form, cut = cut, max_time = max_time)
self$state$cut = attributes(long_data)$trafo_args$cut
self$state$event_var = event_var
self$state$time_var = time_var
long_data = as.data.table(long_data)
long_data$ped_status = factor(long_data$ped_status, levels = c("0", "1"))

# remove offset, tstart, interval for dataframe long_data
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
long_data[, c("offset", "tstart", "interval") := NULL]

task = TaskClassif$new(paste0(task$id, "_disc"), long_data, target = "ped_status", positive = "1")
task$set_col_roles("id", roles = "name")

list(task, data.table())
},

.predict = function(input) {
task = input[[1]]
data = task$data()

# extract required data from `state`
cut = self$state$cut
time_var = self$state$time_var
event_var = self$state$event_var

max_time = max(cut)
time = data[[time_var]]
data[[time_var]] = max_time

# update form
form = mlr3misc::formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".")

new_data = pammtools::as_ped(data, formula = form, cut = cut)
new_data = as.data.table(new_data)
new_data$ped_status = factor(new_data$ped_status, levels = c("0", "1"))

# remove offset, tstart, interval for dataframe long_data
new_data[, c("offset", "tstart", "interval") := NULL]
task = TaskClassif$new(paste0(task$id, "_disc"), new_data, target = "ped_status", positive = "1")
task$set_col_roles("id", roles = "name")

new_data$time2 = rep(time, each = sum(new_data$id == 1))
list(task, new_data)
}
)
)

register_pipeop("trafotask_survclassif", PipeOpTaskSurvClassif)
Loading