Skip to content

Commit

Permalink
Merge pull request #259 from stan-dev/fix_moment_matching_NA_issue
Browse files Browse the repository at this point in the history
fix loo_moment_matching NaN issue
  • Loading branch information
jgabry authored Mar 5, 2024
2 parents bb3abb2 + 5e4cf37 commit 52bc270
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 14 deletions.
4 changes: 2 additions & 2 deletions R/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ validate_ll <- function(x) {
stop("List not allowed as input.")
} else if (anyNA(x)) {
stop("NAs not allowed in input.")
} else if (!all(is.finite(x))) {
stop("All input values must be finite.")
} else if (any(x == Inf)) {
stop("All input values must be finite or -Inf.")
}
invisible(x)
}
Expand Down
12 changes: 6 additions & 6 deletions R/loo_moment_matching.R
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,6 @@ loo_moment_match_i <- function(i,
dim(log_liki) <- NULL
}


# pointwise estimates
elpd_loo_i <- matrixStats::logSumExp(log_liki + lwi)
mcse_elpd_loo <- mcse_elpd(
Expand Down Expand Up @@ -439,17 +438,18 @@ update_quantities_i <- function(x, upars, i, orig_log_prob,
log_liki_new <- log_lik_i_upars(x, upars = upars, i = i, ...)
# compute new log importance weights

is_obj_new <- suppressWarnings(importance_sampling.default(-log_liki_new +
log_prob_new -
orig_log_prob,
# If log_liki_new and log_prob_new both have same element as Inf,
# replace the log ratio with -Inf
lr <- -log_liki_new + log_prob_new - orig_log_prob
lr[is.na(lr)] <- -Inf
is_obj_new <- suppressWarnings(importance_sampling.default(lr,
method = is_method,
r_eff = r_eff_i,
cores = 1))
lwi_new <- as.vector(weights(is_obj_new))
ki_new <- is_obj_new$diagnostics$pareto_k

is_obj_f_new <- suppressWarnings(importance_sampling.default(log_prob_new -
orig_log_prob,
is_obj_f_new <- suppressWarnings(importance_sampling.default(log_prob_new - orig_log_prob,
method = is_method,
r_eff = r_eff_i,
cores = 1))
Expand Down
13 changes: 10 additions & 3 deletions R/split_moment_matching.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,21 @@ loo_moment_match_split <- function(x, upars, cov, total_shift, total_scaling,
log1p(exp(log_prob_half_trans[!stable_S] -
log_prob_half_trans_inv[!stable_S])))

is_obj_half <- suppressWarnings(importance_sampling.default(lwi_half,
# lwi_half may have NaNs if computation involves -Inf + Inf
# replace NaN log ratios with -Inf
lr <- lwi_half
lr[is.na(lr)] <- -Inf
is_obj_half <- suppressWarnings(importance_sampling.default(lr,
method = is_method,
r_eff = r_eff_i,
cores = cores))
lwi_half <- as.vector(weights(is_obj_half))

is_obj_f_half <- suppressWarnings(importance_sampling.default(lwi_half +
log_liki_half,
# lwi_half may have NaNs if computation involves -Inf + Inf
# replace NaN log ratios with -Inf
lr <- lwi_half + log_liki_half
lr[is.na(lr)] <- -Inf
is_obj_f_half <- suppressWarnings(importance_sampling.default(lr,
method = is_method,
r_eff = r_eff_i,
cores = cores))
Expand Down
6 changes: 5 additions & 1 deletion tests/testthat/test_loo_moment_matching.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ test_that("loo_moment_match.default warnings work", {

test_that("loo_moment_match.default works", {

# allow -Inf
lwi_x <- lwi_1
lwi_x[which.min(lwi_1)] <- -Inf
expect_no_error(suppressWarnings(importance_sampling.default(lwi_1, method = "psis", r_eff = 1, cores = 1)))

# loo object
loo_manual <- suppressWarnings(loo(loglik))

Expand Down Expand Up @@ -288,7 +293,6 @@ test_that("loo_moment_match.default works with multiple cores", {
})



test_that("loo_moment_match_split works", {
# skip on M1 Mac until we figure out why this test fails only on M1 Mac
skip_if(Sys.info()[["sysname"]] == "Darwin" && R.version$arch == "aarch64")
Expand Down
5 changes: 4 additions & 1 deletion tests/testthat/test_psis.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ test_that("psis throws correct errors and warnings", {
expect_error(psis(-LLmat), "NAs not allowed in input")

LLmat[1,1] <- 1
LLmat[10, 2] <- -Inf
expect_error(psis(-LLmat), "All input values must be finite or -Inf")
# log ratio of -Inf is allowed
LLmat[10, 2] <- Inf
expect_error(psis(-LLmat), "All input values must be finite")
expect_no_error(psis(-LLmat))

# no lists allowed
expect_error(expect_warning(psis(as.list(-LLvec))), "List not allowed as input")
Expand Down
4 changes: 3 additions & 1 deletion tests/testthat/test_tisis.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ test_that("tis throws correct errors and warnings", {
expect_error(tis(-LLmat), "NAs not allowed in input")

LLmat[1,1] <- 1
LLmat[10, 2] <- -Inf
expect_error(tis(-LLmat), "All input values must be finite or -Inf")
LLmat[10, 2] <- Inf
expect_error(tis(-LLmat), "All input values must be finite")
expect_no_error(tis(-LLmat))

# no lists allowed
expect_error(expect_warning(tis(as.list(-LLvec)), "List not allowed as input"))
Expand Down

0 comments on commit 52bc270

Please sign in to comment.