Skip to content

Commit

Permalink
rewrite error handling, fix gradient of eta_pi with pseudocounts
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Nov 10, 2024
1 parent 0964ebc commit a21bc03
Show file tree
Hide file tree
Showing 12 changed files with 218 additions and 86 deletions.
2 changes: 1 addition & 1 deletion R/estimate_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ estimate_mnhmm <- function(
transition_formula = ~1, emission_formula = ~1, cluster_formula = ~1,
data = NULL, time = NULL, id = NULL, state_names = NULL,
channel_names = NULL, cluster_names = NULL, inits = "random", init_sd = 2,
restarts = 0L, lambda = 0, method = "EM", pseudocount = 0,
restarts = 0L, lambda = 0, method = "EM", pseudocount = 1e-4,
store_data = TRUE, ...) {

call <- match.call()
Expand Down
4 changes: 2 additions & 2 deletions R/estimate_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
#' algorithm with L-BFGS in the M-step. Another option is `"DNM"` which uses
#' direct maximization of the log-likelihood using [nloptr::nloptr()].
#' @param pseudocount A positive scalar to be added for the expected counts of
#' E-step. Only used in EM algorithm. Default is 0. Larger values can be used
#' E-step. Only used in EM algorithm. Default is 1e-4. Larger values can be used
#' to avoid zero probabilities in initial, transition, and emission
#' probabilities, i.e. these have similar role as `lambda`.
#' @param store_data If `TRUE` (default), original data frame passed as `data`
Expand Down Expand Up @@ -92,7 +92,7 @@ estimate_nhmm <- function(
transition_formula = ~1, emission_formula = ~1,
data = NULL, time = NULL, id = NULL, state_names = NULL, channel_names = NULL,
inits = "random", init_sd = 2, restarts = 0L, lambda = 0, method = "EM",
pseudocount = 0, store_data = TRUE, ...) {
pseudocount = 1e-4, store_data = TRUE, ...) {

call <- match.call()

Expand Down
14 changes: 10 additions & 4 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' @noRd
fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method,
pseudocount = 0, save_all_solutions = FALSE,
pseudocount, save_all_solutions = FALSE,
control_restart = list(), control_mstep = list(), ...) {
stopifnot_(
checkmate::test_int(x = restarts, lower = 0L),
Expand Down Expand Up @@ -215,6 +215,11 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method,
logliks <- -unlist(lapply(out, "[[", "objective")) * n_obs
return_codes <- unlist(lapply(out, "[[", "status"))
successful <- which(return_codes > 0)
stopifnot_(
length(successful) > 0,
c("All optimizations terminated due to error.",
"Error of first restart: ", error_msg(return_codes[1]))
)
optimum <- successful[which.max(logliks[successful])]
init <- out[[optimum]]$solution
if (save_all_solutions) {
Expand All @@ -230,9 +235,10 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method,
opts = control
)
end_time <- proc.time()
if (out$status < 0) {
warning_(paste("Optimization terminated due to error:", out$message))
}
stopifnot_(
out$status >= 0,
paste("Optimization terminated due to error:", error_msg(out$status))
)
pars <- out$solution
model$etas$pi <- create_eta_pi_mnhmm(
pars[seq_len(n_i)], S, K_pi, D
Expand Down
31 changes: 23 additions & 8 deletions R/fit_nhmm.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Estimate a Non-homogeneous Hidden Markov Model
#'
#' @noRd
fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocount = 0,
fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocount,
save_all_solutions = FALSE, control_restart = list(),
control_mstep = list(), ...) {

Expand Down Expand Up @@ -167,6 +167,11 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun
logliks <- -unlist(lapply(out, "[[", "objective")) * n_obs
return_codes <- unlist(lapply(out, "[[", "status"))
successful <- which(return_codes > 0)
stopifnot_(
length(successful) > 0,
c("All optimizations terminated due to error.",
"Error of first restart: ", error_msg(return_codes[1]))
)
optimum <- successful[which.max(logliks[successful])]
init <- out[[optimum]]$solution
if (save_all_solutions) {
Expand All @@ -183,9 +188,11 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun
opts = control
)
end_time <- proc.time()
if (out$status < 0) {
warning_(paste("Optimization terminated due to error:", out$message))
}
stopifnot_(
out$status >= 0,
paste("Optimization terminated due to error:", error_msg(out$status))
)

pars <- out$solution
model$etas$pi <- create_eta_pi_nhmm(pars[seq_len(n_i)], S, K_pi)
model$gammas$pi <- eta_to_gamma_mat(model$etas$pi)
Expand Down Expand Up @@ -247,7 +254,12 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun
}
},
future.seed = TRUE)

return_codes <- unlist(lapply(out, "[[", "return_code"))
stopifnot_(
any(return_codes == 0),
c("All optimizations terminated due to error.",
"Error of first restart: ", error_msg(return_codes[1]))
)
logliks <- unlist(lapply(out, "[[", "penalized_logLik")) * n_obs
optimum <- out[[which.max(logliks)]]
init <- stats::setNames(
Expand Down Expand Up @@ -285,9 +297,10 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun
control_mstep$print_level, lambda, pseudocount)
}
end_time <- proc.time()
# if (out$status < 0) {
# warning_(paste("Optimization terminated due to error:", out$message))
# }
stopifnot_(
out$return_code == 0,
paste("Optimization terminated due to error:", error_msg(out$return_code))
)

