Skip to content

Commit

Permalink
Merge pull request #411 from mlr-org/surv-assert
Browse files Browse the repository at this point in the history
perf: optimize Rcpp
  • Loading branch information
bblodfon authored Sep 6, 2024
2 parents 737498e + 31b0a5b commit 0e5c80b
Show file tree
Hide file tree
Showing 13 changed files with 195 additions and 174 deletions.
2 changes: 1 addition & 1 deletion .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ indent_size = 2
indent_size = 4

[*.{cpp,hpp}]
indent_size = 4
indent_size = 2

[{NEWS.md,DESCRIPTION,LICENSE}]
max_line_length = 80
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.6.7
Version: 0.6.8
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# mlr3proba 0.6.8

- `Rcpp` code optimizations
- Fixed ERV scoring to comply with `mlr3` dev version (no bugs before)
- Skipping `survtoregr` pipelines due to bugs (to be refactored in the future)

# mlr3proba 0.6.7

- Deprecate `crank` to `distr` composition in `distrcompose` pipeop (only from `lp` => `distr` works now)
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerDens.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ LearnerDens = R6::R6Class("LearnerDens",
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, param_set = ps(),
predict_types = "cdf", feature_types = character(),
properties = character(), data_formats = "data.table",
properties = character(),
packages = character(),
label = NA_character_,
man = NA_character_) {
super$initialize(
id = id, task_type = "dens", param_set = param_set,
predict_types = predict_types, feature_types = feature_types, properties = properties,
data_formats = data_formats, packages = c("mlr3proba", packages), label = label, man = man)
packages = c("mlr3proba", packages), label = label, man = man)
}
)
)
16 changes: 9 additions & 7 deletions R/scoring_rule_erv.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,20 @@
if (!is.null(ps$se) && ps$se) {
stop("Only one of `ERV` and `se` can be TRUE")
}
measure$param_set$values$ERV = FALSE

measure$param_set$set_values(ERV = FALSE)
# compute score for the learner
learner_score = measure$score(prediction, task = task,
train_set = train_set)
learner_score = measure$score(prediction, task = task, train_set = train_set)

# compute score for the baseline (Kaplan-Meier)
# train KM
km = lrn("surv.kaplan")$train(task = task, row_ids = train_set)
km = km$predict(as_task_surv(data.frame(
as.matrix(prediction$truth)), event = "status"))
base_score = measure$score(km, task = task, train_set = train_set)
# predict KM on the test set (= not train ids)
test_set = setdiff(task$row_ids, train_set)
km_pred = km$predict(task, row_ids = test_set)
base_score = measure$score(km_pred, task = task, train_set = train_set)

measure$param_set$values$ERV = TRUE
measure$param_set$set_values(ERV = TRUE)

