Skip to content

Commit

Permalink
remove e > eps check
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Nov 30, 2024
1 parent fca69e8 commit 121d47b
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 112 deletions.
2 changes: 1 addition & 1 deletion R/estimate_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
#' then switches to L-BFGS (but other algorithms of NLopt can be used).
#' @param pseudocount A positive scalar to be added for the expected counts of
#' E-step. Only used in EM and EM-DNM algorithms. Default is 0. Larger values
#' can be used to avoid zero probabilities in initial, transition, and emission
#' can be used to avoid extreme initial, transition, and emission
#' probabilities, i.e. these have similar role as `lambda`.
#' @param store_data If `TRUE` (default), original data frame passed as `data`
#' is stored to the model object. For large datasets, this can be set to
Expand Down
108 changes: 51 additions & 57 deletions src/mnhmm_EM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,27 +214,25 @@ double mnhmm_base::objective_A(const arma::vec& x, arma::vec& grad) {
}
for (arma::uword t = 0; t < (Ti(i) - 1); t++) {
double sum_ea = arma::accu(E_A(current_s, current_d).slice(t).col(i));
if (sum_ea > 100 * arma::datum::eps) {
if (tv_A) {
A1 = softmax(gamma_Arow * X_A.slice(i).col(t));
log_A1 = log(A1);
if (tv_A) {
A1 = softmax(gamma_Arow * X_A.slice(i).col(t));
log_A1 = log(A1);
}
double val = arma::dot(E_A(current_s, current_d).slice(t).col(i), log_A1);
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
}
double val = arma::dot(E_A(current_s, current_d).slice(t).col(i), log_A1);
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
}
return n_obs * maxval;
}
value -= val;

