Skip to content

Commit

Permalink
Merge pull request #834 from mlr-org/po_subsample_use_groups
Browse files Browse the repository at this point in the history
New param `use_groups` for `PipeOpSubsample` and rework for `task_filter_ex()`
  • Loading branch information
mb706 authored Nov 26, 2024
2 parents 9d4cced + e14d020 commit 40f7a6f
Show file tree
Hide file tree
Showing 6 changed files with 420 additions and 43 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
* New parameter `no_collapse_above_absolute` in `PipeOpCollapseFactors` / `po("collapse_factors")`.
* Fix: `PipeOpCollapseFactors` now correctly collapses levels of ordered factors.
* Fix: `LearnerClassifAvg` and `LearnerRegrAvg` hyperparameters get the `"required"` tag.
* New parameter `use_groups` (default `TRUE`) for `PipeOpSubsampling` to respect grouping (changed default behaviour for grouped data)

# mlr3pipelines 0.7.1

Expand Down
69 changes: 57 additions & 12 deletions R/PipeOpSubsample.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
#' * `frac` :: `numeric(1)`\cr
#' Fraction of rows in the [`Task`][mlr3::Task] to keep. May only be greater than 1 if `replace` is `TRUE`. Initialized to `(1 - exp(-1)) == 0.6321`.
#' * `stratify` :: `logical(1)`\cr
#' Should the subsamples be stratified by target? Initialized to `FALSE`. May only be `TRUE` for [`TaskClassif`][mlr3::TaskClassif] input.
#' Should the subsamples be stratified by target? Initialized to `FALSE`. May only be `TRUE` for [`TaskClassif`][mlr3::TaskClassif] input and if `use_groups = FALSE`.
#' * `use_groups` :: `logical(1)`\cr
#' If `TRUE` and if the [`Task`][mlr3::Task] has a column with role `group`, grouped observations are kept together during subsampling. In case of sampling with
# `replace = TRUE`, the group entry of duplicate samples is suffixed (`_1`, `_2`, ...). May only be `TRUE` if `strafiy = FALSE`. Initialized to `TRUE`.
#' * `replace` :: `logical(1)`\cr
#' Sample with replacement? Initialized to `FALSE`.
#'
Expand All @@ -52,9 +55,22 @@
#' @examples
#' library("mlr3")
#'
#' pos = mlr_pipeops$get("subsample", param_vals = list(frac = 0.7, stratify = TRUE))
#' # Subsample with stratification
#' pop = po("subsample", frac = 0.7, stratify = TRUE, use_groups = FALSE)
#' pop$train(list(tsk("iris")))
#'
#' pos$train(list(tsk("iris")))
#' # Subsample, respecting grouping
#' df = data.frame(
#' target = runif(3000),
#' x1 = runif(3000),
#' x2 = runif(3000),
#' grp = sample(paste0("g", 1:100), 3000, replace = TRUE)
#' )
#' task = TaskRegr$new(id = "example", backend = df, target = "target")
#' task$set_col_roles("grp", "group")
#'
#' pop = po("subsample", frac = 0.7, use_groups = TRUE)
#' pop$train(list(task))
#'
#' @family PipeOps
#' @template seealso_pipeopslist
Expand All @@ -67,30 +83,59 @@ PipeOpSubsample = R6Class("PipeOpSubsample",
ps = ps(
frac = p_dbl(lower = 0, upper = Inf, tags = "train"),
stratify = p_lgl(tags = "train"),
use_groups = p_lgl(tags = "train"),
replace = p_lgl(tags = "train")
)
ps$values = list(frac = 1 - exp(-1), stratify = FALSE, replace = FALSE)
ps$values = list(frac = 1 - exp(-1), stratify = FALSE, use_groups = TRUE, replace = FALSE)
super$initialize(id, param_set = ps, param_vals = param_vals, can_subset_cols = FALSE)
}
),
private = list(

.train_task = function(task) {
if (!self$param_set$values$stratify) {
keep = shuffle(task$row_roles$use,
ceiling(self$param_set$values$frac * task$nrow),
replace = self$param_set$values$replace)
} else {
pv = self$param_set$get_values(tags = "train")

if (pv$frac > 1 && pv$replace == FALSE) {
stop("Can't subsample task up to a fraction larger than 1 if parameter 'replace' is FALSE")
}

if (pv$stratify && pv$use_groups) {
stop("Cannot combine stratification with grouping")
} else if (pv$use_groups && !is.null(task$groups)) {
# task$groups automatically removes rows not in task$row_roles$use and allows rows to be included multiple times
grp_sizes = table(task$groups$group)
ngrps = length(grp_sizes)
nrows = task$nrow

if (pv$replace) {
# Draw groups until the fraction of sampled rows is no longer below the desired fraction.
# We then keep all up to the entry for which the fraction of sampled rows is closest to the desired fraction.
shuffled = numeric(0)
while (sum(shuffled) / nrows < pv$frac) {
shuffled = c(shuffled, shuffle(grp_sizes, ceiling(max(1, pv$frac) * ngrps), replace = TRUE))
}
cutoff_index = which.min(abs(cumsum(shuffled) / nrows - pv$frac))
keep_grps = names(shuffled[seq_len(cutoff_index)])
} else {
# We randomly shuffle the groups and keep all up to the group for
# which the fraction of sampled rows is closest to the desired fraction.
shuffled = shuffle(grp_sizes)
cutoff_index = which.min(abs(cumsum(shuffled) / nrows - pv$frac))
keep_grps = names(shuffled[seq_len(cutoff_index)])
}
keep = task$groups[list(keep_grps), on = "group", allow.cartesian = TRUE]$row_id
} else if (pv$stratify) {
if (!inherits(task, "TaskClassif")) {
stopf("Stratification not supported for %s", class(task))
}
splt = split(task$row_roles$use, task$data(cols = task$target_names))
keep = unlist(map(splt, function(x) {
shuffle(x,
ceiling(self$param_set$values$frac * length(x)),
replace = self$param_set$values$replace)
shuffle(x, ceiling(pv$frac * length(x)), replace = pv$replace)
}))
} else {
keep = shuffle(task$row_roles$use, ceiling(pv$frac * task$nrow), replace = pv$replace)
}

self$state = list()
task_filter_ex(task, keep)
},
Expand Down
72 changes: 61 additions & 11 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,76 @@ calculate_collimit = function(colwidths, outwidth) {
collimit - 3 # subtracting 3 here because data.table adds "..." whenever it truncates a string
}

# same as task$filter(), but allows duplicate row IDs
# Same as task$filter(), but allows duplicate row IDs.
# Handles duplicate rows in tasks with col_role "group" by renaming groups
# following pattern grp_name_1, grp_name_1, ... per each duplication of a group.
# @param task [Task] the task
# @param row_ids [numeric] the row IDs to select
# @return [Task] the modified task
task_filter_ex = function(task, row_ids) {
# Get vector of duplicate row IDs
dup_ids = row_ids[duplicated(row_ids)]
# Generate vector with new row IDs
newrows = task$nrow + seq_along(dup_ids)
# Get all columns of task for subsetting using task$data()
cols = unique(unlist(task$col_roles, use.names = FALSE))

# Rbind duplicated rows to task
if (length(dup_ids)) {
# First, get a data.table with all duplicated rows.
new_data = task$data(rows = dup_ids, cols = cols)

# Second, if task has a column with role "group", create new groups for duplicate rows by adding a suffix to the group entry.
if (!is.null(task$groups)) {
group = NULL # for binding
row_id = NULL # for binding

row_counts = table(task$row_roles$use)
grps = unique(task$groups$group)

# by = "row_id" for faster computation since groups are implied by row_id in task$groups
new_groups = unique(task$groups, by = "row_id")[list(dup_ids), on = "row_id"]
new_groups[, group := {
# Number of how often the same group name should occur for this row ID
target_count = row_counts[as.character(row_id)]
# Get default group name target_count - 1 times since default group already exists in task once
groups = group[seq_len(target_count - 1)]
# Initialize suffix to be appended to group name if it is otherwise already taken
suffix = 0

while (length(groups) < .N) {
suffix = suffix + 1
new_group = paste0(group[[1]], "_", suffix)
# Add it new_group already exists, skip to next iteration.
if (new_group %in% grps) {
next
}
# Otherwise, add new_group to groups
groups[length(groups) + seq_len(target_count)] = new_group
}
# This can happen if row_ids
# - has more occurances of an ID than in task$row_roles$use which is not an exact multiple
# - has less occurances of an ID than in task$row_roles$use
if (length(groups) != .N) {
stopf("Called task_filter_ex() but constructed incomplete group '%s'. Try removing column with role 'group'.", group[[1]])
}
groups
}, by = row_id]

# Use "new_groups" to update the group entries.
new_data[, (task$col_roles$group) := new_groups$group]
}

addedrows = row_ids[duplicated(row_ids)]

newrows = task$nrow + seq_along(addedrows)

if (length(addedrows)) {
task$rbind(task$data(rows = addedrows))
# Lastly, new data is rbinded to the original task.
task$rbind(new_data)
}

# row ids can be anything, we just take what mlr3 happens to assign.
# row_ids can be anything, we just take what mlr3 happens to assign to filter the task.
row_ids[duplicated(row_ids)] = task$row_ids[newrows]

task$filter(row_ids)
# Update row_ids, effectively filtering the task
task$row_roles$use = row_ids
task
}

# these must be at the root and can not be anonymous functions because all.equal fails otherwise.
Expand All @@ -66,7 +118,6 @@ curry = function(fn, ..., varname = "x") {
}
}


# 'and' operator for checkmate check_*-functions
# example:
# check_numeric(x) %check&&% check_true(all(x < 0))
Expand All @@ -80,7 +131,6 @@ curry = function(fn, ..., varname = "x") {
TRUE
}


# perform gsub on names of list
# `...` are given to `gsub()`
rename_list = function(x, ...) {
Expand Down
23 changes: 19 additions & 4 deletions man/mlr_pipeops_subsample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 40f7a6f

Please sign in to comment.