Skip to content

Commit

Permalink
handle nonfinite values in optimization better
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Nov 29, 2024
1 parent 21e1e11 commit 611ab5c
Show file tree
Hide file tree
Showing 13 changed files with 116 additions and 118 deletions.
26 changes: 13 additions & 13 deletions R/bootstrap.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ permute_clusters <- function(model, pcp_mle) {
#' bootstrapping.
#'
#' @param model An `nhmm` or `mnhmm` object.
#' @param B number of bootstrap samples.
#' @param nsim number of bootstrap samples.
#' @param type Either `"nonparametric"` or `"parametric"`, to define whether
#' nonparametric or parametric bootstrap should be used. The former samples
#' sequences with replacement, whereas the latter simulates new datasets based
Expand All @@ -85,24 +85,24 @@ bootstrap_coefs <- function(model, ...) {
}
#' @rdname bootstrap
#' @export
bootstrap_coefs.nhmm <- function(model, B = 1000,
bootstrap_coefs.nhmm <- function(model, nsim = 1000,
type = c("nonparametric", "parametric"),
method = "DNM", ...) {
type <- match.arg(type)
stopifnot_(
checkmate::test_int(x = B, lower = 0L),
"Argument {.arg B} must be a single positive integer."
checkmate::test_int(x = nsim, lower = 0L),
"Argument {.arg nsim} must be a single positive integer."
)
init <- model$etas
gammas_mle <- model$gammas
lambda <- model$estimation_results$lambda
pseudocount <- model$estimation_results$pseudocount
p <- progressr::progressor(along = seq_len(B))
p <- progressr::progressor(along = seq_len(nsim))
original_options <- options(future.globals.maxSize = Inf)
on.exit(options(original_options))
if (type == "nonparametric") {
out <- future.apply::future_lapply(
seq_len(B), function(i) {
seq_len(nsim), function(i) {
mod <- bootstrap_model(model)
fit <- fit_nhmm(mod, init, init_sd = 0, restarts = 0, lambda = lambda,
method = method, pseudocount = pseudocount, ...)
Expand All @@ -123,7 +123,7 @@ bootstrap_coefs.nhmm <- function(model, B = 1000,
time <- model$time_variable
id <- model$id_variable
out <- future.apply::future_lapply(
seq_len(B), function(i) {
seq_len(nsim), function(i) {
mod <- simulate_nhmm(
N, T_, M, S, formula_pi, formula_A, formula_B,
d, time, id, init, 0)$model
Expand All @@ -144,26 +144,26 @@ bootstrap_coefs.nhmm <- function(model, B = 1000,
}
#' @rdname bootstrap
#' @export
bootstrap_coefs.mnhmm <- function(model, B = 1000,
bootstrap_coefs.mnhmm <- function(model, nsim = 1000,
type = c("nonparametric", "parametric"),
method = "DNM", ...) {
type <- match.arg(type)
stopifnot_(
checkmate::test_int(x = B, lower = 0L),
"Argument {.arg B} must be a single positive integer."
checkmate::test_int(x = nsim, lower = 0L),
"Argument {.arg nsim} must be a single positive integer."
)
init <- model$etas
gammas_mle <- model$gammas
pcp_mle <- posterior_cluster_probabilities(model)
lambda <- model$estimation_results$lambda
pseudocount <- model$estimation_results$pseudocount
D <- model$n_clusters
p <- progressr::progressor(along = seq_len(B))
p <- progressr::progressor(along = seq_len(nsim))
original_options <- options(future.globals.maxSize = Inf)
on.exit(options(original_options))
if (type == "nonparametric") {
out <- future.apply::future_lapply(
seq_len(B), function(i) {
seq_len(nsim), function(i) {
mod <- bootstrap_model(model)
fit <- fit_mnhmm(mod, init, init_sd = 0, restarts = 0, lambda = lambda,
method = method, pseudocount = pseudocount, ...)
Expand Down Expand Up @@ -194,7 +194,7 @@ bootstrap_coefs.mnhmm <- function(model, B = 1000,
time <- model$time_variable
id <- model$id_variable
out <- future.apply::future_lapply(
seq_len(B), function(i) {
seq_len(nsim), function(i) {
mod <- simulate_mnhmm(
N, T_, M, S, D, formula_pi, formula_A, formula_B, formula_omega,
d, time, id, init, 0)$model
Expand Down
2 changes: 1 addition & 1 deletion R/create_base_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ create_base_nhmm <- function(observations, data, time, id, n_states,

stopifnot_(
!missing(n_states) && checkmate::test_int(x = n_states, lower = 2L),
"Argument {.arg n_states} must be a single integer larger than one."
"Argument {.arg n_states} must be a single integer larger than 1."
)
stopifnot_(
inherits(initial_formula, "formula"),
Expand Down
6 changes: 3 additions & 3 deletions man/bootstrap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/estimate_mnhmm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 13 additions & 13 deletions man/estimate_nhmm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 23 additions & 23 deletions src/mnhmm_EM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@ double mnhmm_base::objective_omega(const arma::vec& x, arma::vec& grad) {
double val = arma::dot(E_omega.col(i), log_omega);
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(std::numeric_limits<double>::max());
grad.fill(maxval);
}
return std::numeric_limits<double>::max();
return n_obs * maxval;
}
value -= val;
// Only update grad if it's non-empty (i.e., for gradient-based optimization)
if (!grad.is_empty()) {
tmpgrad -= sum_eo * (E_omega.col(i) / sum_eo - omega) * X_omega.col(i).t();
if (!tmpgrad.is_finite()) {
grad.fill(std::numeric_limits<double>::max());
return std::numeric_limits<double>::max();
grad.fill(maxval);
return n_obs * maxval;
}
}
}
Expand Down Expand Up @@ -115,18 +115,18 @@ double mnhmm_base::objective_pi(const arma::vec& x, arma::vec& grad) {
double val = arma::dot(E_Pi(current_d).col(i), log_pi(current_d));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(std::numeric_limits<double>::max());
grad.fill(maxval);
}
return std::numeric_limits<double>::max();
return n_obs * maxval;
}
value -= val;
// Only update grad if it's non-empty (i.e., for gradient-based optimization)
if (!grad.is_empty()) {
tmpgrad -= sum_epi * (E_Pi(current_d).col(i) / sum_epi -
pi(current_d)) * X_pi.col(i).t();
if (!tmpgrad.is_finite()) {
grad.fill(std::numeric_limits<double>::max());
return std::numeric_limits<double>::max();
grad.fill(maxval);
return n_obs * maxval;
}
}
}
Expand Down Expand Up @@ -216,25 +216,25 @@ double mnhmm_base::objective_A(const arma::vec& x, arma::vec& grad) {
}
for (arma::uword t = 0; t < (Ti(i) - 1); t++) {
double sum_ea = arma::accu(E_A(current_s, current_d).slice(t).col(i));
if (sum_ea > arma::datum::eps) {
if (sum_ea > std::sqrt(arma::datum::eps)) {
if (tv_A) {
A1 = softmax(gamma_Arow * X_A.slice(i).col(t));
log_A1 = log(A1);
}
double val = arma::dot(E_A(current_s, current_d).slice(t).col(i), log_A1);
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(std::numeric_limits<double>::max());
grad.fill(maxval);
}
return std::numeric_limits<double>::max();
return n_obs * maxval;
}
value -= val;

if (!grad.is_empty()) {
tmpgrad -= sum_ea * (E_A(current_s, current_d).slice(t).col(i) / sum_ea - A1) * X_A.slice(i).col(t).t();
if (!tmpgrad.is_finite()) {
grad.fill(std::numeric_limits<double>::max());
return std::numeric_limits<double>::max();
grad.fill(maxval);
return n_obs * maxval;
}
}
}
Expand Down Expand Up @@ -340,7 +340,7 @@ double mnhmm_sc::objective_B(const arma::vec& x, arma::vec& grad) {
for (arma::uword t = 0; t < Ti(i); t++) {
if (obs(t, i) < M) {
double e_b = E_B(current_d)(t, i, current_s);
if (e_b > arma::datum::eps) {
if (e_b > std::sqrt(arma::datum::eps)) {
if (tv_B) {
B1 = softmax(gamma_Brow * X_B.slice(i).col(t));
log_B1 = log(B1);
Expand All @@ -349,17 +349,17 @@ double mnhmm_sc::objective_B(const arma::vec& x, arma::vec& grad) {
double val = e_b * log_B1(obs(t, i));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(std::numeric_limits<double>::max());
grad.fill(maxval);
}
return std::numeric_limits<double>::max();
return n_obs * maxval;
}
value -= val;
if (!grad.is_empty()) {
tmpgrad -=
e_b * (I.col(obs(t, i)) - B1) * X_B.slice(i).col(t).t();
if (!tmpgrad.is_finite()) {
grad.fill(std::numeric_limits<double>::max());
return std::numeric_limits<double>::max();
grad.fill(maxval);
return n_obs * maxval;
}
}
}
Expand Down Expand Up @@ -468,25 +468,25 @@ double mnhmm_mc::objective_B(const arma::vec& x, arma::vec& grad) {

if (obs(current_c, t, i) < Mc) {
double e_b = E_B(current_c, current_d)(t, i, current_s);
if (e_b > arma::datum::eps) {
if (e_b > std::sqrt(arma::datum::eps)) {
if (tv_B) {
B1 = softmax(gamma_Brow * X_B.slice(i).col(t));
log_B1 = log(B1);
}
double val = e_b * log_B1(obs(current_c, t, i));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(std::numeric_limits<double>::max());
grad.fill(maxval);
}
return std::numeric_limits<double>::max();
return n_obs * maxval;
}
value -= val;
if (!grad.is_empty()) {
tmpgrad -= e_b * (I.col(obs(current_c, t, i)) - B1) *
X_B.slice(i).col(t).t();
if (!tmpgrad.is_finite()) {
grad.fill(std::numeric_limits<double>::max());
return std::numeric_limits<double>::max();
grad.fill(maxval);
return n_obs * maxval;
}
}
}
Expand Down
14 changes: 7 additions & 7 deletions src/mnhmm_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ struct mnhmm_base {
arma::uword current_d;
const arma::uword n_obs;
double lambda;

int mstep_iter;
int mstep_error_code;
const double maxval;
int mstep_iter = 0;
int mstep_error_code = 0;

mnhmm_base(
const arma::uword S_,
Expand All @@ -76,9 +76,8 @@ struct mnhmm_base {
arma::field<arma::mat>& eta_pi_,
arma::field<arma::cube>& eta_A_,
const arma::uword n_obs_ = 0,
const double lambda = 0,
int mstep_iter = 0,
int mstep_error_code = 0)
const double lambda_ = 0,
const double maxval_ = 1e8)
: S(S_),
D(D_),
X_omega(X_d_),
Expand Down Expand Up @@ -121,7 +120,8 @@ struct mnhmm_base {
current_s(0),
current_d(0),
n_obs(n_obs_),
lambda(lambda){
lambda(lambda_),
maxval(maxval_) {
for (arma::uword d = 0; d < D; d++) {
pi(d) = arma::vec(S);
log_pi(d) = arma::vec(S);
Expand Down
Loading

0 comments on commit 611ab5c

Please sign in to comment.