Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix loo_moment_matching NaN issue #259

Merged
merged 2 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions R/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ validate_ll <- function(x) {
stop("List not allowed as input.")
} else if (anyNA(x)) {
stop("NAs not allowed in input.")
} else if (!all(is.finite(x))) {
stop("All input values must be finite.")
} else if (any(x==Inf)) {
stop("All input values must be finite or -Inf.")
}
invisible(x)
}
Expand Down
12 changes: 6 additions & 6 deletions R/loo_moment_matching.R
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ loo_moment_match_i <- function(i,
dim(log_liki) <- NULL
}


# pointwise estimates
elpd_loo_i <- matrixStats::logSumExp(log_liki + lwi)
lpd <- matrixStats::logSumExp(log_liki) - log(length(log_liki))
Expand Down Expand Up @@ -439,17 +438,18 @@ update_quantities_i <- function(x, upars, i, orig_log_prob,
log_liki_new <- log_lik_i_upars(x, upars = upars, i = i, ...)
# compute new log importance weights

is_obj_new <- suppressWarnings(importance_sampling.default(-log_liki_new +
log_prob_new -
orig_log_prob,
# If log_liki_new and log_prob_new both have same element as Inf,
# replace the log ratio with -Inf
lr <- -log_liki_new + log_prob_new - orig_log_prob
lr[is.na(lr)] <- -Inf
is_obj_new <- suppressWarnings(importance_sampling.default(lr,
method = is_method,
r_eff = r_eff_i,
cores = 1))
lwi_new <- as.vector(weights(is_obj_new))
ki_new <- is_obj_new$diagnostics$pareto_k

is_obj_f_new <- suppressWarnings(importance_sampling.default(log_prob_new -
orig_log_prob,
is_obj_f_new <- suppressWarnings(importance_sampling.default(log_prob_new - orig_log_prob,
method = is_method,
r_eff = r_eff_i,
cores = 1))
Expand Down
13 changes: 10 additions & 3 deletions R/split_moment_matching.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,21 @@ loo_moment_match_split <- function(x, upars, cov, total_shift, total_scaling,
log1p(exp(log_prob_half_trans[!stable_S] -
log_prob_half_trans_inv[!stable_S])))

is_obj_half <- suppressWarnings(importance_sampling.default(lwi_half,
# lwi_half may have NaNs if computation involves -Inf + Inf
# replace NaN log ratios with -Inf
lr<-lwi_half
lr[is.na(lr)] <- -Inf
is_obj_half <- suppressWarnings(importance_sampling.default(lr,
method = is_method,
r_eff = r_eff_i,
cores = cores))
lwi_half <- as.vector(weights(is_obj_half))

is_obj_f_half <- suppressWarnings(importance_sampling.default(lwi_half +
log_liki_half,
# lwi_half may have NaNs if computation involves -Inf + Inf
# replace NaN log ratios with -Inf
lr<-lwi_half + log_liki_half
lr[is.na(lr)] <- -Inf
is_obj_f_half <- suppressWarnings(importance_sampling.default(lr,
method = is_method,
r_eff = r_eff_i,
cores = cores))
Expand Down
6 changes: 5 additions & 1 deletion tests/testthat/test_loo_moment_matching.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ test_that("loo_moment_match.default warnings work", {

test_that("loo_moment_match.default works", {

# allow -Inf
lwi_x <- lwi_1
lwi_x[which.min(lwi_1)] <- -Inf
expect_no_error(suppressWarnings(importance_sampling.default(lwi_1, method = "psis", r_eff = 1, cores = 1)))

# loo object
loo_manual <- suppressWarnings(loo(loglik))

Expand Down Expand Up @@ -288,7 +293,6 @@ test_that("loo_moment_match.default works with multiple cores", {
})



test_that("loo_moment_match_split works", {
# skip on M1 Mac until we figure out why this test fails only on M1 Mac
skip_if(Sys.info()[["sysname"]] == "Darwin" && R.version$arch == "aarch64")
Expand Down
5 changes: 4 additions & 1 deletion tests/testthat/test_psis.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ test_that("psis throws correct errors and warnings", {
expect_error(psis(-LLmat), "NAs not allowed in input")

LLmat[1,1] <- 1
LLmat[10, 2] <- -Inf
expect_error(psis(-LLmat), "All input values must be finite or -Inf")
# log ratio of -Inf is allowed
LLmat[10, 2] <- Inf
expect_error(psis(-LLmat), "All input values must be finite")
expect_no_error(psis(-LLmat))

# no lists allowed
expect_error(expect_warning(psis(as.list(-LLvec))), "List not allowed as input")
Expand Down
4 changes: 3 additions & 1 deletion tests/testthat/test_tisis.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ test_that("tis throws correct errors and warnings", {
expect_error(tis(-LLmat), "NAs not allowed in input")

LLmat[1,1] <- 1
LLmat[10, 2] <- -Inf
expect_error(tis(-LLmat), "All input values must be finite or -Inf")
LLmat[10, 2] <- Inf
expect_error(tis(-LLmat), "All input values must be finite")
expect_no_error(tis(-LLmat))

# no lists allowed
expect_error(expect_warning(tis(as.list(-LLvec)), "List not allowed as input"))
Expand Down
Loading