diff --git a/R/kfold-helpers.R b/R/kfold-helpers.R index ec2c4513..f29b509c 100644 --- a/R/kfold-helpers.R +++ b/R/kfold-helpers.R @@ -84,7 +84,12 @@ kfold_split_stratified <- function(K = 10, x = NULL) { N <- length(x) xids <- numeric() for (l in 1:Nlev) { - xids <- c(xids, sample(which(x==l))) + idx <- which(x == l) + if (length(idx) > 1) { + xids <- c(xids, sample(idx)) + } else { + xids <- c(xids, idx) + } } bins <- rep(NA, N) bins[xids] <- rep(1:K, ceiling(N/K))[1:N] diff --git a/tests/testthat/test_kfold_helpers.R b/tests/testthat/test_kfold_helpers.R index 2e1a63dd..c613cafa 100644 --- a/tests/testthat/test_kfold_helpers.R +++ b/tests/testthat/test_kfold_helpers.R @@ -22,6 +22,14 @@ test_that("kfold_split_stratified works", { y <- mtcars$cyl fold_strat <- kfold_split_stratified(10, y) expect_equal(range(table(fold_strat)), c(3, 4)) + + # test when a group has 1 observation + # https://github.com/stan-dev/loo/issues/277 + y <- rep(c(1, 2, 3), times = c(20, 40, 1)) + expect_silent(fold_strat <- kfold_split_stratified(5, y)) # used to be a warning before fixing issue #277 + tab <- table(fold_strat, y) + expect_equal(tab[1, ], c("1" = 4, "2" = 8, "3" = 1)) + for (i in 2:nrow(tab)) expect_equal(tab[i, ], c("1" = 4, "2" = 8, "3" = 0)) }) test_that("kfold_split_grouped works", {