Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PipeOpFilterRows #410

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ Collate:
'PipeOpEncodeLmer.R'
'PipeOpFeatureUnion.R'
'PipeOpFilter.R'
'PipeOpFilterRows.R'
'PipeOpFixFactors.R'
'PipeOpHistBin.R'
'PipeOpICA.R'
Expand All @@ -122,6 +123,7 @@ Collate:
'PipeOpMutate.R'
'PipeOpNOP.R'
'PipeOpPCA.R'
'PipeOpPredictionUnion.R'
'PipeOpProxy.R'
'PipeOpQuantileBin.R'
'PipeOpRegrAvg.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export(PipeOpEncodeLmer)
export(PipeOpEnsemble)
export(PipeOpFeatureUnion)
export(PipeOpFilter)
export(PipeOpFilterRows)
export(PipeOpFixFactors)
export(PipeOpHistBin)
export(PipeOpICA)
Expand All @@ -56,6 +57,7 @@ export(PipeOpModelMatrix)
export(PipeOpMutate)
export(PipeOpNOP)
export(PipeOpPCA)
export(PipeOpPredictionUnion)
export(PipeOpProxy)
export(PipeOpQuantileBin)
export(PipeOpRegrAvg)
Expand Down
171 changes: 171 additions & 0 deletions R/PipeOpFilterRows.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#' @title PipeOpFilterRows
#'
#' @usage NULL
#' @name mlr_pipeops_filterrows
#' @format [`R6Class`] object inheriting from [`PipeOpTaskPreproc`].
#'
#' @description
#' Filter rows of the data of a task. Also directly allows for the removal of rows holding missing
sumny marked this conversation as resolved.
Show resolved Hide resolved
#' values.
#'
#' @section Construction:
#' ```
#' PipeOpFilterRows$new(id = "filterrows", param_vals = list())
#' ```
#'
#' * `id` :: `character(1)`\cr
#' Identifier of resulting object, default `"filterrows"`.
#' * `param_vals` :: named `list`\cr
#' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise
#' be set during construction. Default `list()`.
#'
#' @section Input and Output Channels:
#' Input and output channels are inherited from [`PipeOpTaskPreproc`].
#'
#' The output during training is the input [`Task`][mlr3::Task] with rows kept according to the
#' filtering (see Parameters) and (possible) rows with missing values removed.
#'
sumny marked this conversation as resolved.
Show resolved Hide resolved
#' The output during prediction is the unchanged input [`Task`][mlr3::Task] if the parameter
#' `skip_during_predict` is `TRUE`. Otherwise it is analogously handled as the output during
#' training.
#'
#' @section State:
#' The `$state` is a named `list` with the `$state` elements inherited from [`PipeOpTaskPreproc`],
#' as well as the following elements:
#' * `na_ids` :: `integer`\cr
#' The row identifiers that had missing values during training and therefore were removed. See the
#' parameter `na_column`.
#' * `row_ids` :: `integer`\cr
#' The row identifiers that were kept during training according to the parameters `filter`,
#' `na_column` and `invert`.
#'
#' @section Parameters:
#' The parameters are the parameters inherited from [`PipeOpTaskPreproc`], as well as:
#' * `filter` :: `NULL` | `character(1)` | `expression` | `integer`\cr
sumny marked this conversation as resolved.
Show resolved Hide resolved
#' How the rows of the data of the input [`Task`][mlr3::Task] should be filtered. This can be a
#' character vector of length 1 indicating a feature column of logicals in the data of the input
#' [`Task`][mlr3::Task] which forms the basis of the filtering, i.e., all rows that are `TRUE`
#' with respect to this column are kept in the data of the output [`Task`][mlr3::Task]. Moreover,
#' this can be an expression that will result in a logical vector of length `$nrow` of the data of
#' the input [`Task`][mlr3::Task] when evaluated withing the environment of the `$data()` of the
#' input [`Task`][mlr3::Task]. Finally, this can also be an integerish vector that directly
#' specifies the row identifiers of the rows of the data of the input [`Task`][mlr3::Task] that
#' should be kept. Default is `NULL`, i.e., no filtering is done.
#' * `na_column` :: `NULL` | `character`\cr
sumny marked this conversation as resolved.
Show resolved Hide resolved
#' A character vector that specifies the columns of the data of the input [`Task`][mlr3::Task]
#' that should be checked for missing values. If set to `all`, all columns of the data are used. A
sumny marked this conversation as resolved.
Show resolved Hide resolved
#' row is removed if at least one missing value is found with respect to the columns specified.
#' Default is `NULL`, i.e., no removal of missing values is done.
#' * `invert` :: `logical(1)`\cr
sumny marked this conversation as resolved.
Show resolved Hide resolved
#' Should the filtering rule be set-theoretically inverted? Note that this happens after
#' (possible) missing values were removed if `na_column` is specified. Default is `FALSE`.
#' * `skip_during_predict` :: `logical(1)`\cr
sumny marked this conversation as resolved.
Show resolved Hide resolved
#' Should the filtering and missing value removal steps be skipped during prediction? If `TRUE`,
sumny marked this conversation as resolved.
Show resolved Hide resolved
#' the input [`Task`][mlr3::Task] is returned unaltered during prediction. Default is `FALSE`.
#'
#' @section Internals:
#' Uses the [`is.na()`][base::is.na] function for the checking of missing values.
#'
#' @section Methods:
#' Only methods inherited from [`PipeOpTaskPreproc`]/[`PipeOp`].
#'
#' @examples
#' library("mlr3")
#' task = tsk("pima")
#' po = PipeOpFilterRows$new(param_vals = list(
#' filter = expression(age < median(age) & mass > 30),
#' na_column = "all")
#' )
#' po$train(list(task))
#' po$state
#' @family PipeOps
#' @include PipeOpTaskPreproc.R
#' @export
PipeOpFilterRows = R6Class("PipeOpFilterRows",
inherit = PipeOpTaskPreproc,
public = list(
initialize = function(id = "filterrows", param_vals = list()) {
ps = ParamSet$new(params = list(
ParamUty$new("filter", default = NULL, tags = c("train", "predict"), custom_check = function(x) {
ok = test_character(x, any.missing = FALSE, len = 1L) ||
is.expression(x) ||
test_integerish(x, lower = 1, min.len = 1L) ||
is.null(x)
if (!ok) return("Must either be a character vector of length 1, an expression, or an integerish object of row ids")
TRUE
}),
ParamUty$new("na_column", default = NULL, tags = c("train", "predict"), custom_check = function(x) {
check_character(x, any.missing = FALSE, min.len = 1L, null.ok = TRUE)
}),
ParamLgl$new("invert", default = FALSE, tags = c("train", "predict")),
ParamLgl$new("skip_during_predict", default = FALSE, tags = "predict"))
)
ps$values = list(filter = NULL, na_column = NULL, invert = FALSE, skip_during_predict = FALSE)
super$initialize(id, param_set = ps, param_vals = param_vals)
}
),
private = list(
.na_and_filter = function(task, skip, set_state) {
if (skip) {
sumny marked this conversation as resolved.
Show resolved Hide resolved
return(task) # early exit if skipped (if skip_during_predict)
}

row_ids = task$row_ids

# NA column(s) handling
na = self$param_set$values$na_column
if (!is.null(na)) {
assert_subset(na, choices = c("all", colnames(task$data())))
sumny marked this conversation as resolved.
Show resolved Hide resolved
if (na == "all") na = colnames(task$data())
na_ids = which(rowSums(is.na(task$data(cols = na))) > 0L)
sumny marked this conversation as resolved.
Show resolved Hide resolved
row_ids = setdiff(row_ids, na_ids)
} else {
na_ids = integer(0L)
}

# filtering
filter = self$param_set$values$filter
filter_ids =
if (is.null(filter)) {
row_ids
} else if (is.character(filter)) {
assert_subset(filter, choices = task$feature_names)
filter_column = task$data(cols = filter)[[1L]]
assert_logical(filter_column)
which(filter_column)
} else if(is.expression(filter)) {
filter_expression = eval(filter, envir = task$data())
assert_logical(filter_expression, len = task$nrow)
which(filter_expression)
} else {
filter = as.integer(filter)
assert_subset(filter, choices = task$row_ids)
filter
}

row_ids = if (self$param_set$values$invert) {
setdiff(row_ids, filter_ids)
} else {
intersect(row_ids, filter_ids)
}

# only set the state if required (during training)
if (set_state) {
self$state$na_ids = na_ids
sumny marked this conversation as resolved.
Show resolved Hide resolved
self$state$row_ids = row_ids
}

task$filter(row_ids)
},

.train_task = function(task) {
private$.na_and_filter(task, skip = FALSE, set_state = TRUE)
},

.predict_task = function(task) {
private$.na_and_filter(task, skip = self$param_set$values$skip_during_predict, set_state = FALSE)
}
)
)

mlr_pipeops$add("filterrows", PipeOpFilterRows)
136 changes: 136 additions & 0 deletions R/PipeOpPredictionUnion.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#' @title PipeOpPredictionUnion
sumny marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @usage NULL
#' @name mlr_pipeops_predictionunion
#' @format [`R6Class`] object inheriting from [`PipeOp`].
#'
#' @description
#' Unite predictions from all input predictions into a single
#' [`Prediction`][mlr3::Prediction].
#'
#' `task_type`s and `predict_types` must be equal across all input predictions.
#'
#' Note that predictions are combined as is, i.e., no checks for duplicated row
#' identifiers etc. are performed.
#'
#' Currently only supports task types `classif` and `regr` by constructing a new
#' [`PredictionClassif`][mlr3::PredictionClassif] and respectively
#' [`PredictionRegr`][mlr3::PredictionRegr].
#'
#' @section Construction:
#' ```
#' PipeOpPredictionUnion$new(innum = 0, id = "predictionunion", param_vals = list())
#' ```
#'
#' * `innum` :: `numeric(1)` | `character`\cr
#' Determines the number of input channels. If `innum` is 0 (default), a vararg input channel is
#' created that can take an arbitrary number of inputs. If `innum` is a `character` vector, the
#' number of input channels is the length of `innum`.
#' * `id` :: `character(1)`\cr
#' Identifier of the resulting object, default `"predictionunion"`.
#' * `param_vals` :: named `list`\cr
#' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise
#' be set during construction. Default `list()`.
#'
#' @section Input and Output Channels:
#' [`PipeOpPredictionUnion`] has multiple input channels depending on the `innum` construction
#' argument, named `"input1"`, `"input2"`, ... if `innum` is nonzero; if `innum` is 0, there is only
#' one *vararg* input channel named `"..."`. All input channels take `NULL` during training and a
#' [`Prediction`][mlr3::Prediction] during prediction.
#'
#' [`PipeOpPredictionUnion`] has one output channel named `"output"`, producing `NULL` during
#' training and a [`Prediction`][mlr3::Prediction] during prediction.
#'
#' The output during prediction is a [`Prediction`][mlr3::Prediction] constructed by combining all
#' input [`Prediction`][mlr3::Prediction]s.
#'
#' @section State:
#' The `$state` is left empty (`list()`).
#'
#' @section Parameters:
#' [`PipeOpPredictionUnion`] has no Parameters.
#'
#' @section Internals:
#' Only sets the fields `row_ids`, `truth`, `response` and if applicable `prob` and `se` during
#' construction of the output [`Prediction`][mlr3::Prediction].
#'
#' @section Fields:
#' Only fields inherited from [`PipeOp`].
#'
#' @section Methods:
#' Only methods inherited from [`PipeOp`].
#'
#' @family PipeOps
#' @include PipeOp.R
#' @export
#' @examples
#' library("mlr3")
#'
#' task = tsk("iris")
#' filter = expression(Sepal.Length < median(Sepal.Length))
#' gr = po("copy", outnum = 2) %>>% gunion(list(
#' po("filterrows", id = "filter1",
#' param_vals = list(filter = filter)) %>>%
#' lrn("classif.rpart", id = "learner1"),
#' po("filterrows", id = "filter2",
#' param_vals = list(filter = filter, invert = TRUE)) %>>%
#' lrn("classif.rpart", id = "learner2")
#' )) %>>% po("predictionunion")
#'
#' gr$train(task)
#' gr$predict(task)
PipeOpPredictionUnion = R6Class("PipeOpPredictionUnion",
inherit = PipeOp,
public = list(
initialize = function(innum = 0L, id = "predictionunion", param_vals = list()) {
assert(
check_int(innum, lower = 0L),
check_character(innum, min.len = 1L, any.missing = FALSE)
)
if (!is.numeric(innum)) {
innum = length(innum)
}
inname = if (innum) rep_suffix("input", innum) else "..."
super$initialize(id, param_vals = param_vals,
input = data.table(name = inname, train = "NULL", predict = "Prediction"),
output = data.table(name = "output", train = "NULL", predict = "Prediction"))
}
),
private = list(
.train = function(inputs) {
self$state = list()
list(NULL)
},
.predict = function(inputs) {
# currently only works for task_type "classif" or "regr"
check = all((unlist(map(inputs[-1L], .f = `[[`, "task_type")) == inputs[[1L]]$task_type) &
sumny marked this conversation as resolved.
Show resolved Hide resolved
unlist(map(inputs[-1L], .f = `[[`, "predict_types")) == inputs[[1L]]$predict_types)
if (!check) {
stopf("Can only unite predictions of the same task type and predict types.")
}

type = inputs[[1L]]$task_type
if (type %nin% c("classif", "regr")) {
stopf("Currently only supports task types `classif` and `regr`.")
}

row_ids = unlist(map(inputs, .f = `[[`, "row_ids"), use.names = FALSE)
truth = unlist(map(inputs, .f = `[[`, "truth"), use.names = FALSE)
response = unlist(map(inputs, .f = `[[`, "response"), use.names = FALSE)

prediction =
if(type == "classif") {
prob = do.call(rbind, map(inputs, .f = `[[`, "prob"))
PredictionClassif$new(row_ids = row_ids, truth = truth, response = response, prob = prob)
} else {
se = unlist(map(inputs, .f = `[[`, "se"), use.names = FALSE)
if (length(se) == 0L) se = NULL
PredictionRegr$new(row_ids = row_ids, truth = truth, response = response, se = se)
}

list(prediction)
}
)
)

mlr_pipeops$add("predictionunion", PipeOpPredictionUnion)
2 changes: 2 additions & 0 deletions man/PipeOp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/PipeOpEnsemble.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/PipeOpImpute.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading