Skip to content

Commit

Permalink
fix docs, imports etc
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Sep 4, 2024
1 parent 78a2e21 commit ec5ec99
Show file tree
Hide file tree
Showing 15 changed files with 96 additions and 55 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Depends:
Imports:
checkmate,
cli,
dplyr,
future,
future.apply,
ggplot2,
Expand Down
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ importFrom(future,multisession)
importFrom(future,plan)
importFrom(future,sequential)
importFrom(future.apply,future_lapply)
importFrom(ggplot2,aes)
importFrom(ggplot2,facet_wrap)
importFrom(ggplot2,geom_line)
importFrom(ggplot2,geom_pointrange)
importFrom(ggplot2,geom_ribbon)
importFrom(ggplot2,ggplot)
importFrom(grDevices,col2rgb)
importFrom(grDevices,rainbow)
importFrom(graphics,barplot)
Expand Down
50 changes: 29 additions & 21 deletions R/average_marginal_prediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#' @param model A Hidden Markov Model of class `nhmm` or `mnhmm`.
#' @param variable Name of the variable of interest.
#' @param values Vector containing one or two values for `variable`.
#' @param marginalize_B_over Character string defining the dimensions over which
#' emission probabilities are marginalized. Default is `"sequences"`.
#' @param newdata Optional data frame which is used for marginalization.
#' @param nsim Non-negative integer defining the number of samples from the
#' normal approximation of the model parameters used in
Expand All @@ -14,8 +16,16 @@
average_marginal_prediction <- function(
model, variable, values, marginalize_B_over = "sequences", newdata = NULL,
nsim = 0, probs = c(0.025, 0.5, 0.975)) {
marginalize_over <- match.arg(
marginalize_over, c("sequences", "states", "clusters"), several.ok = TRUE)
stopifnot_(
inherits(model, "nhmm") || inherits(model, "mnhmm"),
"Argument {.arg model} must be a {.cls nhmm} or {.cls mnhmm} object."
)
marginalize_B_over <- match.arg(
marginalize_B_over, c("sequences", "states", "clusters"), several.ok = TRUE)
stopifnot_(
marginalize_B_over != "clusters" || model$n_clusters > 1,
"Cannot marginalize over clusters as {.arg model} is not a {.cls mnhmm} object."
)
stopifnot_(
checkmate::test_count(nsim),
"Argument {.arg nsim} should be a single non-negative integer."
Expand Down Expand Up @@ -81,23 +91,22 @@ average_marginal_prediction <- function(
times <- colnames(model$observations[[1]])
symbol_names <- model$symbol_names
}
marginalize <- dplyr::recode(
marginalize_B_over,
"clusters" = "cluster",
"states" = "state",
"sequences" = "id")
id_var <- model$id_variable
time_var <- model$time_variable
marginalize <- c(
switch(
marginalize_B_over,
"clusters" = c("cluster", "state", "id"),
"states" = c("state", "id"),
"sequences" = "id"),
"time", "channel", "observation")

pi <- data.frame(
cluster = rep(model$cluster_names, each = S * N),
id = rep(ids, each = S),
state = model$state_names,
estimate = unlist(pred$pi)
) |>
dplyr::group_by(.data[[marginalize]]) |>
dplyr::group_by(cluster, state) |>
summarise(estimate = mean(estimate))
colnames(pi)[2] <- id_var

A <- data.frame(
cluster = rep(model$cluster_names, each = S^2 * T * N),
Expand All @@ -106,9 +115,10 @@ average_marginal_prediction <- function(
state_from = model$state_names,
state_to = rep(model$state_names, each = S),
estimate = unlist(pred$A)
)
colnames(A)[2] <- id_var
colnames(A)[3] <- time_var
) |>
dplyr::group_by(cluster, time, state_from, state_to) |>
dplyr::summarise(estimate = mean(estimate)) |>
dplyr::rename(!!time_var := time)

B <- data.frame(
cluster = rep(model$cluster_names, each = S * sum(M) * T * N),
Expand All @@ -123,13 +133,9 @@ average_marginal_prediction <- function(
})),
estimate = unlist(pred$B)
) |>
dplyr::group_by(.data[[marginalize]]) |>
summarise(estimate = mean(estimate))
idx <- which(names(B) == "id")
if(length(idx) == 1) names(B)[idx] <- id_var
idx <- which(names(B) == "time")
if(length(idx) == 1) names(B)[idx] <- time_var
if (C == 1) emission_probs$channel <- NULL
dplyr::group_by(across(all_of(marginalize))) |>
dplyr::summarise(estimate = mean(estimate)) |>
dplyr::rename(!!time_var := time)

if (D > 1) {
omega <- data.frame(
Expand All @@ -142,5 +148,7 @@ average_marginal_prediction <- function(
out <- list(pi = pi, A = A, B = B)
}
class(out) <- "amp"
attr(out, "seed") <- seed
attr(out, "marginalize_B_over") <- marginalize_B_over
out
}
3 changes: 2 additions & 1 deletion R/build_lcm.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@
#' fit <- fit_model(model)
#'
#' # How many of the observations were correctly classified:
#' sum(summary(fit$model)$most_probable_cluster == rep(c("Class 2", "Class 1"), times = c(500, 200)))
#' sum(summary(fit$model)$most_probable_cluster == rep(c("Class 2", "Class 1"),
#' times = c(500, 200)))
#'
#' ############################################################
#' \dontrun{
Expand Down
8 changes: 4 additions & 4 deletions R/get_probs.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ get_probs.nhmm <- function(model, newdata = NULL, nsim = 0,
"Argument {.arg nsim} should be a single non-negative integer."
)
if (!is.null(newdata)) {
model <- update(model, newdata = ne)
model <- update(model, newdata = newdata)
}
S <- model$n_states
M <- model$n_symbols
Expand Down Expand Up @@ -93,16 +93,16 @@ get_probs.nhmm <- function(model, newdata = NULL, nsim = 0,
}
#' @rdname get_probs
#' @export
get_probs.mnhmm <- function(model, ne = NULL, nsim = 0,
get_probs.mnhmm <- function(model, newdata = NULL, nsim = 0,
probs = c(0.025, 0.5, 0.975), ...) {

stopifnot_(
checkmate::test_count(nsim),
"Argument {.arg nsim} should be a single non-negative integer."
)

if (!is.null(ne)) {
model <- update(model, newdata = ne)
if (!is.null(newdata)) {
model <- update(model, newdata = newdata)
}

T <- model$length_of_sequences
Expand Down
10 changes: 6 additions & 4 deletions R/plot.ame.R → R/plot.amp.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#' Visualize Average Marginal Effects
#'
#' @param x Output from `ame`.
#' @importFrom ggplot2 ggplot aes geom_pointrange geom_ribbon geom_line facet_wrap
#' @param x Output from [amp()].
#' @param alpha Transparency level for [ggplot2::geom_ribbon()].
plot.ame <- function(x, type, probs = c(0.025, 0.975), alpha = 0.25) {
plot.amp <- function(x, type, probs = c(0.025, 0.975), alpha = 0.25) {
type <- match.arg(type, c("initial", "transition", "emission", "cluster"))

cluster <- time <- state <- state_from <- state_to <- observation <- NULL
stopifnot_(
checkmate::test_numeric(
x = probs, lower = 0, upper = 1, any.missing = FALSE, min.len = 2L,
Expand Down Expand Up @@ -37,10 +39,10 @@ plot.ame <- function(x, type, probs = c(0.025, 0.975), alpha = 0.25) {
if (type == "emission") {
p <- ggplot(x$emission, aes(estimate, time)) +
geom_ribbon(
aes(ymin = .data[[lwr]], ymax = .data[[upr]], fill = symbol),
aes(ymin = .data[[lwr]], ymax = .data[[upr]], fill = observation),
alpha = alpha
) +
geom_line(aes(colour = symbol)) +
geom_line(aes(colour = observation)) +
facet_wrap(~ state)
if (!is.null(cluster)) {
p <- p + facet_wrap(~ cluster)
Expand Down
16 changes: 4 additions & 12 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ predict.nhmm <- function(
!is.null(newdata[[time]]),
"Can't find time index variable {.var {time}} in {.arg newdata}."
)
stopifnot_(
!is.null(newdata[[variable]]),
"Can't find time variable {.var {variable}} in {.arg newdata}."
)
} else {
stopifnot(
!is.null(model$data),
Expand All @@ -53,7 +49,7 @@ predict.nhmm <- function(
beta_o_raw <- stan_to_cpp_emission(
model$estimation_results$parameters$beta_o_raw,
1,
C > 1
model$n_channels > 1
)
X_initial <- t(model$X_initial)
X_transition <- aperm(model$X_transition, c(3, 1, 2))
Expand All @@ -66,7 +62,7 @@ predict.nhmm <- function(
} else {
get_multichannel_B(
beta_o_raw,
X_emission1,
X_emission,
model$n_states,
model$n_channels,
model$n_symbols,
Expand Down Expand Up @@ -106,10 +102,6 @@ predict.mnhmm <- function(
!is.null(newdata[[time]]),
"Can't find time index variable {.var {time}} in {.arg newdata}."
)
stopifnot_(
!is.null(newdata[[variable]]),
"Can't find time variable {.var {variable}} in {.arg newdata}."
)
} else {
stopifnot(
!is.null(model$data),
Expand All @@ -128,7 +120,7 @@ predict.mnhmm <- function(
beta_o_raw <- stan_to_cpp_emission(
model$estimation_results$parameters$beta_o_raw,
1,
C > 1
model$n_channels > 1
)
X_initial <- t(model$X_initial)
X_transition <- aperm(model$X_transition, c(3, 1, 2))
Expand All @@ -142,7 +134,7 @@ predict.mnhmm <- function(
} else {
get_multichannel_B(
beta_o_raw,
X_emission1,
X_emission,
model$n_states,
model$n_channels,
model$n_symbols,
Expand Down
8 changes: 7 additions & 1 deletion R/sort_sequences.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@

#' Sort sequences in a sequence object
#'
#' @param x A sequence object or a list of sequence objects
#' @param sort_by A character string specifying the sorting criterion.
#' @param sort_channel An integer or character string specifying the channel to
#' sort by.
#' @param dist_method A character string specifying the distance method to use.
#' @export
sort_sequences <- function(
x, sort_by = "start", sort_channel = 1, dist_method = "OM") {
Expand Down
8 changes: 4 additions & 4 deletions R/summary.nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,18 @@ summary.mnhmm <- function(object, nsim = 0, probs = c(0.025, 0.5, 0.975), ...) {
}
#' @export
print.summary.nhmm <- function(x, digits = 3, ...) {
print(object)
print(x)
cat("\nCoefficients:\n")
print.listof(out, digits = digits, ...)
print.listof(x$coefficients, digits = digits, ...)

cat("Log-likelihood:", x$logLik, " BIC:", x$BIC, "\n\n")
invisible(x)
}
#' @export
print.summary.mnhmm <- function(x, digits = 3, ...) {
print(object)
print(x)
cat("\nCoefficients:\n")
print.listof(out, digits = digits, ...)
print.listof(x$coefficients, digits = digits, ...)

cat("Log-likelihood:", x$logLik, " BIC:", x$BIC, "\n\n")

Expand Down
2 changes: 1 addition & 1 deletion R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ stop_ <- function(message, ..., call = rlang::caller_env()) {
#' @param f A formula object.
#' @noRd
intercept_only <- function(f) {
identical(deparse(update(f, 0 ~ .)), "0 ~ 1")
identical(deparse(stats::update(f, 0 ~ .)), "0 ~ 1")
}
#' Create obsArray for Various C++ functions
#'
Expand Down
3 changes: 3 additions & 0 deletions man/average_marginal_prediction.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/build_lcm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/get_probs.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions man/plot.ame.Rd → man/plot.amp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions man/sort_sequences.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ec5ec99

Please sign in to comment.