# return percentage decrease
1 - (learner_score / base_score)
Expand Down
48 changes: 24 additions & 24 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// c_assert_surv
bool c_assert_surv(NumericMatrix mat);
bool c_assert_surv(const NumericMatrix& mat);
RcppExport SEXP _mlr3proba_c_assert_surv(SEXP matSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericMatrix >::type mat(matSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type mat(matSEXP);
rcpp_result_gen = Rcpp::wrap(c_assert_surv(mat));
return rcpp_result_gen;
END_RCPP
Expand All @@ -34,74 +34,74 @@ BEGIN_RCPP
END_RCPP
}
// c_score_intslogloss
NumericMatrix c_score_intslogloss(NumericVector truth, NumericVector unique_times, NumericMatrix cdf, double eps);
NumericMatrix c_score_intslogloss(const NumericVector& truth, const NumericVector& unique_times, const NumericMatrix& cdf, double eps);
RcppExport SEXP _mlr3proba_c_score_intslogloss(SEXP truthSEXP, SEXP unique_timesSEXP, SEXP cdfSEXP, SEXP epsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericVector >::type truth(truthSEXP);
Rcpp::traits::input_parameter< NumericVector >::type unique_times(unique_timesSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type cdf(cdfSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type truth(truthSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type unique_times(unique_timesSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type cdf(cdfSEXP);
Rcpp::traits::input_parameter< double >::type eps(epsSEXP);
rcpp_result_gen = Rcpp::wrap(c_score_intslogloss(truth, unique_times, cdf, eps));
return rcpp_result_gen;
END_RCPP
}
// c_score_graf_schmid
NumericMatrix c_score_graf_schmid(NumericVector truth, NumericVector unique_times, NumericMatrix cdf, int power);
NumericMatrix c_score_graf_schmid(const NumericVector& truth, const NumericVector& unique_times, const NumericMatrix& cdf, int power);
RcppExport SEXP _mlr3proba_c_score_graf_schmid(SEXP truthSEXP, SEXP unique_timesSEXP, SEXP cdfSEXP, SEXP powerSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericVector >::type truth(truthSEXP);
Rcpp::traits::input_parameter< NumericVector >::type unique_times(unique_timesSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type cdf(cdfSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type truth(truthSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type unique_times(unique_timesSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type cdf(cdfSEXP);
Rcpp::traits::input_parameter< int >::type power(powerSEXP);
rcpp_result_gen = Rcpp::wrap(c_score_graf_schmid(truth, unique_times, cdf, power));
return rcpp_result_gen;
END_RCPP
}
// c_weight_survival_score
NumericMatrix c_weight_survival_score(NumericMatrix score, NumericMatrix truth, NumericVector unique_times, NumericMatrix cens, bool proper, double eps);
NumericMatrix c_weight_survival_score(const NumericMatrix& score, const NumericMatrix& truth, const NumericVector& unique_times, const NumericMatrix& cens, bool proper, double eps);
RcppExport SEXP _mlr3proba_c_weight_survival_score(SEXP scoreSEXP, SEXP truthSEXP, SEXP unique_timesSEXP, SEXP censSEXP, SEXP properSEXP, SEXP epsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericMatrix >::type score(scoreSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type truth(truthSEXP);
Rcpp::traits::input_parameter< NumericVector >::type unique_times(unique_timesSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type cens(censSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type score(scoreSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type truth(truthSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type unique_times(unique_timesSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type cens(censSEXP);
Rcpp::traits::input_parameter< bool >::type proper(properSEXP);
Rcpp::traits::input_parameter< double >::type eps(epsSEXP);
rcpp_result_gen = Rcpp::wrap(c_weight_survival_score(score, truth, unique_times, cens, proper, eps));
return rcpp_result_gen;
END_RCPP
}
// c_concordance
float c_concordance(NumericVector time, NumericVector status, NumericVector crank, double t_max, std::string weight_meth, NumericMatrix cens, NumericMatrix surv, float tiex);
float c_concordance(const NumericVector& time, const NumericVector& status, const NumericVector& crank, double t_max, const std::string& weight_meth, const NumericMatrix& cens, const NumericMatrix& surv, float tiex);
RcppExport SEXP _mlr3proba_c_concordance(SEXP timeSEXP, SEXP statusSEXP, SEXP crankSEXP, SEXP t_maxSEXP, SEXP weight_methSEXP, SEXP censSEXP, SEXP survSEXP, SEXP tiexSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericVector >::type time(timeSEXP);
Rcpp::traits::input_parameter< NumericVector >::type status(statusSEXP);
Rcpp::traits::input_parameter< NumericVector >::type crank(crankSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type time(timeSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type status(statusSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type crank(crankSEXP);
Rcpp::traits::input_parameter< double >::type t_max(t_maxSEXP);
Rcpp::traits::input_parameter< std::string >::type weight_meth(weight_methSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type cens(censSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type surv(survSEXP);
Rcpp::traits::input_parameter< const std::string& >::type weight_meth(weight_methSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type cens(censSEXP);
Rcpp::traits::input_parameter< const NumericMatrix& >::type surv(survSEXP);
Rcpp::traits::input_parameter< float >::type tiex(tiexSEXP);
rcpp_result_gen = Rcpp::wrap(c_concordance(time, status, crank, t_max, weight_meth, cens, surv, tiex));
return rcpp_result_gen;
END_RCPP
}
// c_gonen
double c_gonen(NumericVector crank, float tiex);
double c_gonen(const NumericVector& crank, float tiex);
RcppExport SEXP _mlr3proba_c_gonen(SEXP crankSEXP, SEXP tiexSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericVector >::type crank(crankSEXP);
Rcpp::traits::input_parameter< const NumericVector& >::type crank(crankSEXP);
Rcpp::traits::input_parameter< float >::type tiex(tiexSEXP);
rcpp_result_gen = Rcpp::wrap(c_gonen(crank, tiex));
return rcpp_result_gen;
Expand Down
5 changes: 2 additions & 3 deletions src/survival_assert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
using namespace Rcpp;

// [[Rcpp::export]]
bool c_assert_surv(NumericMatrix mat) {
bool c_assert_surv(const NumericMatrix& mat) {
for (int i = 0; i < mat.nrow(); i++) {
// check first element
if (mat(i, 0) < 0 || mat(i, 0) > 1) {
// check first element
return false;
}

Expand All @@ -24,4 +24,3 @@ bool c_assert_surv(NumericMatrix mat) {

return true;
}

Loading

0 comments on commit 0e5c80b

Please sign in to comment.