From ec5ec992cc74834bb4fb04aaabacad0e9e9580d0 Mon Sep 17 00:00:00 2001 From: Jouni Helske Date: Wed, 4 Sep 2024 21:27:01 +0300 Subject: [PATCH] fix docs, imports etc --- DESCRIPTION | 1 + NAMESPACE | 6 ++++ R/average_marginal_prediction.R | 50 +++++++++++++++++------------- R/build_lcm.R | 3 +- R/get_probs.R | 8 ++--- R/{plot.ame.R => plot.amp.R} | 10 +++--- R/predict.R | 16 +++------- R/sort_sequences.R | 8 ++++- R/summary.nhmm.R | 8 ++--- R/utilities.R | 2 +- man/average_marginal_prediction.Rd | 3 ++ man/build_lcm.Rd | 3 +- man/get_probs.Rd | 2 +- man/{plot.ame.Rd => plot.amp.Rd} | 10 +++--- man/sort_sequences.Rd | 21 +++++++++++++ 15 files changed, 96 insertions(+), 55 deletions(-) rename R/{plot.ame.R => plot.amp.R} (79%) rename man/{plot.ame.Rd => plot.amp.Rd} (58%) create mode 100644 man/sort_sequences.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 021f9cbd..11792d15 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -40,6 +40,7 @@ Depends: Imports: checkmate, cli, + dplyr, future, future.apply, ggplot2, diff --git a/NAMESPACE b/NAMESPACE index c931c134..4987f334 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -107,6 +107,12 @@ importFrom(future,multisession) importFrom(future,plan) importFrom(future,sequential) importFrom(future.apply,future_lapply) +importFrom(ggplot2,aes) +importFrom(ggplot2,facet_wrap) +importFrom(ggplot2,geom_line) +importFrom(ggplot2,geom_pointrange) +importFrom(ggplot2,geom_ribbon) +importFrom(ggplot2,ggplot) importFrom(grDevices,col2rgb) importFrom(grDevices,rainbow) importFrom(graphics,barplot) diff --git a/R/average_marginal_prediction.R b/R/average_marginal_prediction.R index d9d599d6..d9a0b66f 100644 --- a/R/average_marginal_prediction.R +++ b/R/average_marginal_prediction.R @@ -3,6 +3,8 @@ #' @param model A Hidden Markov Model of class `nhmm` or `mnhmm`. #' @param variable Name of the variable of interest. #' @param values Vector containing one or two values for `variable`. +#' @param marginalize_B_over Character string defining the dimensions over which +#' emission probabilities are marginalized. Default is `"sequences"`. #' @param newdata Optional data frame which is used for marginalization. #' @param nsim Non-negative integer defining the number of samples from the #' normal approximation of the model parameters used in @@ -14,8 +16,16 @@ average_marginal_prediction <- function( model, variable, values, marginalize_B_over = "sequences", newdata = NULL, nsim = 0, probs = c(0.025, 0.5, 0.975)) { - marginalize_over <- match.arg( - marginalize_over, c("sequences", "states", "clusters"), several.ok = TRUE) + stopifnot_( + inherits(model, "nhmm") || inherits(model, "mnhmm"), + "Argument {.arg model} must be a {.cls nhmm} or {.cls mnhmm} object." + ) + marginalize_B_over <- match.arg( + marginalize_B_over, c("sequences", "states", "clusters"), several.ok = TRUE) + stopifnot_( + marginalize_B_over != "clusters" || model$n_clusters > 1, + "Cannot marginalize over clusters as {.arg model} is not a {.cls mnhmm} object." + ) stopifnot_( checkmate::test_count(nsim), "Argument {.arg nsim} should be a single non-negative integer." @@ -81,13 +91,13 @@ average_marginal_prediction <- function( times <- colnames(model$observations[[1]]) symbol_names <- model$symbol_names } - marginalize <- dplyr::recode( - marginalize_B_over, - "clusters" = "cluster", - "states" = "state", - "sequences" = "id") - id_var <- model$id_variable - time_var <- model$time_variable + marginalize <- c( + switch( + marginalize_B_over, + "clusters" = c("cluster", "state", "id"), + "states" = c("state", "id"), + "sequences" = "id"), + "time", "channel", "observation") pi <- data.frame( cluster = rep(model$cluster_names, each = S * N), @@ -95,9 +105,8 @@ average_marginal_prediction <- function( state = model$state_names, estimate = unlist(pred$pi) ) |> - dplyr::group_by(.data[[marginalize]]) |> + dplyr::group_by(cluster, state) |> summarise(estimate = mean(estimate)) - colnames(pi)[2] <- id_var A <- data.frame( cluster = rep(model$cluster_names, each = S^2 * T * N), @@ -106,9 +115,10 @@ average_marginal_prediction <- function( state_from = model$state_names, state_to = rep(model$state_names, each = S), estimate = unlist(pred$A) - ) - colnames(A)[2] <- id_var - colnames(A)[3] <- time_var + ) |> + dplyr::group_by(cluster, time, state_from, state_to) |> + dplyr::summarise(estimate = mean(estimate)) |> + dplyr::rename(!!time_var := time) B <- data.frame( cluster = rep(model$cluster_names, each = S * sum(M) * T * N), @@ -123,13 +133,9 @@ average_marginal_prediction <- function( })), estimate = unlist(pred$B) ) |> - dplyr::group_by(.data[[marginalize]]) |> - summarise(estimate = mean(estimate)) - idx <- which(names(B) == "id") - if(length(idx) == 1) names(B)[idx] <- id_var - idx <- which(names(B) == "time") - if(length(idx) == 1) names(B)[idx] <- time_var - if (C == 1) emission_probs$channel <- NULL + dplyr::group_by(across(all_of(marginalize))) |> + dplyr::summarise(estimate = mean(estimate)) |> + dplyr::rename(!!time_var := time) if (D > 1) { omega <- data.frame( @@ -142,5 +148,7 @@ average_marginal_prediction <- function( out <- list(pi = pi, A = A, B = B) } class(out) <- "amp" + attr(out, "seed") <- seed + attr(out, "marginalize_B_over") <- marginalize_B_over out } diff --git a/R/build_lcm.R b/R/build_lcm.R index 984dc387..cef84d1f 100644 --- a/R/build_lcm.R +++ b/R/build_lcm.R @@ -65,7 +65,8 @@ #' fit <- fit_model(model) #' #' # How many of the observations were correctly classified: -#' sum(summary(fit$model)$most_probable_cluster == rep(c("Class 2", "Class 1"), times = c(500, 200))) +#' sum(summary(fit$model)$most_probable_cluster == rep(c("Class 2", "Class 1"), +#' times = c(500, 200))) #' #' ############################################################ #' \dontrun{ diff --git a/R/get_probs.R b/R/get_probs.R index 679f996f..4c4ae4eb 100644 --- a/R/get_probs.R +++ b/R/get_probs.R @@ -23,7 +23,7 @@ get_probs.nhmm <- function(model, newdata = NULL, nsim = 0, "Argument {.arg nsim} should be a single non-negative integer." ) if (!is.null(newdata)) { - model <- update(model, newdata = ne) + model <- update(model, newdata = newdata) } S <- model$n_states M <- model$n_symbols @@ -93,7 +93,7 @@ get_probs.nhmm <- function(model, newdata = NULL, nsim = 0, } #' @rdname get_probs #' @export -get_probs.mnhmm <- function(model, ne = NULL, nsim = 0, +get_probs.mnhmm <- function(model, newdata = NULL, nsim = 0, probs = c(0.025, 0.5, 0.975), ...) { stopifnot_( @@ -101,8 +101,8 @@ get_probs.mnhmm <- function(model, ne = NULL, nsim = 0, "Argument {.arg nsim} should be a single non-negative integer." ) - if (!is.null(ne)) { - model <- update(model, newdata = ne) + if (!is.null(newdata)) { + model <- update(model, newdata = newdata) } T <- model$length_of_sequences diff --git a/R/plot.ame.R b/R/plot.amp.R similarity index 79% rename from R/plot.ame.R rename to R/plot.amp.R index 457a6f10..a3fb9390 100644 --- a/R/plot.ame.R +++ b/R/plot.amp.R @@ -1,10 +1,12 @@ #' Visualize Average Marginal Effects #' -#' @param x Output from `ame`. +#' @importFrom ggplot2 ggplot aes geom_pointrange geom_ribbon geom_line facet_wrap +#' @param x Output from [amp()]. #' @param alpha Transparency level for [ggplot2::geom_ribbon()]. -plot.ame <- function(x, type, probs = c(0.025, 0.975), alpha = 0.25) { +plot.amp <- function(x, type, probs = c(0.025, 0.975), alpha = 0.25) { type <- match.arg(type, c("initial", "transition", "emission", "cluster")) + cluster <- time <- state <- state_from <- state_to <- observation <- NULL stopifnot_( checkmate::test_numeric( x = probs, lower = 0, upper = 1, any.missing = FALSE, min.len = 2L, @@ -37,10 +39,10 @@ plot.ame <- function(x, type, probs = c(0.025, 0.975), alpha = 0.25) { if (type == "emission") { p <- ggplot(x$emission, aes(estimate, time)) + geom_ribbon( - aes(ymin = .data[[lwr]], ymax = .data[[upr]], fill = symbol), + aes(ymin = .data[[lwr]], ymax = .data[[upr]], fill = observation), alpha = alpha ) + - geom_line(aes(colour = symbol)) + + geom_line(aes(colour = observation)) + facet_wrap(~ state) if (!is.null(cluster)) { p <- p + facet_wrap(~ cluster) diff --git a/R/predict.R b/R/predict.R index b7cc2994..6b25ad5d 100644 --- a/R/predict.R +++ b/R/predict.R @@ -31,10 +31,6 @@ predict.nhmm <- function( !is.null(newdata[[time]]), "Can't find time index variable {.var {time}} in {.arg newdata}." ) - stopifnot_( - !is.null(newdata[[variable]]), - "Can't find time variable {.var {variable}} in {.arg newdata}." - ) } else { stopifnot( !is.null(model$data), @@ -53,7 +49,7 @@ predict.nhmm <- function( beta_o_raw <- stan_to_cpp_emission( model$estimation_results$parameters$beta_o_raw, 1, - C > 1 + model$n_channels > 1 ) X_initial <- t(model$X_initial) X_transition <- aperm(model$X_transition, c(3, 1, 2)) @@ -66,7 +62,7 @@ predict.nhmm <- function( } else { get_multichannel_B( beta_o_raw, - X_emission1, + X_emission, model$n_states, model$n_channels, model$n_symbols, @@ -106,10 +102,6 @@ predict.mnhmm <- function( !is.null(newdata[[time]]), "Can't find time index variable {.var {time}} in {.arg newdata}." ) - stopifnot_( - !is.null(newdata[[variable]]), - "Can't find time variable {.var {variable}} in {.arg newdata}." - ) } else { stopifnot( !is.null(model$data), @@ -128,7 +120,7 @@ predict.mnhmm <- function( beta_o_raw <- stan_to_cpp_emission( model$estimation_results$parameters$beta_o_raw, 1, - C > 1 + model$n_channels > 1 ) X_initial <- t(model$X_initial) X_transition <- aperm(model$X_transition, c(3, 1, 2)) @@ -142,7 +134,7 @@ predict.mnhmm <- function( } else { get_multichannel_B( beta_o_raw, - X_emission1, + X_emission, model$n_states, model$n_channels, model$n_symbols, diff --git a/R/sort_sequences.R b/R/sort_sequences.R index 64e02338..2185bece 100644 --- a/R/sort_sequences.R +++ b/R/sort_sequences.R @@ -1,4 +1,10 @@ - +#' Sort sequences in a sequence object +#' +#' @param x A sequence object or a list of sequence objects +#' @param sort_by A character string specifying the sorting criterion. +#' @param sort_channel An integer or character string specifying the channel to +#' sort by. +#' @param dist_method A character string specifying the distance method to use. #' @export sort_sequences <- function( x, sort_by = "start", sort_channel = 1, dist_method = "OM") { diff --git a/R/summary.nhmm.R b/R/summary.nhmm.R index 6fe5a98a..a38e16dc 100644 --- a/R/summary.nhmm.R +++ b/R/summary.nhmm.R @@ -51,18 +51,18 @@ summary.mnhmm <- function(object, nsim = 0, probs = c(0.025, 0.5, 0.975), ...) { } #' @export print.summary.nhmm <- function(x, digits = 3, ...) { - print(object) + print(x) cat("\nCoefficients:\n") - print.listof(out, digits = digits, ...) + print.listof(x$coefficients, digits = digits, ...) cat("Log-likelihood:", x$logLik, " BIC:", x$BIC, "\n\n") invisible(x) } #' @export print.summary.mnhmm <- function(x, digits = 3, ...) { - print(object) + print(x) cat("\nCoefficients:\n") - print.listof(out, digits = digits, ...) + print.listof(x$coefficients, digits = digits, ...) cat("Log-likelihood:", x$logLik, " BIC:", x$BIC, "\n\n") diff --git a/R/utilities.R b/R/utilities.R index 665627fb..1a2f1349 100644 --- a/R/utilities.R +++ b/R/utilities.R @@ -185,7 +185,7 @@ stop_ <- function(message, ..., call = rlang::caller_env()) { #' @param f A formula object. #' @noRd intercept_only <- function(f) { - identical(deparse(update(f, 0 ~ .)), "0 ~ 1") + identical(deparse(stats::update(f, 0 ~ .)), "0 ~ 1") } #' Create obsArray for Various C++ functions #' diff --git a/man/average_marginal_prediction.Rd b/man/average_marginal_prediction.Rd index 48ef3f42..88128629 100644 --- a/man/average_marginal_prediction.Rd +++ b/man/average_marginal_prediction.Rd @@ -21,6 +21,9 @@ average_marginal_prediction( \item{values}{Vector containing one or two values for \code{variable}.} +\item{marginalize_B_over}{Character string defining the dimensions over which +emission probabilities are marginalized. Default is \code{"sequences"}.} + \item{newdata}{Optional data frame which is used for marginalization.} \item{nsim}{Non-negative integer defining the number of samples from the diff --git a/man/build_lcm.Rd b/man/build_lcm.Rd index aa97cc46..f88fcc51 100644 --- a/man/build_lcm.Rd +++ b/man/build_lcm.Rd @@ -86,7 +86,8 @@ model <- build_lcm(obs, n_clusters = 2) fit <- fit_model(model) # How many of the observations were correctly classified: -sum(summary(fit$model)$most_probable_cluster == rep(c("Class 2", "Class 1"), times = c(500, 200))) +sum(summary(fit$model)$most_probable_cluster == rep(c("Class 2", "Class 1"), + times = c(500, 200))) ############################################################ \dontrun{ diff --git a/man/get_probs.Rd b/man/get_probs.Rd index 42440daa..2f44bbdb 100644 --- a/man/get_probs.Rd +++ b/man/get_probs.Rd @@ -11,7 +11,7 @@ get_probs(model, ...) \method{get_probs}{nhmm}(model, newdata = NULL, nsim = 0, probs = c(0.025, 0.5, 0.975), ...) -\method{get_probs}{mnhmm}(model, ne = NULL, nsim = 0, probs = c(0.025, 0.5, 0.975), ...) +\method{get_probs}{mnhmm}(model, newdata = NULL, nsim = 0, probs = c(0.025, 0.5, 0.975), ...) } \arguments{ \item{model}{An object of class \code{nhmm} or \code{mnhmm}.} diff --git a/man/plot.ame.Rd b/man/plot.amp.Rd similarity index 58% rename from man/plot.ame.Rd rename to man/plot.amp.Rd index 48dba80e..7d847f07 100644 --- a/man/plot.ame.Rd +++ b/man/plot.amp.Rd @@ -1,13 +1,13 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot.ame.R -\name{plot.ame} -\alias{plot.ame} +% Please edit documentation in R/plot.amp.R +\name{plot.amp} +\alias{plot.amp} \title{Visualize Average Marginal Effects} \usage{ -\method{plot}{ame}(x, type, probs = c(0.025, 0.975), alpha = 0.25) +\method{plot}{amp}(x, type, probs = c(0.025, 0.975), alpha = 0.25) } \arguments{ -\item{x}{Output from \code{ame}.} +\item{x}{Output from \code{\link[=amp]{amp()}}.} \item{alpha}{Transparency level for \code{\link[ggplot2:geom_ribbon]{ggplot2::geom_ribbon()}}.} } diff --git a/man/sort_sequences.Rd b/man/sort_sequences.Rd new file mode 100644 index 00000000..276021b1 --- /dev/null +++ b/man/sort_sequences.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/sort_sequences.R +\name{sort_sequences} +\alias{sort_sequences} +\title{Sort sequences in a sequence object} +\usage{ +sort_sequences(x, sort_by = "start", sort_channel = 1, dist_method = "OM") +} +\arguments{ +\item{x}{A sequence object or a list of sequence objects} + +\item{sort_by}{A character string specifying the sorting criterion.} + +\item{sort_channel}{An integer or character string specifying the channel to +sort by.} + +\item{dist_method}{A character string specifying the distance method to use.} +} +\description{ +Sort sequences in a sequence object +}