-
-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Connect discSurv as a reduction pipeline #194
Comments
You discretize the follow up into intervals, and create pseudo data with new "status" variable that is 0 if subject survived the interval and 1 if subject experienced event in that interval. Then fit any classifier to the pseudo status variable (with the interval as covariate). The event probability in each interval = the discrete time hazard, from which you can also calculate survival probabilities. If the grid is fine enough, you can approximate any survival time distribution reasonably well. |
@RaphaelS1 Looks legit, however,
Its better to create a new data with all intervals for each id (covariates of the ids remain constant in all intervals). Then predict the hazard for each interval and calculate survival probability accordingly. |
https://www.springer.com/gp/book/9783319281568
Suggestions for a better default?
Not sure I understand this, could you provide code/pseudo-code/math example? |
Here is a proof of concept (programming is terrible). Note that I use library(mlr3)
library(mlr3proba)
library(discSurv)
library(mlr3learners)
library(survival)
library(pammtools)
discSurv_redux <- function(task, cut = NULL, lrn = "classif.ranger") {
out <- list()
data <- as.data.frame(task$data())
time <- task$target_names[[1]]
status <- task$target_names[[2]]
## convert data to discrete time
longData <- as_ped(data=data, Surv(time, status)~., cut = cut)
## get target and feature names
long_data_sim <- longData[, c("tend", "ped_status", task$feature_names[1:3])]
long_data_sim$ped_status <- factor(long_data_sim$ped_status)
## create classif task
task = TaskClassif$new("disc", long_data_sim, target = "ped_status")
data2 <- data
data2$time <- max(data$time)
new_data <- as_ped(longData, newdata = data2)
new_data <- new_data[, c( "id", "tend", "ped_status", task$feature_names[1:3])]
## make prediction
p <- lrn(lrn, predict_type = "prob")$train(task)$predict_newdata(new_data)
## get predictions
pred <- cbind(new_data, pred = p$prob[, 2])
max_t <- max(data$time)
## convert hazards to surv as prod(1 - h(t))
surv <- t(sapply(unique(pred$id), function(i) {
x <- cumprod((1 - pred[pred$id == i, "pred"]))
x
}))
time <- sort(unique(new_data$tend))
## coerce to distribution and crank
r <- .surv_return(time, surv = surv)
## create prediction object
p <- PredictionSurv$new(
row_ids = seq(nrow(data)),
crank = r$crank, distr = r$distr,
truth = Surv(data[["time"]], data[[status]]))
## evaluate with Harrell's C and IGS
out$H_C <- as.numeric(p$score())
out$IGS <- as.numeric(p$score(msr("surv.graf", proper = TRUE)))
out
}
set.seed(18452505)
discSurv_redux(tsk("rats"), cut = seq(0, max(rats$time), length.out = 10), "classif.featureless")
# $H_C
# [1] 0.5
# $IGS
# [1] 0.05894565
discSurv_redux(tsk("rats"), cut = seq(0, max(rats$time), length.out = 100), "classif.featureless")
# $H_C
# [1] 0.5
# $IGS
# [1] 0.05894565
discSurv_redux(tsk("rats"), cut = seq(0, max(rats$time), length.out = 10), "classif.ranger")
# $H_C
# [1] 0.9465666
# $IGS
# [1] 0.07711993
discSurv_redux(tsk("rats"), cut = seq(0, max(rats$time), length.out = 100), "classif.ranger")
# $H_C
# [1] 0.9563922
# $IGS
# [1] 0.0356116
discSurv_redux(tsk("rats"), cut = NULL, "classif.ranger") # cuts at unique event times
# $H_C
# [1] 0.9568337
# $IGS
# [1] 0.04004635 Results make sense to me, although we might think about what featureless means in this context. The way it is used now, you estimate one constant baseline hazard. When you include the variable that indicates intervals, you get piece-wise constant baseline hazard (still without traditional features). |
maybe number of unique event times (or square root of them, if number large).
Note |
I'm not aware of cases where bins are overlapping. See example below for usual workflow.
For each subject only bins/intervals are included where subject was at risk. In the last interval where they are at risk, if they were still alive, all pseudo status values (
Its definitely more sensible to use estimate a more flexible baseline hazard, but we would have to think how this model is called from within library(pammtools)
#>
#> Attaching package: 'pammtools'
#> The following object is masked from 'package:stats':
#>
#> filter
library(mgcv)
#> Loading required package: nlme
#> This is mgcv 1.8-33. For overview type 'help("mgcv-package")'.
library(ggplot2)
theme_set(theme_bw())
library(survival)
set.seed(128)
data <- tumor[sample.int(nrow(tumor), 300, replace = FALSE), ]
data$id <- seq_len(nrow(data))
data <- data[, c("id", "days", "status", "age", "sex", "complications")]
data[c(1,3), ]
#> # A tibble: 2 x 6
#> id days status age sex complications
#> <int> <dbl> <int> <int> <fct> <fct>
#> 1 1 1402 1 64 male no
#> 2 3 645 0 45 male yes
### Data preparation
# discretize follow up: use unique event times as cut points
discretized_data <- data %>% as_ped(Surv(days, status)~., cut = NULL)
nrow(discretized_data)
#> [1] 25513
# show first and last two observations for subjects 1 and 200
discretized_data %>%
filter(id %in% c(1, 3)) %>%
group_by(id) %>%
slice(1:2, (dplyr::n()-1):dplyr::n())
#> # A tibble: 8 x 9
#> # Groups: id [2]
#> id tstart tend interval offset ped_status age sex complications
#> <int> <dbl> <dbl> <fct> <dbl> <dbl> <int> <fct> <fct>
#> 1 1 0 2 (0,2] 0.693 0 64 male no
#> 2 1 2 3 (2,3] 0 0 64 male no
#> 3 1 1383 1393 (1383,1393] 2.30 0 64 male no
#> 4 1 1393 1402 (1393,1402] 2.20 1 64 male no
#> 5 3 0 2 (0,2] 0.693 0 45 male yes
#> 6 3 2 3 (2,3] 0 0 45 male yes
#> 7 3 586 613 (586,613] 3.30 0 45 male yes
#> 8 3 613 646 (613,646] 3.47 0 45 male yes
length(unique(discretized_data$interval))
#> [1] 142
# -> subject 1 experienced event at 1402 -> ped_status = 1 in this interval, otherwise 0
# -> subject 3 was censored at 645 -> last interval in risk set is (613, 646] -> ped_status = 0 always
### Model estimation
# constant baseline hazard
m0 <- glm(ped_status ~ 1, data = discretized_data, family = binomial())
# interval specific baseline hazard
# not so good, hazards volatile, many parameters to estimate
m1 <- glm(ped_status ~ interval, data = discretized_data, family = binomial())
# better: penaize differences between neighboring hazards
# tend is a representation of time in the j-th interval, here interval end point
m2 <- gam(ped_status ~ s(tend), data = discretized_data, family = binomial())
### Visualization
ndf <- discretized_data %>% make_newdata(tend = unique(tend))
head(ndf)
#> tstart tend intlen interval id offset ped_status age sex
#> 1 0 2 2 (0,2] 151.1807 0.6931472 0 61.06683 male
#> 2 2 3 1 (2,3] 151.1807 0.0000000 0 61.06683 male
#> 3 3 5 2 (3,5] 151.1807 0.6931472 0 61.06683 male
#> 4 5 6 1 (5,6] 151.1807 0.0000000 0 61.06683 male
#> 5 6 7 1 (6,7] 151.1807 0.0000000 0 61.06683 male
#> 6 7 8 1 (7,8] 151.1807 0.0000000 0 61.06683 male
#> complications
#> 1 no
#> 2 no
#> 3 no
#> 4 no
#> 5 no
#> 6 no
ndf$hazard0 <- predict(m0, newdata = ndf, type = "response")
ndf$hazard1 <- predict(m1, newdata = ndf, type = "response")
ndf$hazard2 <- predict(m2, newdata = ndf, type = "response")
ggplot(ndf, aes(x = tend)) +
geom_step(aes(y = hazard0, col = "m0")) +
geom_step(aes(y = hazard1, col = "m1")) +
geom_step(aes(y = hazard2, col = "m2")) ### Survival Probability
ndf <- ndf %>%
mutate(
surv0 = cumprod(1 - hazard0),
surv1 = cumprod(1 - hazard1),
surv2 = cumprod(1 - hazard2),
)
# cox for comparison
cox <- coxph(Surv(days, status)~1, data = data)
bh <- basehaz(cox)
bh$surv_cox <- as.numeric(exp(-bh$hazard))
ggplot(ndf, aes(x = tend)) +
geom_step(aes(y = surv0, col = "m0"))+
geom_step(aes(y = surv1, col = "m1"))+
geom_step(aes(y = surv2, col = "m2")) +
geom_step(data = bh, aes(x=time, y = surv_cox, col = "cox")) +
ylim(c(0, 1)) # cox, m1 and m2 basically identical
## Stratified baseline hazard
cox_strata <- coxph(Surv(days, status)~strata(complications), data = data)
bh_strata <- basehaz(cox_strata)
bh_strata <- bh_strata %>%
group_by(strata) %>%
mutate(surv = exp(-hazard))
discrete_strata <- gam(ped_status ~ complications + s(tend, by = complications),
data = discretized_data, family = binomial())
ndf <- discretized_data %>%
make_newdata(tend = unique(tend), complications = unique(complications))
ndf$hazard_strata <- predict(discrete_strata, ndf, type = "response")
ndf <- ndf %>%
group_by(complications) %>%
mutate(surv_strata = cumprod(1 - hazard_strata))
ggplot(ndf, aes(x = tend)) +
geom_step(aes(lty = complications, y = surv_strata, col = "discrete")) +
geom_line(data = bh_strata, aes(x = time, y = surv, lty = strata, col = "cox")) Created on 2021-06-01 by the reprex package (v0.3.0) |
This is great, let's use pammtools not discSurv in the reduction directly |
Especially with the examples you show above it's very promising!! |
So available reductions/ones we have implemented
Comparing these would make a very nice paper |
Ideally, we would extract some of the functionality re data transformation into a separate package. Otherwise you get a lot of unnecessary dependencies. the data trafo in pammtools is itself a wrapper around |
Unless I'm missing something, I think Zhong/Tibshirani is discrete time to event analysis (transformation of survival task into a binomial likelihood optimization/classification task), just framed differently, so there is not difference to discSurv/pammtools. The latter are just implementations of the data trafo before applying some classification algorithm. For pammtools, this is essentially a by-product, because originally, the data trafo was intended for piece-wise exponential models (i.e. transformation of survival task into poisson likelihood optimization task), but the data trafo for discrete time to event and piece-wise exponential models is essentially the same. |
This could go into a different package but we still want the interface to work with mlr easily otherwise it's still a lot of glue code
You're the expect here but my understanding is that their method is different because of how the predictions are made (not necessarily in the transformation) and is therefore sufficiently different to mean it's fair to include in a comparison. So in discSurv we would fit a single model at all unique times with T as a covariate? Whereas in the stacking approach it's one model per unique event time. Obviously these may come to the same thing but analytically it's still a different approach |
what do we need to do for this reduction to be fully integrated in mlr3? |
Finished in #391 |
This will require some careful thought as initial exploration (below) indicates sub-optimal results. We may want to consider discrete-survival representations, as well as discrete hazards, or think about sampling methods to solve the imbalance problem. (@adibender)
The text was updated successfully, but these errors were encountered: