Skip to content

Commit

Permalink
Merge pull request #837 from mlr-org/dict_sugar_suggests
Browse files Browse the repository at this point in the history
Use `.dicts_suggest` in `po.R` and `ppl.R`
  • Loading branch information
advieser authored Dec 1, 2024
2 parents 40f7a6f + 2184038 commit 086646b
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 6 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Imports:
digest,
lgr,
mlr3 (>= 0.20.0),
mlr3misc (>= 0.9.0),
mlr3misc (>= 0.16.0),
paradox,
R6,
withr
Expand Down
4 changes: 2 additions & 2 deletions R/po.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ po.PipeOp = function(.obj, ...) {

#' @export
po.character = function(.obj, ...) {
dictionary_sugar_inc_get(dict = mlr_pipeops, .key = .obj, ...)
dictionary_sugar_inc_get(dict = mlr_pipeops, .key = .obj, ..., .dicts_suggest = list("ppl()" = mlr_graphs))
}

#' @export
Expand Down Expand Up @@ -111,7 +111,7 @@ pos.NULL = function(.objs, ...) {

#' @export
pos.character = function(.objs, ...) {
dictionary_sugar_inc_mget(dict = mlr_pipeops, .keys = .objs, ...)
dictionary_sugar_inc_mget(dict = mlr_pipeops, .keys = .objs, ..., .dicts_suggest = list("ppls()" = mlr_graphs))
}

#' @export
Expand Down
4 changes: 2 additions & 2 deletions R/ppl.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
#' gr = ppl("bagging", graph = po(lrn("regr.rpart")),
#' averager = po("regravg", collect_multiplicity = TRUE))
ppl = function(.key, ...) {
dictionary_sugar_get(dict = mlr_graphs, .key = .key, ...)
dictionary_sugar_get(dict = mlr_graphs, .key = .key, ..., .dicts_suggest = list("po()" = mlr_pipeops))
}

#' @export
#' @rdname ppl
ppls = function(.keys, ...) {
if (missing(.keys)) return(mlr_graphs)
map(.x = .keys, .f = dictionary_sugar_get, dict = mlr_graphs, ...)
map(.x = .keys, .f = dictionary_sugar_get, dict = mlr_graphs, ..., .dicts_suggest = list("pos()" = mlr_pipeops))
}
8 changes: 8 additions & 0 deletions tests/testthat/test_po.R
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,11 @@ test_that("Incrementing ids works", {
xs = pos(c("pca_1", "pca_2"))
assert_true(all(names(xs) == c("pca_1", "pca_2")))
})

test_that("po - dictionary suggest works", {

# test that correct dictionary is checked against
expect_error(po("robustify"), "ppl\\(\\): 'robustify'")
expect_error(pos("robustify"), "ppls\\(\\): 'robustify'")

})
10 changes: 9 additions & 1 deletion tests/testthat/test_ppl.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ test_that("mlr_graphs access works", {
})


test_that("mlr_pipeops multi-access works", {
test_that("mlr_graphs multi-access works", {

expect_equal(
ppls("robustify"),
Expand Down Expand Up @@ -73,3 +73,11 @@ test_that("mlr3book authors don't sleepwalk through life", {
bmr = benchmark(benchmark_grid(tasks, learners, rsmp("cv", folds = 2)))

})

test_that("ppl - dictionary suggest works", {

# test that correct dictionary is checked against
expect_error(ppl("adas"), "po\\(\\): 'adas'")
expect_error(ppls("adas"), "pos\\(\\): 'adas'")

})

0 comments on commit 086646b

Please sign in to comment.