Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add order statistic warning #230

Merged
merged 6 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion R/loo_compare.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,18 @@
#' distribution, a practice derived for Gaussian linear models or
#' asymptotically, and which only applies to nested models in any case.
#'
#' If more than \eqn{11} models are compared, we internally recompute the model
#' differences using the median model by ELPD as the baseline model. We then
#' estimate whether the differences in predictive performance are potentially
#' due to chance as described by McLatchie and Vehtari (2023). This will flag
#' a warning if it is deemed that there is a risk of over-fitting due to the
#' selection process. In that case users are recommended to avoid model
#' selection based on LOO-CV, and instead to favor model averaging/stacking or
#' projection predictive inference.
#' @seealso
#' * The [FAQ page](https://mc-stan.org/loo/articles/online-only/faq.html) on
#' the __loo__ website for answers to frequently asked questions.
#' @template loo-and-psis-references
#' @template loo-and-compare-references
#'
#' @examples
#' # very artificial example, just for demonstration!
Expand Down Expand Up @@ -108,6 +116,9 @@ loo_compare.default <- function(x, ...) {
comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp)
rownames(comp) <- rnms

# run order statistics-based checks on models
loo_order_stat_check(loos, ord)

class(comp) <- c("compare.loo", class(comp))
return(comp)
}
Expand Down Expand Up @@ -270,3 +281,63 @@ loo_compare_order <- function(loos){
ord <- order(tmp[grep("^elpd", rnms), ], decreasing = TRUE)
ord
}

#' Perform checks on `"loo"` objects __after__ comparison
#' @noRd
#' @keywords internal
#' @param loos List of `"loo"` objects.
#' @param ord List of `"loo"` object orderings.
#' @return Nothing, just possibly throws errors/warnings.
loo_order_stat_check <- function(loos, ord) {

## breaks

if (length(loos) <= 11L) {
# procedure cannot be diagnosed for fewer than ten candidate models
# (total models = worst model + ten candidates)
# break from function
return(NULL)
}

## warnings

# compute the elpd differences from the median model
baseline_idx <- middle_idx(ord)
diffs <- mapply(FUN = elpd_diffs, loos[ord[baseline_idx]], loos[ord])
elpd_diff <- apply(diffs, 2, sum)

# estimate the standard deviation of the upper-half-normal
diff_median <- stats::median(elpd_diff)
elpd_diff_trunc <- elpd_diff[elpd_diff >= diff_median]
n_models <- sum(!is.na(elpd_diff_trunc))
candidate_sd <- sqrt(1 / n_models * sum(elpd_diff_trunc^2, na.rm = TRUE))

# estimate expected best diff under null hypothesis
K <- length(loos) - 1
order_stat <- order_stat_heuristic(K, candidate_sd)

if (max(elpd_diff) <= order_stat) {
# flag warning if we suspect no model is theoretically better than the baseline
warning("Difference in performance potentially due to chance.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we point users to the paper in the warning message (e.g., "See McLatchie and Vehtari (2023) for details".)? @avehtari what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now added this

"See McLatchie and Vehtari (2023) for details.",
call. = FALSE)
}
}

#' Returns the middle index of a vector
#' @noRd
#' @keywords internal
#' @param vec A vector.
#' @return Integer index value.
middle_idx <- function(vec) floor(length(vec) / 2)

#' Computes maximum order statistic from K Gaussians
#' @noRd
#' @keywords internal
#' @param K Number of Gaussians.
#' @param c Scaling of the order statistic.
#' @return Numeric expected maximum from K samples from a Gaussian with mean
#' zero and scale `"c"`
order_stat_heuristic <- function(K, c) {
qnorm(p = 1 - 1 / (K * 2), mean = 0, sd = c)
}
14 changes: 14 additions & 0 deletions man-roxygen/loo-and-compare-references.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#' @references
#' Vehtari, A., Gelman, A., and Gabry, J. (2017). Practical Bayesian model
#' evaluation using leave-one-out cross-validation and WAIC.
#' *Statistics and Computing*. 27(5), 1413--1432. doi:10.1007/s11222-016-9696-4
#' ([journal version](https://link.springer.com/article/10.1007/s11222-016-9696-4),
#' [preprint arXiv:1507.04544](https://arxiv.org/abs/1507.04544)).
#'
#' Vehtari, A., Simpson, D., Gelman, A., Yao, Y., and Gabry, J. (2019).
#' Pareto smoothed importance sampling.
#' [preprint arXiv:1507.02646](https://arxiv.org/abs/1507.02646)
#'
#' McLatchie, Y., and Vehtari, A. (2023).
#' Efficient estimation and correction of selection-induced bias with order statistics.
#' [preprint arXiv:2309.03742](https://arxiv.org/abs/2309.03742)
14 changes: 14 additions & 0 deletions man/loo_compare.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions tests/testthat/test_compare.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ test_that("loo_compare throws appropriate warnings", {
attr(w3, "yhash") <- "a"
attr(w4, "yhash") <- "b"
expect_warning(loo_compare(w3, w4), "Not all models have the same y variable")

set.seed(123)
w_list <- lapply(1:25, function(x) SW(waic(LLarr + rnorm(1, 0, 0.1))))
expect_warning(loo_compare(w_list),
"Difference in performance potentially due to chance")

w_list_short <- lapply(1:4, function(x) SW(waic(LLarr + rnorm(1, 0, 0.1))))
expect_no_warning(loo_compare(w_list_short))
})


Expand Down