diff --git a/R/helpers.R b/R/helpers.R index 58424f83..38b401dd 100644 --- a/R/helpers.R +++ b/R/helpers.R @@ -57,8 +57,8 @@ validate_ll <- function(x) { stop("List not allowed as input.") } else if (anyNA(x)) { stop("NAs not allowed in input.") - } else if (!all(is.finite(x))) { - stop("All input values must be finite.") + } else if (any(x == Inf)) { + stop("All input values must be finite or -Inf.") } invisible(x) } diff --git a/R/loo_moment_matching.R b/R/loo_moment_matching.R index 9ba55cad..b0c2cff5 100644 --- a/R/loo_moment_matching.R +++ b/R/loo_moment_matching.R @@ -386,7 +386,6 @@ loo_moment_match_i <- function(i, dim(log_liki) <- NULL } - # pointwise estimates elpd_loo_i <- matrixStats::logSumExp(log_liki + lwi) mcse_elpd_loo <- mcse_elpd( @@ -439,17 +438,18 @@ update_quantities_i <- function(x, upars, i, orig_log_prob, log_liki_new <- log_lik_i_upars(x, upars = upars, i = i, ...) # compute new log importance weights - is_obj_new <- suppressWarnings(importance_sampling.default(-log_liki_new + - log_prob_new - - orig_log_prob, + # If log_liki_new and log_prob_new both have same element as Inf, + # replace the log ratio with -Inf + lr <- -log_liki_new + log_prob_new - orig_log_prob + lr[is.na(lr)] <- -Inf + is_obj_new <- suppressWarnings(importance_sampling.default(lr, method = is_method, r_eff = r_eff_i, cores = 1)) lwi_new <- as.vector(weights(is_obj_new)) ki_new <- is_obj_new$diagnostics$pareto_k - is_obj_f_new <- suppressWarnings(importance_sampling.default(log_prob_new - - orig_log_prob, + is_obj_f_new <- suppressWarnings(importance_sampling.default(log_prob_new - orig_log_prob, method = is_method, r_eff = r_eff_i, cores = 1)) diff --git a/R/split_moment_matching.R b/R/split_moment_matching.R index a8ad8be9..b7bdb111 100644 --- a/R/split_moment_matching.R +++ b/R/split_moment_matching.R @@ -98,14 +98,21 @@ loo_moment_match_split <- function(x, upars, cov, total_shift, total_scaling, log1p(exp(log_prob_half_trans[!stable_S] - log_prob_half_trans_inv[!stable_S]))) - is_obj_half <- suppressWarnings(importance_sampling.default(lwi_half, + # lwi_half may have NaNs if computation involves -Inf + Inf + # replace NaN log ratios with -Inf + lr <- lwi_half + lr[is.na(lr)] <- -Inf + is_obj_half <- suppressWarnings(importance_sampling.default(lr, method = is_method, r_eff = r_eff_i, cores = cores)) lwi_half <- as.vector(weights(is_obj_half)) - is_obj_f_half <- suppressWarnings(importance_sampling.default(lwi_half + - log_liki_half, + # lwi_half may have NaNs if computation involves -Inf + Inf + # replace NaN log ratios with -Inf + lr <- lwi_half + log_liki_half + lr[is.na(lr)] <- -Inf + is_obj_f_half <- suppressWarnings(importance_sampling.default(lr, method = is_method, r_eff = r_eff_i, cores = cores)) diff --git a/tests/testthat/test_loo_moment_matching.R b/tests/testthat/test_loo_moment_matching.R index 16a9a426..ae67b20b 100644 --- a/tests/testthat/test_loo_moment_matching.R +++ b/tests/testthat/test_loo_moment_matching.R @@ -177,6 +177,11 @@ test_that("loo_moment_match.default warnings work", { test_that("loo_moment_match.default works", { + # allow -Inf + lwi_x <- lwi_1 + lwi_x[which.min(lwi_1)] <- -Inf + expect_no_error(suppressWarnings(importance_sampling.default(lwi_1, method = "psis", r_eff = 1, cores = 1))) + # loo object loo_manual <- suppressWarnings(loo(loglik)) @@ -288,7 +293,6 @@ test_that("loo_moment_match.default works with multiple cores", { }) - test_that("loo_moment_match_split works", { # skip on M1 Mac until we figure out why this test fails only on M1 Mac skip_if(Sys.info()[["sysname"]] == "Darwin" && R.version$arch == "aarch64") diff --git a/tests/testthat/test_psis.R b/tests/testthat/test_psis.R index 95a4a524..ef4b14cd 100644 --- a/tests/testthat/test_psis.R +++ b/tests/testthat/test_psis.R @@ -88,8 +88,11 @@ test_that("psis throws correct errors and warnings", { expect_error(psis(-LLmat), "NAs not allowed in input") LLmat[1,1] <- 1 + LLmat[10, 2] <- -Inf + expect_error(psis(-LLmat), "All input values must be finite or -Inf") + # log ratio of -Inf is allowed LLmat[10, 2] <- Inf - expect_error(psis(-LLmat), "All input values must be finite") + expect_no_error(psis(-LLmat)) # no lists allowed expect_error(expect_warning(psis(as.list(-LLvec))), "List not allowed as input") diff --git a/tests/testthat/test_tisis.R b/tests/testthat/test_tisis.R index bd80bdd2..5d693a37 100644 --- a/tests/testthat/test_tisis.R +++ b/tests/testthat/test_tisis.R @@ -121,8 +121,10 @@ test_that("tis throws correct errors and warnings", { expect_error(tis(-LLmat), "NAs not allowed in input") LLmat[1,1] <- 1 + LLmat[10, 2] <- -Inf + expect_error(tis(-LLmat), "All input values must be finite or -Inf") LLmat[10, 2] <- Inf - expect_error(tis(-LLmat), "All input values must be finite") + expect_no_error(tis(-LLmat)) # no lists allowed expect_error(expect_warning(tis(as.list(-LLvec)), "List not allowed as input"))