Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix measure bottlenecks #337

Merged
merged 17 commits into from
Nov 20, 2023
Merged
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.5.3
Version: 0.5.4
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down Expand Up @@ -43,7 +43,7 @@ Depends:
Imports:
checkmate,
data.table,
distr6 (>= 1.8.3),
distr6 (>= 1.8.4),
ggplot2,
mlr3misc (>= 0.7.0),
mlr3viz,
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mlr3proba 0.5.4

* Fix bottlenecks in Dcalib and RCLL

# mlr3proba 0.5.3

* Add support for learners that can predict multiple posterior distributions by using `distr6::Arrdist`
Expand Down
83 changes: 58 additions & 25 deletions R/MeasureSurvDCalibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,25 @@
#' @templateVar fullname MeasureSurvDCalibration
#'
#' @description
#' This calibration method is defined by calculating
#' This calibration method is defined by calculating the following statistic:
#' \deqn{s = B/n \sum_i (P_i - n/B)^2}
#' where \eqn{B} is number of 'buckets', \eqn{n} is the number of predictions,
#' and \eqn{P_i} is the predicted number of deaths in the \eqn{i}th interval
#' [0, 100/B), [100/B, 50/B),....,[(B - 100)/B, 1).
#' where \eqn{B} is number of 'buckets' (that equally divide \eqn{[0,1]} into intervals),
#' \eqn{n} is the number of predictions, and \eqn{P_i} is the observed proportion
#' of observations in the \eqn{i}th interval. An observation is assigned to the
#' \eqn{i}th bucket, if its predicted survival probability at the time of event
#' falls within the corresponding interval.
#' This statistic assumes that censoring time is independent of death time.
#'
#' A model is well-calibrated if `s ~ Unif(B)`, tested with `chisq.test`
#' (`p > 0.05` if well-calibrated).
#' Model `i` is better calibrated than model `j` if `s_i < s_j`.
#' A model is well-calibrated if \eqn{s \sim Unif(B)}, tested with `chisq.test`
#' (\eqn{p > 0.05} if well-calibrated).
#' Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)},
#' meaning that *lower values* of this measure are preferred.
#'
#' @details
#' This measure can either return the test statistic or the p-value from the `chisq.test`.
#' The former is useful for model comparison whereas the latter is useful for determining if a model
#' is well-calibration. If `chisq = FALSE` and `m` is the predicted value then you can manually
#' compute the p.value with `pchisq(m, B - 1, lower.tail = FALSE)`.
#' is well-calibrated. If `chisq = FALSE` and `s` is the predicted value then you can manually
#' compute the p.value with `pchisq(s, B - 1, lower.tail = FALSE)`.
#'
#' NOTE: This measure is still experimental both theoretically and in implementation. Results
#' should therefore only be taken as an indicator of performance and not for
Expand All @@ -34,18 +38,29 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration",
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
#' @param B (`integer(1)`) \cr
#' Number of buckets to test for uniform predictions over. Default of `10` is recommended by
#' Haider et al. (2020).
#' Number of buckets to test for uniform predictions over.
#' Default of `10` is recommended by Haider et al. (2020).
#' Changing this parameter affects `truncate`.
#' @param chisq (`logical(1)`) \cr
#' If `TRUE` returns the p.value of the corresponding chisq.test instead of the measure.
#' Otherwise this can be performed manually with `pchisq(m, B - 1, lower.tail = FALSE)`.
#' `p > 0.05` indicates well-calibrated.
#' If `TRUE` returns the p-value of the corresponding chisq.test instead of the measure.
#' Default is `FALSE` and returns the statistic `s`.
#' You can manually get the p-value by executing `pchisq(s, B - 1, lower.tail = FALSE)`.
#' `p > 0.05` indicates a well-calibrated model.
#' @param truncate (`double(1)`) \cr
#' This parameter controls the upper bound of the output statistic,
#' when `chisq` is `FALSE`. The default `truncate` value of \eqn{10}
#' corresponds to a p-value of 0.35 for the chisq.test using \eqn{B = 10} buckets.
#' Values \eqn{>10} translate to even lower p-values and thus less calibrated
#' models. If the number of buckets \eqn{B} changes, you probably will want to
#' change the `truncate` value as well to correspond to the same p-value significance.
#' Initialize with `truncate = Inf` if no truncation is desired.
initialize = function() {
ps = ps(
B = p_int(1, default = 10),
chisq = p_lgl(default = FALSE)
chisq = p_lgl(default = FALSE),
truncate = p_dbl(lower = 0, upper = Inf, default = 10)
)
ps$values = list(B = 10L, chisq = FALSE)
ps$values = list(B = 10L, chisq = FALSE, truncate = 10)

super$initialize(
id = "surv.dcalib",
Expand All @@ -62,18 +77,36 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration",
private = list(
.score = function(prediction, ...) {
ps = self$param_set$values
B = ps$B

# initialize buckets
bj = numeric(ps$B)
bj = numeric(B)
true_times = prediction$truth[, 1L]

# predict individual probability of death at observed event time
if (inherits(prediction$distr, "VectorDistribution")) {
si = as.numeric(prediction$distr$survival(data = matrix(prediction$truth[, 1L], nrow = 1L)))
# bypass distr6 construction if possible
if (inherits(prediction$data$distr, "array")) {
surv = prediction$data$distr
if (length(dim(surv)) == 3) {
# survival 3d array, extract median
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
}
times = as.numeric(colnames(surv))

si = diag(distr6:::C_Vec_WeightedDiscreteCdf(true_times, times,
cdf = t(1 - surv), FALSE, FALSE))
} else {
si = diag(prediction$distr$survival(prediction$truth[, 1L]))
distr = prediction$distr
if (inherits(distr, c("Matdist", "Arrdist"))) {
si = diag(distr$survival(true_times))
} else { # VectorDistribution or single Distribution, e.g. WeightDisc()
si = as.numeric(distr$survival(data = matrix(true_times, nrow = 1L)))
}
}
Comment on lines +99 to 105
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is from using lrn('surv.parametric') and filtering to one observation, p$distr is a WeibullAFT() which is not a VectorDistribution, so diag(distr$survival(true_times)) returned 0 and later NaN in the score

# remove zeros
si = map_dbl(si, function(.x) max(.x, 1e-5))
# index of associated bucket
js = ceiling(ps$B * si)
js = ceiling(B * si)

# could remove loop for dead observations but needed for censored ones and minimal overhead
# in combining both
Expand All @@ -83,18 +116,18 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration",
# dead observations contribute 1 to their index
bj[ji] = bj[ji] + 1
} else {
# uncensored observations spread across buckets with most weighting on penultimate
# censored observations spread across buckets with most weighting on penultimate
for (k in seq.int(ji - 1)) {
bj[k] = bj[k] + 1 / (ps$B * si[[i]])
bj[k] = bj[k] + 1 / (B * si[[i]])
}
bj[ji] = bj[ji] + (1 - (ji - 1) / (ps$B * si[[i]]))
bj[ji] = bj[ji] + (1 - (ji - 1) / (B * si[[i]]))
}
}

if (ps$chisq) {
return(stats::chisq.test(bj)$p.value)
} else {
return((ps$B / length(si)) * sum((bj - length(si) / ps$B)^2))
return(min(ps$truncate, (B / length(si)) * sum((bj - length(si) / B)^2)))
}
}
)
Expand Down
56 changes: 46 additions & 10 deletions R/MeasureSurvRCLL.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,53 @@ MeasureSurvRCLL = R6::R6Class("MeasureSurvRCLL",
event = truth[, 2] == 1
event_times = truth[event, 1]
cens_times = truth[!event, 1]
distr = prediction$distr

if (!any(event)) { # all censored
# survival at outcome time (survived *at least* this long)
out[!event] = diag(as.matrix(distr[!event]$survival(cens_times)))
} else if (all(event)) { # all uncensored
# pdf at outcome time (survived *this* long)
out[event] = diag(as.matrix(distr[event]$pdf(event_times)))
} else { # mix
out[event] = diag(as.matrix(distr[event]$pdf(event_times)))
out[!event] = diag(as.matrix(distr[!event]$survival(cens_times)))
# Bypass distr6 construction if underlying distr represented by array
if (inherits(prediction$data$distr, "array")) {
surv = prediction$data$distr
if (length(dim(surv)) == 3) {
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
# survival 3d array, extract median
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
}
times = as.numeric(colnames(surv))

if (any(!event)) {
if (sum(!event) == 1) { # fix subsetting issue in case of 1 censored
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
cdf = as.matrix(1 - surv[!event, ])
} else {
cdf = t(1 - surv[!event, ])
}

out[!event] = diag(
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
distr6:::C_Vec_WeightedDiscreteCdf(cens_times, times, cdf = cdf, FALSE, FALSE)
)
}
if (any(event)) {
pdf = distr6:::cdfpdf(1 - surv)
if (sum(event) == 1) { # fix subsetting issue in case of 1 event
pdf = as.matrix(pdf[event, ])
} else {
pdf = t(pdf[event, ])
}

out[event] = diag(
distr6:::C_Vec_WeightedDiscretePdf(event_times, times, pdf = pdf)
)
}
} else {
distr = prediction$distr

# Splitting in this way bypasses unnecessary distr extraction
if (!any(event)) { # all censored
# survival at outcome time (survived *at least* this long)
out = diag(as.matrix(distr$survival(cens_times)))
} else if (all(event)) { # all uncensored
# pdf at outcome time (survived *this* long)
out = diag(as.matrix(distr$pdf(event_times)))
} else { # mix
out[event] = diag(as.matrix(distr[event]$pdf(event_times)))
out[!event] = diag(as.matrix(distr[!event]$survival(cens_times)))
}
}

stopifnot(!any(out == -99L)) # safety check
Expand Down
21 changes: 16 additions & 5 deletions R/PredictionDataSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,23 @@ filter_prediction_data.PredictionDataSurv = function(pdata, row_ids, ...) {
}

if (!is.null(pdata$distr)) {
if (inherits(pdata$distr, "matrix")) {
pdata$distr = pdata$distr[keep, , drop = FALSE]
} else { # array
pdata$distr = pdata$distr[keep, , , drop = FALSE]
distr = pdata$distr

if (testDistribution(distr)) { # distribution
ok = inherits(distr, c("VectorDistribution", "Matdist", "Arrdist")) &&
length(keep) > 1 # e.g.: Arrdist(1xYxZ) and keep = FALSE
if (ok) {
pdata$distr = distr[keep] # we can subset row/samples like this
} else {
pdata$distr = base::switch(keep, distr) # one distribution only
}
} else {
if (length(dim(distr)) == 2) { # 2d matrix
pdata$distr = distr[keep, , drop = FALSE]
} else { # 3d array
pdata$distr = distr[keep, , , drop = FALSE]
}
}

}

pdata
Expand Down
4 changes: 3 additions & 1 deletion R/PredictionSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,12 @@ PredictionSurv = R6Class("PredictionSurv",
}
},
.distrify_survarray = function(x) {
if (inherits(x, "array")) { # can be matrix as well
if (inherits(x, "array") && nrow(x) > 0) { # can be matrix as well
# create Matdist or Arrdist (default => median curve)
distr6::as.Distribution(1 - x, fun = "cdf",
decorators = c("CoreStatistics", "ExoticStatistics"))
} else {
NULL
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
}
}
)
Expand Down
4 changes: 2 additions & 2 deletions inst/testthat/helper_expectations.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ expect_prediction_surv = function(p) {
"response", "distr", "lp", "crank"))
checkmate::expect_data_table(data.table::as.data.table(p), nrows = length(p$row_ids))
checkmate::expect_atomic_vector(p$missing)
if ("distr" %in% p$predict_types) {
expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist"))
if ("distr" %in% p$predict_types && !is.null(p$distr)) {
expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist", "WeightedDiscrete"))
}
expect_true(inherits(p, "PredictionSurv"))
}
43 changes: 29 additions & 14 deletions man/mlr_measures_surv.dcalib.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading