Skip to content

Commit

Permalink
Merge pull request #406 from mlr-org/disctime_fixes
Browse files Browse the repository at this point in the history
Disctime fixes
  • Loading branch information
bblodfon authored Jul 31, 2024
2 parents f74d5cf + bbd0c97 commit 21a5dae
Show file tree
Hide file tree
Showing 22 changed files with 247 additions and 77 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/r-cmd-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on:
branches:
- main

name: r-cmd-check
name: R-CMD-check

jobs:
r-cmd-check:
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.6.5
Version: 0.6.6
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mlr3proba 0.6.6

- Small fixes and refactoring to the discrete-time pipeops

# mlr3proba 0.6.5

* Add support for discrete-time survival analysis
Expand Down
25 changes: 17 additions & 8 deletions R/PipeOpPredClassifSurvDiscTime.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
#' conditional probability for an event in the \eqn{k}-interval.
#' - \eqn{p_k = 1 - h_k = P(T \ge t_k | T \ge t_{k-1})}
#'
#' @section Dictionary:
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the
#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops]
#' or with the associated sugar function [mlr3pipelines::po()]:
#' ```
#' PipeOpPredClassifSurvDiscTime$new()
#' mlr_pipeops$get("trafopred_classifsurv_disctime")
#' po("trafopred_classifsurv_disctime")
#' ```
#'
#' @section Input and Output Channels:
#' The input is a [PredictionClassif] and a [data.table][data.table::data.table]
#' with the transformed data both generated by [PipeOpTaskSurvClassifDiscTime].
Expand Down Expand Up @@ -68,25 +78,24 @@ PipeOpPredClassifSurvDiscTime = R6Class(
cumprod(1 - data[data$id == unique_id, ][["dt_hazard"]])
}, numeric(rows_per_id)))

pred_list = list()
unique_end_times = sort(unique(data$tend))
# coerce to distribution and crank
pred_list = .surv_return(times = unique_end_times, surv = surv)

# select the real tend values by only selecting the last row of each id
# basically a slightly more complex unique()
real_tend = data$time2[seq_len(nrow(data)) %% rows_per_id == 0]
real_tend = data$obs_times[seq_len(nrow(data)) %% rows_per_id == 0]

# select last row for every id
data = as.data.table(data)
id = ped_status = NULL # to fix note
data = data[, .SD[.N, list(ped_status)], by = id]
ids = unique(data$id)
# select last row for every id => observed times
id = disc_status = NULL # to fix note
data = data[, .SD[.N, list(disc_status)], by = id]

# create prediction object
p = PredictionSurv$new(
row_ids = seq_row(data),
row_ids = ids,
crank = pred_list$crank, distr = pred_list$distr,
truth = Surv(real_tend, as.integer(as.character(data$ped_status))))
truth = Surv(real_tend, as.integer(as.character(data$disc_status))))

