Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Threshold and discrimination functionality #47

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .lintr
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
linters: linters_with_defaults(
indentation_linter = NULL
indentation_linter = NULL,
return_linter = NULL
)
exclusions: list(
"R/stanmodels.R",
Expand Down
5 changes: 3 additions & 2 deletions R/model-evaluation.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ add_criterion <- function(x, criterion = c("loo", "waic"), overwrite = FALSE,

#' @export
#' @rdname model_evaluation
add_reliability <- function(x, overwrite = FALSE, save = TRUE) {
add_reliability <- function(x, threshold = 0.5, overwrite = FALSE,
save = TRUE) {
model <- check_model(x, required_class = "measrfit", name = "x")
overwrite <- check_logical(overwrite, name = "overwrite")
save <- check_logical(save, name = "force_save")
Expand All @@ -168,7 +169,7 @@ add_reliability <- function(x, overwrite = FALSE, save = TRUE) {
run_reli <- length(model$reliability) == 0 || overwrite

if (run_reli) {
model$reliability <- reliability(model)
model$reliability <- reliability(model, threshold = threshold, force = TRUE)
}

# re-save model object (if applicable)
Expand Down
34 changes: 30 additions & 4 deletions R/reliability.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ reliability <- function(model, ...) {
#'
#' @param threshold For `map_reliability`, the threshold applied to the
#' attribute-level probabilities for determining the binary attribute
#' classifications.
#' classifications. Should be a numeric vector of length 1 (the same threshold
#' is applied to all attributes), or length equal to the number of attributes.
#' If a named vector is supplied, names should match the attribute names in the
#' Q-matrix used to estimate the model. If unnamed, thresholds should be in the
#' order the attributes were defined in the Q-matrix.
#'
#' @details The pattern-level reliability (`pattern_reliability`) statistics are
#' described in Cui et al. (2012). Attribute-level classification reliability
Expand Down Expand Up @@ -71,16 +75,38 @@ reliability <- function(model, ...) {
#'
#' reliability(rstn_mdm_lcdm)
reliability.measrdcm <- function(model, ..., threshold = 0.5, force = FALSE) {
threshold <- check_double(threshold, lb = 0, ub = 1, inclusive = FALSE,
name = "threshold")
force <- check_logical(force, name = "force")

att_names <- colnames(dplyr::select(model$data$qmatrix, -"item_id"))
if (length(threshold) == 1) {
threshold <- rep(threshold, times = length(att_names)) %>%
rlang::set_names(att_names)
} else if (length(threshold) == length(att_names)) {
if (is.null(names(threshold))) {
threshold <- rlang::set_names(threshold, att_names)
} else if (!all(names(threshold) %in% att_names)) {
bad_names <- setdiff(names(threshold), att_names)
rlang::abort(
message = glue::glue("Unknown attribute names provided: ",
"{paste(bad_names, collapse = ', ')}")
)
}
} else {
rlang::abort(
message = glue::glue("`threshold` must be of length 1 or length ",
"{length(att_names)} (the number of attributes).")
)
}

if ((!is.null(model$reliability) && length(model$reliability) > 0) &&
!force) {
return(model$reliability)
}

# coerce model into a list of values required for reliability
obj <- reli_list(model, threshold = threshold)
att_names <- model$data$qmatrix %>%
dplyr::select(-"item_id") %>%
colnames()

tbl <- obj$acc
p <- obj$prev
Expand Down
16 changes: 13 additions & 3 deletions R/utils-reliability.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,19 @@ reli_list <- function(model, threshold) {

# map estimates
binary_att <- attr_probs %>%
dplyr::mutate(dplyr::across(dplyr::everything(),
~dplyr::case_when(.x >= threshold ~ 1L,
TRUE ~ 0L))) %>%
tibble::rowid_to_column(var = "resp_id") %>%
tidyr::pivot_longer(cols = -"resp_id",
names_to = "attribute", values_to = "probability") %>%
dplyr::left_join(tibble::enframe(threshold, name = "attribute",
value = "threshold"),
by = "attribute",
relationship = "many-to-one") %>%
dplyr::mutate(class = dplyr::case_when(.data$probability >=
.data$threshold ~ 1L,
.default = 0L)) %>%
dplyr::select("resp_id", "attribute", "class") %>%
tidyr::pivot_wider(names_from = "attribute", values_from = "class") %>%
dplyr::select(dplyr::all_of(names(threshold))) %>%
as.matrix() %>%
unname() %>%
as.vector()
Expand Down
2 changes: 1 addition & 1 deletion man/model_evaluation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion man/reliability.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 15 additions & 2 deletions tests/testthat/test-reliability.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
test_that("dino reliability", {
dino_reli <- reliability(rstn_dino)
dino_reli <- reliability(rstn_dino, threshold = 0.5)

# threshold errors
err <- rlang::catch_cnd(reliability(rstn_dino,
threshold = c("att1" = 0.5,
"asdf" = 0.8,
"test" = 0.7,
"att4" = 0.5,
"att5" = 0.3)))
expect_match(err$message, "Unknown attribute names")
err <- rlang::catch_cnd(reliability(rstn_dino,
threshold = c("att1" = 0.5,
"asdf" = 0.8)))
expect_match(err$message, "must be of length 1 or length 5")

# list naming
expect_equal(names(dino_reli), c("pattern_reliability", "map_reliability",
Expand Down Expand Up @@ -35,7 +48,7 @@ test_that("reliability can be added to model object", {
err <- rlang::catch_cnd(measr_extract(dina_mod, "probability_reliability"))
expect_match(err$message, "Reliability information must be added to a model")

dina_mod <- add_reliability(dina_mod)
dina_mod <- add_reliability(dina_mod, threshold = rep(0.5, 5))
expect_equal(names(dina_mod$reliability),
c("pattern_reliability", "map_reliability", "eap_reliability"))

Expand Down
Loading