model$etas$pi[] <- out$eta_pi
model$gammas$pi <- eta_to_gamma_mat(model$etas$pi)
Expand All @@ -305,7 +318,9 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun
loglik = out$penalized_logLik,
penalty = out$penalty_term,
iterations = out$iterations,
return_code = out$return_code,
logliks_of_restarts = if(restarts > 0L) logliks else NULL,
return_codes_of_restarts = if(restarts > 0L) return_codes else NULL,
all_solutions = all_solutions,
time = end_time - start_time,
f_rel_change = out$relative_f_change,
Expand Down
40 changes: 40 additions & 0 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,43 @@ create_emissionArray <- function(model) {
}
emissionArray
}
#' Convert error message to text
#' @noRd
error_msg <- function(error) {
gamma <- dplyr::case_when(
error %in% c(-(1:4)) ~ "",
error %in% c(1, -(101:104)) ~ "gamma_pi",
error %in% c(2, -(201:204)) ~ "gamma_A",
error == 3 | error %in% -301:-304 ~ "gamma_B"
)

nonfinite_msg <- paste0(
"Error: Some of the values in ", gamma, " are nonfinite, likely due to ",
"zero expected counts. Try increasing the penalty lambda or ",
"pseudocounts to avoid extreme probabilities.")
if (!(error %in% -(1:4))) {
mstep <- paste0("Error in M-step of ", gamma, ". ")
} else {
mstep <- ""
}

e <- seq(0, 300, by = 100)
msg <- dplyr::case_when(
error %in% 1:3 ~ nonfinite_msg,
error %in% (-1 - e) ~ paste0(
mstep, "NLOPT_FAILURE: Generic failure code."
),
error %in% (-2 - e) ~ paste0(
mstep, "NLOPT_INVALID_ARGS: Invalid arguments (e.g., lower bounds are ",
"bigger than upper bounds, an unknown algorithm was specified)."
),
error %in% (-3 - e) ~ paste0(
mstep, "NLOPT_OUT_OF_MEMORY: Ran out of memory."
),
error %in% (-4 - e) ~ paste0(
mstep,
"NLOPT_ROUNDOFF_LIMITED: Halted because roundoff errors limited progress."
)
)
msg
}
4 changes: 2 additions & 2 deletions man/estimate_mnhmm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/estimate_nhmm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion src/mnhmm_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ struct mnhmm_base {
arma::cube log_py;
arma::uword n_obs;
double lambda;

int mstep_iter;
int mstep_error_code;

mnhmm_base(
const arma::uword S_,
const arma::uword D_,
Expand All @@ -65,7 +69,9 @@ struct mnhmm_base {
arma::mat& eta_omega_,
arma::field<arma::mat>& eta_pi_,
arma::field<arma::cube>& eta_A_,
const double lambda = 0)
const double lambda = 0,
int imstep_ter = 0,
int mstep_error_code = 0)
: S(S_),
D(D_),
X_omega(X_d_),
Expand Down
Loading

0 comments on commit a21bc03

Please sign in to comment.