list(p)
},
Expand Down
88 changes: 62 additions & 26 deletions R/PipeOpTaskSurvClassifDiscTime.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,40 @@
#' @description
#' Transform [TaskSurv] to [TaskClassif][mlr3::TaskClassif] by dividing continuous
#' time into multiple time intervals for each observation.
#' This transformation creates a new target variable `ped_status` that indicates
#' This transformation creates a new target variable `disc_status` that indicates
#' whether an event occurred within each time interval.
#' This approach facilitates survival analysis within a classification framework
#' using discrete time intervals (Tutz et al. 2016).
#'
#' @section Dictionary:
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the
#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops]
#' or with the associated sugar function [mlr3pipelines::po()]:
#' ```
#' PipeOpTaskSurvClassifDiscTime$new()
#' mlr_pipeops$get("trafotask_survclassif_disctime")
#' po("trafotask_survclassif_disctime")
#' ```
#'
#' @section Input and Output Channels:
#' [PipeOpTaskSurvClassifDiscTime] has one input channel named "input", and two
#' output channels, one named "output" and the other "transformed_data".
#'
#' During training, the "output" is the "input" [TaskSurv] transformed to a
#' [TaskClassif][mlr3::TaskClassif].
#' The target column is named `ped_status` and indicates whether an event occurred
#' The target column is named `"disc_status"` and indicates whether an event occurred
#' in each time interval.
#' An additional feature named `tend` is added to the ouput task, containing the
#' end time of each interval.
#' An additional feature named `"tend"` contains the end time point of each interval.
#' Lastly, the "output" task has a column with the original observation ids,
#' under the role `"original_ids"`.
#' The "transformed_data" is an empty [data.table][data.table::data.table].
#'
#' During prediction, the "input" [TaskSurv] is transformed to the "output"
#' [TaskClassif][mlr3::TaskClassif] with `ped_status` as target and the `tend`
#' [TaskClassif][mlr3::TaskClassif] with `"disc_status"` as target and the `"tend"`
#' feature included.
#' The "transformed_data" is a [data.table] which has all the features of the
#' "output" task, including an additional column `time2` containing the
#' original times.
#' The "transformed_data" is a [data.table] with columns the `"disc_status"`
#' target of the "output" task, the `"id"` (original observation ids),
#' `"obs_times"` (observed times per `"id"`) and `"tend"` (end time of each interval).
#' This "transformed_data" is only meant to be used with the [PipeOpPredClassifSurvDiscTime].
#'
#' @section State:
Expand Down Expand Up @@ -110,6 +121,10 @@ PipeOpTaskSurvClassifDiscTime = R6Class("PipeOpTaskSurvClassifDiscTime",
assert_true(task$censtype == "right")
data = task$data()

if ("disc_status" %in% colnames(task$data())) {
stop("\"disc_status\" can not be a column in the input data.")
}

cut = assert_numeric(self$param_set$values$cut, null.ok = TRUE, lower = 0)
max_time = self$param_set$values$max_time

Expand All @@ -129,14 +144,20 @@ PipeOpTaskSurvClassifDiscTime = R6Class("PipeOpTaskSurvClassifDiscTime",
long_data = pammtools::as_ped(data = data, formula = form, cut = cut, max_time = max_time)
self$state$cut = attributes(long_data)$trafo_args$cut
long_data = as.data.table(long_data)
long_data$ped_status = factor(long_data$ped_status, levels = c("0", "1"))
setnames(long_data, old = "ped_status", new = "disc_status")
long_data$disc_status = factor(long_data$disc_status, levels = c("0", "1"))

# remove offset, tstart, interval for dataframe long_data
# remove some columns from `long_data`
long_data[, c("offset", "tstart", "interval") := NULL]
# keep id mapping
reps = table(long_data$id)
ids = rep(task$row_ids, times = reps)
id = NULL
long_data[, id := ids]

task_disc = TaskClassif$new(paste0(task$id, "_disc"), long_data,
target = "ped_status", positive = "1")
task_disc$set_col_roles("id", roles = "name")
target = "disc_status", positive = "1")
task_disc$set_col_roles("id", roles = "original_ids")

list(task_disc, data.table())
},
Expand All @@ -161,22 +182,37 @@ PipeOpTaskSurvClassifDiscTime = R6Class("PipeOpTaskSurvClassifDiscTime",
# update form
form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".")

new_data = pammtools::as_ped(data, formula = form, cut = cut)
new_data = as.data.table(new_data)

ped_status = id = NULL # fixing global binding notes of data.table
new_data[, ped_status := 0]
new_data[new_data[, .I[.N], by = id]$V1, ped_status := status]
new_data$ped_status = factor(new_data$ped_status, levels = c("0", "1"))
long_data = as.data.table(pammtools::as_ped(data, formula = form, cut = cut))
setnames(long_data, old = "ped_status", new = "disc_status")

disc_status = id = tend = obs_times = NULL # fixing global binding notes of data.table
long_data[, disc_status := 0]
# set correct id
rows_per_id = nrow(long_data) / length(unique(long_data$id))
long_data$obs_times = rep(time, each = rows_per_id)
ids = rep(task$row_ids, each = rows_per_id)
long_data[, id := ids]

# set correct disc_status
reps = long_data[, data.table(count = sum(tend >= obs_times)), by = id]$count
status = rep(status, times = reps)
long_data[long_data[, .I[tend >= obs_times], by = id]$V1, disc_status := status]
long_data$disc_status = factor(long_data$disc_status, levels = c("0", "1"))

# remove some columns from `long_data`
long_data[, c("offset", "tstart", "interval", "obs_times") := NULL]
task_disc = TaskClassif$new(paste0(task$id, "_disc"), long_data,
target = "disc_status", positive = "1")
task_disc$set_col_roles("id", roles = "original_ids")

