Skip to content

Commit

Permalink
fanhmm bootstrap wip
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Jan 6, 2025
1 parent 69dcb83 commit a3bd3b5
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 1 deletion.
109 changes: 109 additions & 0 deletions R/ame_obs.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,112 @@ ame_obs.mnhmm <- function(
attr(out, "model") <- "mnhmm"
out
}

#' @rdname ame_obs
#' @export
ame_obs.fanhmm <- function(
model, variable, values, start_time, newdata = NULL, probs, ...) {
stopifnot_(
attr(model, "intercept_only") == FALSE,
"Model does not contain any covariates."
)
stopifnot_(
checkmate::test_string(x = variable),
"Argument {.arg variable} must be a single character string."
)
stopifnot_(
length(values) == 2,
"Argument {.arg values} should contain two values for
variable {.var variable}.")
if (!missing(probs)) {
return_quantiles <- TRUE
stopifnot_(
checkmate::test_numeric(
x = probs, lower = 0, upper = 1, any.missing = FALSE, min.len = 1L
),
"Argument {.arg probs} must be a {.cls numeric} vector with values
between 0 and 1."
)
stopifnot_(
!is.null(model$boot),
paste0(
"Model does not contain bootstrap samples of coefficients. ",
"Run {.fn bootstrap_coefs} first in order to compute quantiles."
)
)
} else {
return_quantiles <- FALSE
# dummy stuff for C++
model$boot <- list(
gamma_pi = list(model$gammas$pi),
gamma_A = list(model$gammas$A),
gamma_B = list(model$gammas$B)
)
probs <- 0.5
model$boot$idx <- matrix(seq_len(model$n_sequences), model$n_sequences, 1)
}

time <- model$time_variable
id <- model$id_variable
if (!is.null(newdata)) {
stopifnot_(
is.data.frame(newdata),
"Argument {.arg newdata} must be a {.cls data.frame} object."
)
stopifnot_(
!is.null(newdata[[id]]),
"Can't find grouping variable {.var {id}} in {.arg newdata}."
)
stopifnot_(
!is.null(newdata[[time]]),
"Can't find time index variable {.var {time}} in {.arg newdata}."
)
stopifnot_(
!is.null(newdata[[variable]]),
"Can't find time variable {.var {variable}} in {.arg newdata}."
)
} else {
stopifnot_(
!is.null(model$data),
"Model does not contain original data and argument {.arg newdata} is
{.var NULL}."
)
newdata <- model$data
}
newdata[[variable]][newdata[[time]] >= start_time] <- values[1]
X1 <- update(model, newdata)[c("X_pi", "X_A", "X_B", "W_A", "W_B")]
newdata[[variable]][newdata[[time]] >= start_time] <- values[2]
X2 <- update(model, newdata)[c("X_pi", "X_A", "X_B", "W_A", "W_B")]
C <- model$n_channels
start <- which(sort(unique(newdata[[time]])) == start_time)
times <- as.numeric(colnames(model$observations))
symbol_names <- list(model$symbol_names)
obs <- create_obsArray(model)[1L, , ]
out <- ame_obs_fanhmm_singlechannel(
model$etas$pi, model$etas$A, model$etas$B, model$rho_A, model$rho_B,
obs, model$sequence_lengths,
attr(X1$X_pi, "icpt_only"), attr(X1$X_A, "icpt_only"),
attr(X1$X_B, "icpt_only"), attr(X1$X_A, "iv"),
attr(X1$X_B, "iv"), attr(X1$X_A, "tv"), attr(X1$X_B, "tv"),
X1$X_pi, X1$X_A, X1$X_B, X1$W_A, X1$W_B,
attr(X2$X_pi, "icpt_only"), attr(X2$X_A, "icpt_only"),
attr(X2$X_B, "icpt_only"), attr(X2$X_A, "iv"),
attr(X2$X_B, "iv"), attr(X2$X_A, "tv"), attr(X2$X_B, "tv"),
X2$X_pi, X2$X_A, X2$X_B, X2$W_A, X2$W_B,
model$boot$gamma_pi, model$boot$gamma_A, model$boot$gamma_B,
model$boot$rho_A, model$boot_rho_B,
start, probs, model$boot$idx - 1L
)
d <- data.frame(
observation = model$symbol_names,
time = rep(as.numeric(colnames(model$observations)), each = model$n_symbols),
estimate = c(out$point_estimate)
)
if (return_quantiles) {
for(i in seq_along(probs)) {
d[paste0("q", 100 * probs[i])] <- c(out$quantiles[, , i])
}
}
colnames(d)[2] <- time
d[d[[time]] >= start_time, ]
}
116 changes: 115 additions & 1 deletion R/bootstrap.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@ bootstrap_model <- function(model) {
model$X_pi[] <- model$X_pi[, idx]
model$X_A[] <- model$X_A[, , idx]
model$X_B[] <- model$X_B[, , idx]
if (!is.null(model$X_omega)) {
if (inherits(model, "mnhmm")) {
model$X_omega[] <- model$X_omega[, idx]
}
if (inherits(model, "fanhmm")) {
model$W_A[] <- model$W_A[, idx]
model$W_B[] <- model$W_B[, idx]
}
model$sequence_lengths <- model$sequence_lengths[idx]
list(model = model, idx = idx)
}
Expand Down Expand Up @@ -316,3 +320,113 @@ bootstrap_coefs.mnhmm <- function(model, nsim = 1000,
}
model
}
#' @rdname bootstrap
#' @export
bootstrap_coefs.fanhmm <- function(model, nsim = 1000,
type = c("nonparametric", "parametric"),
method = "EM-DNM", append = FALSE, ...) {
type <- match.arg(type)
stopifnot_(
checkmate::test_int(x = nsim, lower = 0L),
"Argument {.arg nsim} must be a single positive integer."
)
init <- setNames(model$etas, c("eta_pi", "eta_A", "eta_B", "rho_A", "rho_B"))
gammas_mle <- model$gammas
lambda <- model$estimation_results$lambda
bound <- model$estimation_results$bound
p <- progressr::progressor(along = seq_len(nsim))
original_options <- options(future.globals.maxSize = Inf)
on.exit(options(original_options))
control <- model$controls$control
control$print_level <- 0
control_mstep <- model$controls$mstep
control_mstep$print_level <- 0
if (type == "nonparametric") {
out <- future.apply::future_lapply(
seq_len(nsim), function(i) {
boot_mod <- bootstrap_model(model)
fit <- fit_fanhmm(
boot_mod$model, init, init_sd = 0, restarts = 0, lambda = lambda,
method = method, bound = bound, control = control,
control_restart = list(), control_mstep = control_mstep
)
if (fit$estimation_results$return_code >= 0) {
fit$gammas <- permute_states(fit$gammas, gammas_mle)
} else {
fit$gammas <- NULL
}
p()
list(gammas = fit$gammas, idx = boot_mod$idx)
}, future.seed = TRUE
)
idx <- do.call(cbind, lapply(out, "[[", "idx"))
out <- lapply(out, "[[", "gammas")
} else {
N <- model$n_sequences
T_ <- model$sequence_lengths
M <- model$n_symbols
S <- model$n_states
formula_pi <- model$initial_formula
formula_A <- model$transition_formula
formula_B <- model$emission_formula
formula_rho_A <- model$feedback_formula
formula_rho_B <- model$autoregression_formula
d <- model$data
time <- model$time_variable
id <- model$id_variable
out <- future.apply::future_lapply(
seq_len(nsim), function(i) {
mod <- simulate_fanhmm(
N, T_, M, S, formula_pi, formula_A, formula_B,
formula_rho_B, formula_rho_A,
d, time, id, init, 0)$model
fit <- fit_fanhmm(
mod, init, init_sd = 0, restarts = 0, lambda = lambda,
method = method, bound = bound, control = control,
control_restart = list(), control_mstep = control_mstep
)
if (fit$estimation_results$return_code >= 0) {
fit$gammas <- permute_states(fit$gammas, gammas_mle)
} else {
fit$gammas <- NULL
}
p()
fit$gammas
}, future.seed = TRUE
)
}
boot <- list(
gamma_pi = lapply(out, "[[", "pi"),
gamma_A = lapply(out, "[[", "A"),
gamma_B = lapply(out, "[[", "B"),
phi_A = lapply(out, "[[", "phi_A"),
phi_B = lapply(out, "[[", "phi_B")
)
boot <- lapply(boot, function(x) x[lengths(x) > 0])
if (length(boot[[1]]) < nsim) {
warning_(
paste0(
"Estimation in some of the bootstrap samples failed. ",
"Returning samples from {length(boot[[1]])} successes out of {nsim} ",
"bootstrap samples."
)
)
}
if (type == "nonparametric") {
boot$idx <- idx
} else {
boot$idx <- matrix(seq_len(model$n_sequences), model$n_sequences, nsim)
}
if (append && !is.null(model$boot)) {
model$boot$gamma_pi <- c(model$boot$gamma_pi, boot$gamma_pi)
model$boot$gamma_A <- c(model$boot$gamma_A, boot$gamma_A)
model$boot$gamma_B <- c(model$boot$gamma_B, boot$gamma_B)
model$boot$phi_A <- c(model$boot$phi_A, boot$phi_A)
model$boot$phi_B <- c(model$boot$phi_B, boot$phi_B)
model$boot$idx <- cbind(model$boot$idx, idx)
} else {
model$boot <- boot
}

model
}
40 changes: 40 additions & 0 deletions R/update.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,43 @@ update.mnhmm <- function(object, newdata, ...) {
)$X
object
}
#' @rdname update_nhmm
#' @export
update.fanhmm <- function(object, newdata, ...) {
newdata <- .check_data(newdata, object$time_variable, object$id_variable)
if (!is.null(object$data)) object$data <- newdata
object$X_pi <- model_matrix_initial_formula(
object$initial_formula, newdata, object$n_sequences,
object$length_of_sequences, object$n_states, object$time_variable,
object$id_variable,
attr(object$X_pi, "X_mean"), attr(object$X_pi, "X_sd"), FALSE
)$X
object$X_A <- model_matrix_transition_formula(
object$transition_formula, newdata, object$n_sequences,
object$length_of_sequences, object$n_states, object$time_variable,
object$id_variable, object$sequence_lengths,
attr(object$X_A, "X_mean"), attr(object$X_A, "X_sd"), FALSE
)$X
object$X_B <- model_matrix_emission_formula(
object$emission_formula, newdata, object$n_sequences,
object$length_of_sequences, object$n_states, object$n_symbols,
object$time_variable, object$id_variable,
object$sequence_lengths,
attr(object$X_B, "X_mean"), attr(object$X_B, "X_sd"), FALSE
)$X
object$W_A <- model_matrix_feedback_formula(
object$feedback_formula, newdata, object$n_sequences,
object$length_of_sequences, object$n_states, object$n_symbols,
object$time_variable, object$id_variable,
object$sequence_lengths,
attr(object$W_A, "X_mean"), attr(object$W_A, "X_sd"), FALSE
)$X
object$W_B <- model_matrix_autoregression_formula(
object$autoregression_formula, newdata, object$n_sequences,
object$length_of_sequences, object$n_states, object$n_symbols,
object$time_variable, object$id_variable,
object$sequence_lengths,
attr(object$W_B, "X_mean"), attr(object$W_B, "X_sd"), FALSE
)$X
object
}

0 comments on commit a3bd3b5

Please sign in to comment.