diff --git a/tests/testthat/test_partition.R b/tests/testthat/test_partition.R index 1bd99339a..4415bb29b 100644 --- a/tests/testthat/test_partition.R +++ b/tests/testthat/test_partition.R @@ -1,17 +1,18 @@ test_that("partition w/ stratification works", { - task = tsk("rats") - sets = partition(task) + with_seed(42, { + task = tsk("rats") + part = partition(task, ratio = 0.8) - ratio = function(status) { - tab = table(status) - unname(tab[1L] / tab[2L]) - } + ratio = function(status) { + tab = table(status) + unname(tab[1L] / tab[2L]) + } - all = ratio(task$status()) - train = ratio(task$status(sets$train)) - test = ratio(task$status(sets$test)) + all = task$cens_prop() + train = task$cens_prop(rows = part$train) + test = task$cens_prop(rows = part$test) - expect_numeric(all, lower = 6, upper = 6.2) - expect_numeric(train, lower = 6, upper = 6.2) - expect_numeric(test, lower = 6, upper = 6.2) + expect_equal(all, train, tolerance = 0.01) + expect_equal(all, test, tolerance = 0.01) + }) })