# remove offset, tstart, interval for dataframe long_data
new_data[, c("offset", "tstart", "interval") := NULL]
task_disc = TaskClassif$new(paste0(task$id, "_disc"), new_data,
target = "ped_status", positive = "1")
task_disc$set_col_roles("id", roles = "name")
# map observed times back
reps = table(long_data$id)
long_data$obs_times = rep(time, each = rows_per_id)
# subset transformed data
columns_to_keep = c("id", "obs_times", "tend", "disc_status")
long_data = long_data[, columns_to_keep, with = FALSE]

new_data$time2 = rep(time, each = sum(new_data$id == 1))
list(task_disc, new_data)
list(task_disc, long_data)
}
)
)
Expand Down
2 changes: 1 addition & 1 deletion R/TaskSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ TaskSurv = R6::R6Class("TaskSurv",

#' @description
#' Checks if the data satisfy the *proportional hazards (PH)* assumption using
#' the Grambsch-Therneau test, `r mlr3misc::cite_bib("grambsch_1994")`.
#' the Grambsch-Therneau test, `r cite_bib("grambsch_1994")`.
#' Uses [cox.zph][survival::cox.zph()].
#' This method should be used only for **low-dimensional datasets** where
#' the number of features is relatively small compared to the number of
Expand Down
1 change: 1 addition & 0 deletions R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ register_reflections = function() {

x$task_col_roles$surv = x$task_col_roles$regr
x$task_col_roles$dens = c("feature", "target", "label", "order", "group", "weight", "stratum")
x$task_col_roles$classif = unique(c(x$task_col_roles$classif, "original_ids")) # for discrete time
x$task_properties$surv = x$task_properties$regr
x$task_properties$dens = x$task_properties$regr

Expand Down
2 changes: 1 addition & 1 deletion R/pipelines.R
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ pipeline_survtoclassif_disctime = function(learner, cut = NULL, max_time = NULL,

if (!is.null(rhs)) {
gr$edges = gr$edges[-1, ]
gr$add_pipeop(mlr3pipelines::po("modelmatrix", formula = mlr3misc::formulate(rhs = rhs, quote = "left")))
gr$add_pipeop(mlr3pipelines::po("modelmatrix", formula = formulate(rhs = rhs, quote = "left")))
gr$add_edge(src_id = "trafotask_survclassif_disctime", dst_id = "modelmatrix", src_channel = "output")
gr$add_edge(src_id = "modelmatrix", dst_id = learner$id, src_channel = "output", dst_channel = "input")
}
Expand Down
26 changes: 26 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ utils::globalVariables(c(
setHook(event, hooks[pkgname != "mlr3proba"], action = "replace")

# unregister
unregister_reflections()
walk(names(mlr3proba_learners), function(nm) mlr_learners$remove(nm))
walk(names(mlr3proba_tasks), function(nm) mlr_tasks$remove(nm))
walk(names(mlr3proba_measures), function(nm) mlr_measures$remove(nm))
Expand All @@ -75,4 +76,29 @@ utils::globalVariables(c(
library.dynam.unload("mlr3proba", libpath)
}

unregister_reflections = function() {
x = utils::getFromNamespace("mlr_reflections", ns = "mlr3")

# task
package = NULL # silence data.table notes
x$task_types[package != "mlr3proba"]
x$task_col_roles$surv = NULL
x$task_col_roles$dens = NULL
x$task_col_roles$classif = setdiff(x$task_col_roles$classif, "original_ids")
x$task_properties$surv = NULL
x$task_properties$dens = NULL

# learner
x$learner_properties$surv = NULL
x$learner_properties$dens = NULL
x$learner_predict_types$surv = NULL
x$learner_predict_types$dens = NULL

# measure
x$measure_properties$surv = NULL
x$measure_properties$dens = NULL
x$default_measures$surv = NULL
x$default_measures$dens = NULL
}

leanify_package()
11 changes: 5 additions & 6 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@ knitr::opts_chunk$set(

# mlr3proba

Package website: [release](https://mlr3proba.mlr-org.com/)

Probabilistic Supervised Learning for **[mlr3](https://github.com/mlr-org/mlr3/)**.
Probabilistic Supervised Learning for **[mlr3](https://github.com/mlr-org/mlr3/)** ([website](https://mlr3proba.mlr-org.com/)).

<!-- badges: start -->
[![r-cmd-check](https://github.com/mlr-org/mlr3proba/actions/workflows/r-cmd-check.yml/badge.svg)](https://github.com/mlr-org/mlr3proba/actions/workflows/r-cmd-check.yml)
[![R-CMD-check](https://github.com/mlr-org/mlr3proba/actions/workflows/r-cmd-check.yml/badge.svg)](https://github.com/mlr-org/mlr3proba/actions/workflows/r-cmd-check.yml)
[![runiverse](https://mlr-org.r-universe.dev/badges/mlr3proba)](https://mlr-org.r-universe.dev/mlr3proba)
[![GitHub Discussions](https://img.shields.io/github/discussions/mlr-org/mlr3proba?logo=github&label=Discussions%20Q%26A&color=FFE600)](https://github.com/mlr-org/mlr3proba/discussions)
[![Article](https://img.shields.io/badge/Article-10.1093%2Fbioinformatics%2Fbtab039-brightgreen)](https://doi.org/10.1093/bioinformatics/btab039)
[![StackOverflow](https://img.shields.io/badge/stackoverflow-mlr3-orange.svg)](https://stackoverflow.com/questions/tagged/mlr3)
[![Mattermost](https://img.shields.io/badge/chat-mattermost-orange.svg)](https://lmmisld-lmu-stats-slds.srv.mwn.de/mlr_invite/)
[![StackOverflow](https://img.shields.io/badge/stackoverflow-mlr3-orange.svg?color=pink)](https://stackoverflow.com/questions/tagged/mlr3)
[![Mattermost](https://img.shields.io/badge/chat-mattermost-orange.svg?color=pink)](https://lmmisld-lmu-stats-slds.srv.mwn.de/mlr_invite/)
<!-- badges: end -->

## What is mlr3proba?
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@

# mlr3proba

Package website: [release](https://mlr3proba.mlr-org.com/)

Probabilistic Supervised Learning for
**[mlr3](https://github.com/mlr-org/mlr3/)**.
**[mlr3](https://github.com/mlr-org/mlr3/)**
([website](https://mlr3proba.mlr-org.com/)).

<!-- badges: start -->

[![r-cmd-check](https://github.com/mlr-org/mlr3proba/actions/workflows/r-cmd-check.yml/badge.svg)](https://github.com/mlr-org/mlr3proba/actions/workflows/r-cmd-check.yml)
[![R-CMD-check](https://github.com/mlr-org/mlr3proba/actions/workflows/r-cmd-check.yml/badge.svg)](https://github.com/mlr-org/mlr3proba/actions/workflows/r-cmd-check.yml)
[![runiverse](https://mlr-org.r-universe.dev/badges/mlr3proba)](https://mlr-org.r-universe.dev/mlr3proba)
[![GitHub
Discussions](https://img.shields.io/github/discussions/mlr-org/mlr3proba?logo=github&label=Discussions%20Q%26A&color=FFE600)](https://github.com/mlr-org/mlr3proba/discussions)
[![Article](https://img.shields.io/badge/Article-10.1093%2Fbioinformatics%2Fbtab039-brightgreen)](https://doi.org/10.1093/bioinformatics/btab039)
[![StackOverflow](https://img.shields.io/badge/stackoverflow-mlr3-orange.svg)](https://stackoverflow.com/questions/tagged/mlr3)
[![Mattermost](https://img.shields.io/badge/chat-mattermost-orange.svg)](https://lmmisld-lmu-stats-slds.srv.mwn.de/mlr_invite/)
[![StackOverflow](https://img.shields.io/badge/stackoverflow-mlr3-orange.svg?color=pink)](https://stackoverflow.com/questions/tagged/mlr3)
[![Mattermost](https://img.shields.io/badge/chat-mattermost-orange.svg?color=pink)](https://lmmisld-lmu-stats-slds.srv.mwn.de/mlr_invite/)
<!-- badges: end -->

## What is mlr3proba?
Expand Down
2 changes: 1 addition & 1 deletion inst/testthat/helper_expectations.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ expect_task_surv = function(task) {

f = task$formula()
expect_formula(f)
expect_set_equal(mlr3misc::extract_vars(f)$lhs, task$target_names)
expect_setequal(extract_vars(f)$lhs, task$target_names)
expect_class(task$kaplan(), "survfit")
}

Expand Down
12 changes: 12 additions & 0 deletions man/mlr_pipeops_trafopred_classifsurv_disctime.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 21a5dae

Please sign in to comment.