Skip to content

Commit

Permalink
const etas, always use lbfgs in EM, return inf and zero for nonfinite…
Browse files Browse the repository at this point in the history
… likelihood
  • Loading branch information
helske committed Dec 2, 2024
1 parent 992208c commit 39a16d2
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 149 deletions.
16 changes: 8 additions & 8 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,16 +564,16 @@ BEGIN_RCPP
END_RCPP
}
// EM_LBFGS_nhmm_singlechannel
Rcpp::List EM_LBFGS_nhmm_singlechannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double pseudocount);
Rcpp::List EM_LBFGS_nhmm_singlechannel(const arma::mat& eta_pi, const arma::mat& X_pi, const arma::cube& eta_A, const arma::cube& X_A, const arma::cube& eta_B, const arma::cube& X_B, const arma::umat& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double pseudocount);
RcppExport SEXP _seqHMM_EM_LBFGS_nhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP, SEXP pseudocountSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< arma::mat& >::type eta_pi(eta_piSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type eta_pi(eta_piSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type X_pi(X_piSEXP);
Rcpp::traits::input_parameter< arma::cube& >::type eta_A(eta_ASEXP);
Rcpp::traits::input_parameter< const arma::cube& >::type eta_A(eta_ASEXP);
Rcpp::traits::input_parameter< const arma::cube& >::type X_A(X_ASEXP);
Rcpp::traits::input_parameter< arma::cube& >::type eta_B(eta_BSEXP);
Rcpp::traits::input_parameter< const arma::cube& >::type eta_B(eta_BSEXP);
Rcpp::traits::input_parameter< const arma::cube& >::type X_B(X_BSEXP);
Rcpp::traits::input_parameter< const arma::umat& >::type obs(obsSEXP);
Rcpp::traits::input_parameter< const arma::uvec& >::type Ti(TiSEXP);
Expand Down Expand Up @@ -604,16 +604,16 @@ BEGIN_RCPP
END_RCPP
}
// EM_LBFGS_nhmm_multichannel
Rcpp::List EM_LBFGS_nhmm_multichannel(arma::mat& eta_pi, const arma::mat& X_pi, arma::cube& eta_A, const arma::cube& X_A, arma::field<arma::cube>& eta_B, const arma::cube& X_B, const arma::ucube& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double pseudocount);
Rcpp::List EM_LBFGS_nhmm_multichannel(const arma::mat& eta_pi, const arma::mat& X_pi, const arma::cube& eta_A, const arma::cube& X_A, const arma::field<arma::cube>& eta_B, const arma::cube& X_B, const arma::ucube& obs, const arma::uvec& Ti, const bool icpt_only_pi, const bool icpt_only_A, const bool icpt_only_B, const bool iv_A, const bool iv_B, const bool tv_A, const bool tv_B, const arma::uword n_obs, const arma::uword maxeval, const double ftol_abs, const double ftol_rel, const double xtol_abs, const double xtol_rel, const arma::uword print_level, const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m, const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m, const double lambda, const double pseudocount);
RcppExport SEXP _seqHMM_EM_LBFGS_nhmm_multichannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obsSEXP, SEXP TiSEXP, SEXP icpt_only_piSEXP, SEXP icpt_only_ASEXP, SEXP icpt_only_BSEXP, SEXP iv_ASEXP, SEXP iv_BSEXP, SEXP tv_ASEXP, SEXP tv_BSEXP, SEXP n_obsSEXP, SEXP maxevalSEXP, SEXP ftol_absSEXP, SEXP ftol_relSEXP, SEXP xtol_absSEXP, SEXP xtol_relSEXP, SEXP print_levelSEXP, SEXP maxeval_mSEXP, SEXP ftol_abs_mSEXP, SEXP ftol_rel_mSEXP, SEXP xtol_abs_mSEXP, SEXP xtol_rel_mSEXP, SEXP print_level_mSEXP, SEXP lambdaSEXP, SEXP pseudocountSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< arma::mat& >::type eta_pi(eta_piSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type eta_pi(eta_piSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type X_pi(X_piSEXP);
Rcpp::traits::input_parameter< arma::cube& >::type eta_A(eta_ASEXP);
Rcpp::traits::input_parameter< const arma::cube& >::type eta_A(eta_ASEXP);
Rcpp::traits::input_parameter< const arma::cube& >::type X_A(X_ASEXP);
Rcpp::traits::input_parameter< arma::field<arma::cube>& >::type eta_B(eta_BSEXP);
Rcpp::traits::input_parameter< const arma::field<arma::cube>& >::type eta_B(eta_BSEXP);
Rcpp::traits::input_parameter< const arma::cube& >::type X_B(X_BSEXP);
Rcpp::traits::input_parameter< const arma::ucube& >::type obs(obsSEXP);
Rcpp::traits::input_parameter< const arma::uvec& >::type Ti(TiSEXP);
Expand Down
22 changes: 11 additions & 11 deletions src/mnhmm_EM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ double mnhmm_base::objective_omega(const arma::vec& x, arma::vec& grad) {
double val = arma::dot(counts.rows(idx), log_omega.rows(idx));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
grad.zeros();
}
return maxval;
}
Expand All @@ -43,7 +43,7 @@ double mnhmm_base::objective_omega(const arma::vec& x, arma::vec& grad) {
diff.rows(idx) = counts(idx) - sum_eo * omega.rows(idx);
grad -= arma::vectorise(tQd * diff * X_omega.col(i).t());
if (!grad.is_finite()) {
grad.fill(maxval);
grad.zeros();
return maxval;
}
}
Expand Down Expand Up @@ -127,18 +127,18 @@ double mnhmm_base::objective_pi(const arma::vec& x, arma::vec& grad) {
double val = arma::dot(counts.rows(idx), log_pi(current_d).rows(idx));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
grad.zeros();
}
return maxval;
}
value -= val;
// Only update grad if it's non-empty (i.e., for gradient-based optimization)
if (!grad.is_empty()) {
diff.zeros();
diff = counts.rows(idx) - sum_epi * pi(current_d).rows(idx);
diff.rows(idx) = counts.rows(idx) - sum_epi * pi(current_d).rows(idx);
grad -= arma::vectorise(tQs * diff * X_pi.col(i).t());
if (!grad.is_finite()) {
grad.fill(maxval);
grad.zeros();
return maxval;
}
}
Expand Down Expand Up @@ -241,7 +241,7 @@ double mnhmm_base::objective_A(const arma::vec& x, arma::vec& grad) {
double val = arma::dot(counts.rows(idx), log_A1.rows(idx));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
grad.zeros();
}
return maxval;
}
Expand All @@ -251,7 +251,7 @@ double mnhmm_base::objective_A(const arma::vec& x, arma::vec& grad) {
diff.rows(idx) = counts.rows(idx) - sum_ea * A1.rows(idx);
grad -= arma::vectorise(tQs * diff * X_A.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
grad.zeros();
return maxval;
}
}
Expand Down Expand Up @@ -365,7 +365,7 @@ double mnhmm_sc::objective_B(const arma::vec& x, arma::vec& grad) {
double val = e_b * log_B1(obs(t, i));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
grad.zeros();
}
return maxval;
}
Expand All @@ -374,7 +374,7 @@ double mnhmm_sc::objective_B(const arma::vec& x, arma::vec& grad) {
grad -= arma::vectorise(tQm *
e_b * (I.col(obs(t, i)) - B1) * X_B.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
grad.zeros();
return maxval;
}
}
Expand Down Expand Up @@ -493,7 +493,7 @@ double mnhmm_mc::objective_B(const arma::vec& x, arma::vec& grad) {
double val = e_b * log_B1(obs(current_c, t, i));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(maxval);
grad.zeros();
}
return maxval;
}
Expand All @@ -502,7 +502,7 @@ double mnhmm_mc::objective_B(const arma::vec& x, arma::vec& grad) {
grad -= arma::vectorise(tQm * e_b * (I.col(obs(current_c, t, i)) - B1) *
X_B.slice(i).col(t).t());
if (!grad.is_finite()) {
grad.fill(maxval);
grad.zeros();
return maxval;
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/mnhmm_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ struct mnhmm_base {
const bool iv_B_,
const bool tv_A_,
const bool tv_B_,
arma::mat& eta_omega_,
arma::field<arma::mat>& eta_pi_,
arma::field<arma::cube>& eta_A_,
const arma::mat& eta_omega_,
const arma::field<arma::mat>& eta_pi_,
const arma::field<arma::cube>& eta_A_,
const arma::uword n_obs_ = 0,
const double lambda_ = 0,
double maxval_ = 1e6)
Expand Down
8 changes: 4 additions & 4 deletions src/mnhmm_mc.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ struct mnhmm_mc : public mnhmm_base {
const bool tv_A_,
const bool tv_B_,
const arma::ucube& obs_,
arma::mat& eta_omega_,
arma::field<arma::mat>& eta_pi_,
arma::field<arma::cube>& eta_A_,
arma::field<arma::cube>& eta_B_,
const arma::mat& eta_omega_,
const arma::field<arma::mat>& eta_pi_,
const arma::field<arma::cube>& eta_A_,
const arma::field<arma::cube>& eta_B_,
const arma::uword n_obs_ = 0,
const double lambda_ = 0)
: mnhmm_base(S_, D_, X_d_, X_pi_, X_s_, X_o_, Ti_, icpt_only_omega_,
Expand Down
8 changes: 4 additions & 4 deletions src/mnhmm_sc.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ struct mnhmm_sc : public mnhmm_base {
const bool tv_A_,
const bool tv_B_,
const arma::umat& obs_,
arma::mat& eta_omega_,
arma::field<arma::mat>& eta_pi_,
arma::field<arma::cube>& eta_A_,
arma::field<arma::cube>& eta_B_,
const arma::mat& eta_omega_,
const arma::field<arma::mat>& eta_pi_,
const arma::field<arma::cube>& eta_A_,
const arma::field<arma::cube>& eta_B_,
const arma::uword n_obs_ = 0,
const double lambda_ = 0)
: mnhmm_base(S_, D_, X_d_, X_pi_, X_s_, X_o_, Ti_, icpt_only_omega_,
Expand Down
Loading

0 comments on commit 39a16d2

Please sign in to comment.