if (!grad.is_empty()) {
grad -= arma::vectorise(tQs * (E_A(current_s, current_d).slice(t).col(i) - sum_ea * A1) * X_A.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
return n_obs * maxval;
}
value -= val;

if (!grad.is_empty()) {
grad -= arma::vectorise(tQs * (E_A(current_s, current_d).slice(t).col(i) - sum_ea * A1) * X_A.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
return n_obs * maxval;
}
}
}
}
}
Expand Down Expand Up @@ -336,28 +334,26 @@ double mnhmm_sc::objective_B(const arma::vec& x, arma::vec& grad) {
for (arma::uword t = 0; t < Ti(i); t++) {
if (obs(t, i) < M) {
double e_b = E_B(current_d)(t, i, current_s);
if (e_b > 100 * arma::datum::eps) {
if (tv_B) {
B1 = softmax(gamma_Brow * X_B.slice(i).col(t));
log_B1 = log(B1);
if (tv_B) {
B1 = softmax(gamma_Brow * X_B.slice(i).col(t));
log_B1 = log(B1);
}

double val = e_b * log_B1(obs(t, i));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
}

double val = e_b * log_B1(obs(t, i));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
}
return n_obs * maxval;
}
value -= val;
if (!grad.is_empty()) {
grad -= arma::vectorise(tQm *
e_b * (I.col(obs(t, i)) - B1) * X_B.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
return n_obs * maxval;
}
value -= val;
if (!grad.is_empty()) {
grad -= arma::vectorise(tQm *
e_b * (I.col(obs(t, i)) - B1) * X_B.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
return n_obs * maxval;
}
}
}
}
}
Expand Down Expand Up @@ -465,27 +461,25 @@ double mnhmm_mc::objective_B(const arma::vec& x, arma::vec& grad) {

if (obs(current_c, t, i) < Mc) {
double e_b = E_B(current_c, current_d)(t, i, current_s);
if (e_b > 100 * arma::datum::eps) {
if (tv_B) {
B1 = softmax(gamma_Brow * X_B.slice(i).col(t));
log_B1 = log(B1);
if (tv_B) {
B1 = softmax(gamma_Brow * X_B.slice(i).col(t));
log_B1 = log(B1);
}
double val = e_b * log_B1(obs(current_c, t, i));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
}
double val = e_b * log_B1(obs(current_c, t, i));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
}
return n_obs * maxval;
}
value -= val;
if (!grad.is_empty()) {
grad -= arma::vectorise(tQm * e_b * (I.col(obs(current_c, t, i)) - B1) *
X_B.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
return n_obs * maxval;
}
value -= val;
if (!grad.is_empty()) {
grad -= arma::vectorise(tQm * e_b * (I.col(obs(current_c, t, i)) - B1) *
X_B.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
return n_obs * maxval;
}
}
}
}
}
Expand Down Expand Up @@ -560,7 +554,7 @@ void mnhmm_mc::mstep_B(const double ftol_abs, const double ftol_rel,
Rcpp::Rcout<<"M-step of emission probabilities of state "<<s + 1<<
" and channel "<<c<<" ended with return code "<<
return_code<<" after "<<mstep_iter + 1<<
" iterations."<<std::endl;
" iterations."<<std::endl;
}
if (return_code < 0) {
mstep_return_code = return_code - 310;
Expand Down Expand Up @@ -673,7 +667,7 @@ Rcpp::List EM_LBFGS_mnhmm_singlechannel(
// check for user interrupt every two seconds
auto start_time = std::chrono::steady_clock::now();
const std::chrono::seconds check_interval(2);

while (relative_change > ftol_rel && absolute_change > ftol_abs &&
absolute_x_change > xtol_abs &&
relative_x_change > xtol_rel && iter < maxeval) {
Expand Down
102 changes: 48 additions & 54 deletions src/nhmm_EM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,27 +119,26 @@ double nhmm_base::objective_A(const arma::vec& x, arma::vec& grad) {
}
for (arma::uword t = 0; t < (Ti(i) - 1); t++) {
double sum_ea = arma::accu(E_A(current_s).slice(t).col(i));
if (sum_ea > 100 * arma::datum::eps) {
if (tv_A) {
A1 = softmax(gamma_Arow * X_A.slice(i).col(t));
log_A1 = log(A1);
if (tv_A) {
A1 = softmax(gamma_Arow * X_A.slice(i).col(t));
log_A1 = log(A1);
}
double val = arma::dot(E_A(current_s).slice(t).col(i), log_A1);
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
}
double val = arma::dot(E_A(current_s).slice(t).col(i), log_A1);
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
}
return n_obs * maxval;
}
value -= val;

if (!grad.is_empty()) {
grad -= arma::vectorise(tQs * (E_A(current_s).slice(t).col(i) -
sum_ea * A1) * X_A.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
return n_obs * maxval;
}
value -= val;

if (!grad.is_empty()) {
grad -= arma::vectorise(tQs * (E_A(current_s).slice(t).col(i) - sum_ea * A1) * X_A.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
return n_obs * maxval;
}
}
}
}
}
Expand Down Expand Up @@ -236,26 +235,24 @@ double nhmm_sc::objective_B(const arma::vec& x, arma::vec& grad) {
for (arma::uword t = 0; t < Ti(i); t++) {
if (obs(t, i) < M) {
double e_b = E_B(t, i, current_s);
if (e_b > 100 * arma::datum::eps) {
if (tv_B) {
B1 = softmax(gamma_Brow * X_B.slice(i).col(t));
log_B1 = log(B1);
if (tv_B) {
B1 = softmax(gamma_Brow * X_B.slice(i).col(t));
log_B1 = log(B1);
}
double val = e_b * log_B1(obs(t, i));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
}
double val = e_b * log_B1(obs(t, i));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
}
return n_obs * maxval;
}
value -= val;
if (!grad.is_empty()) {
grad -= arma::vectorise(tQm * e_b * (I.col(obs(t, i)) - B1) * X_B.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
return n_obs * maxval;
}
value -= val;
if (!grad.is_empty()) {
grad -= arma::vectorise(tQm * e_b * (I.col(obs(t, i)) - B1) * X_B.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
return n_obs * maxval;
}
}
}
}
}
Expand Down Expand Up @@ -353,30 +350,27 @@ double nhmm_mc::objective_B(const arma::vec& x, arma::vec& grad) {
log_B1 = log(B1);
}
for (arma::uword t = 0; t < Ti(i); t++) {

if (obs(current_c, t, i) < Mc) {
double e_b = E_B(current_c)(t, i, current_s);
if (e_b > 100 * arma::datum::eps) {
if (tv_B) {
B1 = softmax(gamma_Brow * X_B.slice(i).col(t));
log_B1 = log(B1);
if (tv_B) {
B1 = softmax(gamma_Brow * X_B.slice(i).col(t));
log_B1 = log(B1);
}
double val = e_b * log_B1(obs(current_c, t, i));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
}
double val = e_b * log_B1(obs(current_c, t, i));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
}
return n_obs * maxval;
}
value -= val;
if (!grad.is_empty()) {
grad -= arma::vectorise(tQm * e_b * (I.col(obs(current_c, t, i)) - B1) *
X_B.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
return n_obs * maxval;
}
value -= val;
if (!grad.is_empty()) {
grad -= arma::vectorise(tQm * e_b * (I.col(obs(current_c, t, i)) - B1) *
X_B.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
return n_obs * maxval;
}
}
}
}
}
Expand Down

0 comments on commit 121d47b

Please sign in to comment.