Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix measure bottlenecks #337

Merged
merged 17 commits into from
Nov 20, 2023
Merged
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.5.3
Version: 0.5.4
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down Expand Up @@ -43,7 +43,7 @@ Depends:
Imports:
checkmate,
data.table,
distr6 (>= 1.8.3),
distr6 (>= 1.8.4),
ggplot2,
mlr3misc (>= 0.7.0),
mlr3viz,
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mlr3proba 0.5.4

* Fix bottlenecks in Dcalib and RCLL

# mlr3proba 0.5.3

* Add support for learners that can predict multiple posterior distributions by using `distr6::Arrdist`
Expand Down
55 changes: 38 additions & 17 deletions R/MeasureSurvDCalibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@
#' @templateVar fullname MeasureSurvDCalibration
#'
#' @description
#' This calibration method is defined by calculating
#' This calibration method is defined by calculating the following statistic:
#' \deqn{s = B/n \sum_i (P_i - n/B)^2}
#' where \eqn{B} is number of 'buckets', \eqn{n} is the number of predictions,
#' and \eqn{P_i} is the predicted number of deaths in the \eqn{i}th interval
#' [0, 100/B), [100/B, 50/B),....,[(B - 100)/B, 1).
#' where \eqn{B} is number of 'buckets' (that equally divide \eqn{[0,1]} into intervals),
#' \eqn{n} is the number of predictions, and \eqn{P_i} is the observed proportion
#' of observations in the \eqn{i}th interval. An observation is assigned to the
#' \eqn{i}th bucket, if its predicted survival probability at the time of event
#' falls within the corresponding interval.
#' This statistic assumes that censoring time is independent of death time.
#'
#' A model is well-calibrated if `s ~ Unif(B)`, tested with `chisq.test`
#' (`p > 0.05` if well-calibrated).
#' Model `i` is better calibrated than model `j` if `s_i < s_j`.
#' A model is well-calibrated if \eqn{s \sim Unif(B)}, tested with `chisq.test`
#' (\eqn{p > 0.05} if well-calibrated).
#' Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)}.
#'
#' @details
#' This measure can either return the test statistic or the p-value from the `chisq.test`.
#' The former is useful for model comparison whereas the latter is useful for determining if a model
#' is well-calibration. If `chisq = FALSE` and `m` is the predicted value then you can manually
#' is well-calibrated. If `chisq = FALSE` and `m` is the predicted value then you can manually
#' compute the p.value with `pchisq(m, B - 1, lower.tail = FALSE)`.
#'
#' NOTE: This measure is still experimental both theoretically and in implementation. Results
Expand Down Expand Up @@ -62,18 +65,36 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration",
private = list(
.score = function(prediction, ...) {
ps = self$param_set$values
B = ps$B

# initialize buckets
bj = numeric(ps$B)
bj = numeric(B)
true_times = prediction$truth[, 1L]

# predict individual probability of death at observed event time
if (inherits(prediction$distr, "VectorDistribution")) {
si = as.numeric(prediction$distr$survival(data = matrix(prediction$truth[, 1L], nrow = 1L)))
# bypass distr6 construction if possible
if (inherits(prediction$data$distr, "array")) {
surv = prediction$data$distr
if (length(dim(surv)) == 3) {
# survival 3d array, extract median
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
}
times = as.numeric(colnames(surv))

si = diag(distr6:::C_Vec_WeightedDiscreteCdf(true_times, times,
cdf = t(1 - surv), FALSE, FALSE))
} else {
si = diag(prediction$distr$survival(prediction$truth[, 1L]))
distr = prediction$distr
if (inherits(distr, c("Matdist", "Arrdist"))) {
si = diag(distr$survival(true_times))
} else { # VectorDistribution or single Distribution, e.g. WeightDisc()
si = as.numeric(distr$survival(data = matrix(true_times, nrow = 1L)))
}
}
Comment on lines +99 to 105
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is from using lrn('surv.parametric') and filtering to one observation, p$distr is a WeibullAFT() which is not a VectorDistribution, so diag(distr$survival(true_times)) returned 0 and later NaN in the score

# remove zeros
si = map_dbl(si, function(.x) max(.x, 1e-5))
# index of associated bucket
js = ceiling(ps$B * si)
js = ceiling(B * si)

# could remove loop for dead observations but needed for censored ones and minimal overhead
# in combining both
Expand All @@ -83,18 +104,18 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration",
# dead observations contribute 1 to their index
bj[ji] = bj[ji] + 1
} else {
# uncensored observations spread across buckets with most weighting on penultimate
# censored observations spread across buckets with most weighting on penultimate
for (k in seq.int(ji - 1)) {
bj[k] = bj[k] + 1 / (ps$B * si[[i]])
bj[k] = bj[k] + 1 / (B * si[[i]])
}
bj[ji] = bj[ji] + (1 - (ji - 1) / (ps$B * si[[i]]))
bj[ji] = bj[ji] + (1 - (ji - 1) / (B * si[[i]]))
}
}

if (ps$chisq) {
return(stats::chisq.test(bj)$p.value)
} else {
return((ps$B / length(si)) * sum((bj - length(si) / ps$B)^2))
return((B / length(si)) * sum((bj - length(si) / B)^2))
}
}
)
Expand Down
56 changes: 46 additions & 10 deletions R/MeasureSurvRCLL.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,53 @@ MeasureSurvRCLL = R6::R6Class("MeasureSurvRCLL",
event = truth[, 2] == 1
event_times = truth[event, 1]
cens_times = truth[!event, 1]
distr = prediction$distr

if (!any(event)) { # all censored
# survival at outcome time (survived *at least* this long)
out[!event] = diag(as.matrix(distr[!event]$survival(cens_times)))
} else if (all(event)) { # all uncensored
# pdf at outcome time (survived *this* long)
out[event] = diag(as.matrix(distr[event]$pdf(event_times)))
} else { # mix
out[event] = diag(as.matrix(distr[event]$pdf(event_times)))
out[!event] = diag(as.matrix(distr[!event]$survival(cens_times)))
# Bypass distr6 construction if underlying distr represented by array
if (inherits(prediction$data$distr, "array")) {
surv = prediction$data$distr
if (length(dim(surv)) == 3) {
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
# survival 3d array, extract median
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
}
times = as.numeric(colnames(surv))

if (any(!event)) {
if (sum(!event) == 1) { # fix subsetting issue in case of 1 censored
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
cdf = as.matrix(1 - surv[!event, ])
} else {
cdf = t(1 - surv[!event, ])
}

out[!event] = diag(
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
distr6:::C_Vec_WeightedDiscreteCdf(cens_times, times, cdf = cdf, FALSE, FALSE)
)
}
if (any(event)) {
pdf = distr6:::cdfpdf(1 - surv)
if (sum(event) == 1) { # fix subsetting issue in case of 1 event
pdf = as.matrix(pdf[event, ])
} else {
pdf = t(pdf[event, ])
}

out[event] = diag(
distr6:::C_Vec_WeightedDiscretePdf(event_times, times, pdf = pdf)
)
}
} else {
distr = prediction$distr

# Splitting in this way bypasses unnecessary distr extraction
if (!any(event)) { # all censored
# survival at outcome time (survived *at least* this long)
out = diag(as.matrix(distr$survival(cens_times)))
} else if (all(event)) { # all uncensored
# pdf at outcome time (survived *this* long)
out = diag(as.matrix(distr$pdf(event_times)))
} else { # mix
out[event] = diag(as.matrix(distr[event]$pdf(event_times)))
out[!event] = diag(as.matrix(distr[!event]$survival(cens_times)))
}
}

stopifnot(!any(out == -99L)) # safety check
Expand Down
21 changes: 16 additions & 5 deletions R/PredictionDataSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,23 @@ filter_prediction_data.PredictionDataSurv = function(pdata, row_ids, ...) {
}

if (!is.null(pdata$distr)) {
if (inherits(pdata$distr, "matrix")) {
pdata$distr = pdata$distr[keep, , drop = FALSE]
} else { # array
pdata$distr = pdata$distr[keep, , , drop = FALSE]
distr = pdata$distr

if (testDistribution(distr)) { # distribution
ok = inherits(distr, c("VectorDistribution", "Matdist", "Arrdist")) &&
length(keep) > 1 # e.g.: Arrdist(1xYxZ) and keep = FALSE
if (ok) {
pdata$distr = distr[keep] # we can subset row/samples like this
} else {
pdata$distr = base::switch(keep, distr) # one distribution only
}
} else {
if (length(dim(distr)) == 2) { # 2d matrix
pdata$distr = distr[keep, , drop = FALSE]
} else { # 3d array
pdata$distr = distr[keep, , , drop = FALSE]
}
}

}

pdata
Expand Down
4 changes: 3 additions & 1 deletion R/PredictionSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,12 @@ PredictionSurv = R6Class("PredictionSurv",
}
},
.distrify_survarray = function(x) {
if (inherits(x, "array")) { # can be matrix as well
if (inherits(x, "array") && nrow(x) > 0) { # can be matrix as well
# create Matdist or Arrdist (default => median curve)
distr6::as.Distribution(1 - x, fun = "cdf",
decorators = c("CoreStatistics", "ExoticStatistics"))
} else {
NULL
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
}
}
)
Expand Down
4 changes: 2 additions & 2 deletions inst/testthat/helper_expectations.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ expect_prediction_surv = function(p) {
"response", "distr", "lp", "crank"))
checkmate::expect_data_table(data.table::as.data.table(p), nrows = length(p$row_ids))
checkmate::expect_atomic_vector(p$missing)
if ("distr" %in% p$predict_types) {
expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist"))
if ("distr" %in% p$predict_types && !is.null(p$distr)) {
expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist", "WeightedDiscrete"))
}
expect_true(inherits(p, "PredictionSurv"))
}
19 changes: 11 additions & 8 deletions man/mlr_measures_surv.dcalib.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 48 additions & 2 deletions tests/testthat/test_PredictionSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,21 +176,67 @@ test_that("as_prediction_surv", {
})

test_that("filtering", {
p = suppressWarnings(lrn("surv.coxph")$train(task)$predict(task))
p2 = reshape_distr_to_3d(p) # survival array distr
p = suppressWarnings(lrn("surv.coxph")$train(task)$predict(task)) # survival matrix
p2 = reshape_distr_to_3d(p) # survival array
p3 = p$clone()
p4 = p2$clone()
p3$data$distr = p3$distr # Matdist
p4$data$distr = p4$distr # Arrdist

p$filter(c(20, 37, 42))
p2$filter(c(20, 37, 42))
p3$filter(c(20, 37, 42))
p4$filter(c(20, 37, 42))
expect_prediction_surv(p)
expect_prediction_surv(p2)
expect_prediction_surv(p3)
expect_prediction_surv(p4)

expect_set_equal(p$data$row_ids, c(20, 37, 42))
expect_set_equal(p2$data$row_ids, c(20, 37, 42))
expect_set_equal(p3$data$row_ids, c(20, 37, 42))
expect_set_equal(p4$data$row_ids, c(20, 37, 42))
expect_numeric(p$data$crank, any.missing = FALSE, len = 3)
expect_numeric(p2$data$crank, any.missing = FALSE, len = 3)
expect_numeric(p3$data$crank, any.missing = FALSE, len = 3)
expect_numeric(p4$data$crank, any.missing = FALSE, len = 3)
expect_numeric(p$data$lp, any.missing = FALSE, len = 3)
expect_numeric(p2$data$lp, any.missing = FALSE, len = 3)
expect_numeric(p3$data$lp, any.missing = FALSE, len = 3)
expect_numeric(p4$data$lp, any.missing = FALSE, len = 3)
expect_matrix(p$data$distr, nrows = 3)
expect_array(p2$data$distr, d = 3)
expect_equal(nrow(p2$data$distr), 3)
expect_true(inherits(p3$data$distr, "Matdist"))
expect_true(inherits(p4$data$distr, "Arrdist"))

# edge case: filter to 1 observation
p$filter(20)
p2$filter(20)
p3$filter(20)
p4$filter(20)
expect_prediction_surv(p)
expect_prediction_surv(p2)
expect_prediction_surv(p3)
expect_prediction_surv(p4)
expect_matrix(p$data$distr, nrows = 1)
expect_array(p2$data$distr, d = 3)
expect_equal(nrow(p2$data$distr), 1)
expect_true(inherits(p3$data$distr, "WeightedDiscrete")) # from Matdist!
expect_true(inherits(p4$data$distr, "Arrdist")) # remains an Arrdist!

# filter to 0 observations using non-existent (positive) id
p$filter(42)
p2$filter(42)
p3$filter(42)
p4$filter(42)

expect_prediction_surv(p)
expect_prediction_surv(p2)
expect_prediction_surv(p3)
expect_prediction_surv(p4)
expect_null(p$distr)
expect_null(p2$distr)
expect_null(p3$distr)
expect_null(p4$distr)
})
Loading
Loading