Skip to content

Commit

Permalink
add reverse methods
Browse files Browse the repository at this point in the history
  • Loading branch information
RaphaelS1 committed Dec 7, 2023
1 parent 083e685 commit 803a86d
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 5 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.5.4
Version: 0.5.5
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# mlr3proba 0.5.5

* Add `$reverse()` method to `TaskSurv`, which returns the same task but with 1-status.
* Add `reverse` parameter to `TaskSurv$kaplan()` method, which calculates Kaplan-Meier on the censoring distribution of the task (1-status).

# mlr3proba 0.5.4

* Fix bottlenecks in Dcalib and RCLL
Expand Down
31 changes: 27 additions & 4 deletions R/TaskSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,19 @@ TaskSurv = R6::R6Class("TaskSurv",
#'
#' @param rhs
#' If `NULL` RHS is `.`, otherwise gives RHS of formula.
#' @param reverse
#' If `TRUE` then formula calculated with 1 - status.
#'
#' @return `numeric()`.
formula = function(rhs = NULL) {
formula = function(rhs = NULL, reverse = FALSE) {
# formula appends the rhs argument to Surv(time, event)~
tn = self$target_names
if (length(tn) == 2) {
lhs = sprintf("Surv(%s, %s, type = '%s')", tn[1L], tn[2L], self$censtype)
if (reverse) {
lhs = sprintf("Surv(%s, 1 - %s, type = '%s')", tn[1L], tn[2L], self$censtype)
} else {
lhs = sprintf("Surv(%s, %s, type = '%s')", tn[1L], tn[2L], self$censtype)
}
} else {
lhs = sprintf("Surv(%s, %s, %s, type = '%s')", tn[1L], tn[2L], tn[3L], self$censtype)
}
Expand Down Expand Up @@ -203,15 +209,32 @@ TaskSurv = R6::R6Class("TaskSurv",
#' Stratification variables to use.
#' @param rows (`integer()`)\cr
#' Subset of row indices.
#' @param reverse (`logical()`)\cr
#' If `TRUE` calculates Kaplan-Meier of censoring distribution (1-status). Default `FALSE`.
#' @param ... (any)\cr
#' Additional arguments passed down to [survival::survfit.formula()].
#' @return [survival::survfit.object].
kaplan = function(strata = NULL, rows = NULL, ...) {
kaplan = function(strata = NULL, rows = NULL, reverse = FALSE, ...) {
assert_character(strata, null.ok = TRUE)
f = self$formula(strata %??% 1)
f = self$formula(strata %??% 1, reverse)
cols = c(self$target_names, intersect(self$backend$colnames, strata))
data = self$data(cols = cols, rows = rows)
survival::survfit(f, data = data, ...)
},

#' @description
#' Returns the same task with the status variable reversed, i.e., 1 - status.
#' Only designed for left and right censoring.
#'
#' @return [mlr3proba::TaskSurv].
reverse = function() {
assert(self$censtype %in% c("left", "right"))
d = copy(self$data())
d[, (self$target_names[2L]) := 1 - get(self$target_names[2L])]
as_task_surv(d, self$target_names[1L],
self$target_names[2L],
type = self$censtype, id = paste0(self$id, "_reverse")
)
}
),

Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test_TaskSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,15 @@ test_that("as_task_surv", {
expect_false("litter" %in% names(t1$data()))
expect_false("litter" %in% names(t2$data()))
})

test_that("reverse", {
t = tsk("rats")
expect_equal(t$kaplan()$surv,
survival::survfit(Surv(time, status) ~ 1, t$data())$surv)
expect_equal(t$kaplan(reverse = TRUE)$surv,
survival::survfit(Surv(time, 1 - status) ~ 1, t$data())$surv)

t2 = tsk("rats")$reverse()
expect_equal(t$kaplan(reverse = TRUE)$surv, t2$kaplan()$surv)
expect_equal(t2$data()$status, 1 - t$data()$status)
})

0 comments on commit 803a86d

Please sign in to comment.