Skip to content

Commit

Permalink
add stan utils
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasjclark committed Jun 10, 2022
1 parent 1d56920 commit 68adf05
Show file tree
Hide file tree
Showing 5 changed files with 1,766 additions and 0 deletions.
308 changes: 308 additions & 0 deletions R/add_stan_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
#' Add remaining data, model and parameter blocks to a Stan model
#'
#'
#' @export
#' @param jags_file Prepared JAGS mvgam model file
#' @param stan_file Incomplete Stan model file to be edited
#' @param jags_data Prepared mvgam data for JAGS modelling
#' @return A `list` containing the updated Stan model and model data
add_stan_data = function(jags_file, stan_file, jags_data, family = 'poisson'){

#### Modify the Stan file ####
# Update lines associated with particular family
if(family == 'poisson'){

}

if(family == 'nb'){
stan_file[grep('// raw basis', stan_file) + 2] <-
'\n// negative binomial overdispersion\nvector<lower=0>[n_series] r_inv;\n'

stan_file[grep('// priors for smoothing', stan_file) + 2] <-
'\n// priors for overdispersion parameters\nr_inv ~ normal(0, 10);\n'

to_negbin <- gsub('poisson_log', 'neg_binomial_2',
stan_file[grep('y[i, s] ~ poisson', stan_file, fixed = T)])
stan_file[grep('y[i, s] ~ poisson', stan_file, fixed = T)] <-
gsub(');', ', inv(r_inv[s]));', to_negbin)

add_exp_open <- gsub('\\(eta', '(exp(eta',
stan_file[grep('y[i, s] ~ neg_binomial', stan_file, fixed = T)])
add_exp_cl <- gsub('],', ']),',
add_exp_open)
stan_file[grep('y[i, s] ~ neg_binomial', stan_file, fixed = T)] <-
add_exp_cl

stan_file[grep('// posterior predictions', stan_file, fixed = T) - 1] <-
paste0('matrix[n, n_series] r_vec;\n',
'vector[n_series] r;\n',
'r = inv(r_inv);\n',
'for (s in 1:n_series) {\n',
'r_vec[1:n,s] = rep_vector(inv(r_inv[s]), n);\n}\n')

to_negbin <- gsub('poisson_log_rng', 'neg_binomial_2_rng',
stan_file[grep('ypred = poisson_log_rng', stan_file, fixed = T)])
stan_file[grep('ypred = poisson_log_rng', stan_file, fixed = T)] <-
gsub(');', ', to_vector(r_vec));', to_negbin)

add_exp_open <- gsub('\\(eta', '(exp(eta',
stan_file[grep('ypred = neg_binomial', stan_file, fixed = T)])

if(any(grepl('to_vector(trend)', stan_file, fixed = T))){
add_exp_cl <- gsub('to_vector(trend)', 'to_vector(trend))',
add_exp_open, fixed = T)
} else {
add_exp_cl <- gsub('eta)', 'eta,),',
add_exp_open)
}

stan_file[grep('ypred = neg_binomial', stan_file, fixed = T)] <-
add_exp_cl

stan_file <- readLines(textConnection(stan_file), n = -1)
}

# Get dimensions and numbers of smooth terms
snames <- names(jags_data)[grep('S.*', names(jags_data))]
smooth_dims <- matrix(NA, ncol = 2, nrow = length(snames))
for(i in 1:length(snames)){
smooth_dims[i,] <- dim(jags_data[[snames[i]]])
}

# Insert the data block for the model
smooth_penalty_data <- vector()
for(i in 1:length(snames)){
smooth_penalty_data[i] <- paste0('matrix[', smooth_dims[i, 1],
',',
smooth_dims[i, 2], '] ',
snames[i],
'; // mgcv smooth penalty matrix ', snames[i])
}

# Search for any non-contiguous indices that sometimes are used by mgcv
if(any(grep('in c\\(', jags_file))){
add_idxs <- TRUE
seq_character = function(x){
all_nums <- as.numeric(unlist(strsplit(x, ':')))
if(length(all_nums) > 1){
out <- seq(all_nums[1], all_nums[2])
} else {
out <- all_nums
}
out
}

idx_locations <- grep('in c\\(', jags_file)
idx_vals <- list()
idx_data <- vector()
for(i in 1:length(idx_locations)){
list_vals <- unlist(strsplit(gsub('^.*c\\(*|\\s*).*$', '', jags_file[idx_locations[i]]), ','))
idx_vals[[i]] <- unlist(lapply(list_vals, seq_character))
idx_data[i] <- paste0('int idx', i, '[', length(idx_vals[[i]]), ']; // discontiguous index values')
jags_file[idx_locations][i] <- sub("in.*\\)\\)", paste0("in idx", i, ')'), jags_file[idx_locations][i])
}

# Update the Stan data block
stan_file[grep('##insert data',
stan_file)] <- paste0('//Stan code generated by package mvgam\n',
'data {',
'\n',
paste0(idx_data, collapse = '\n'), '\n',
'int<lower=0> total_obs; // total number of observations\n',
'int<lower=0> n; // number of timepoints per series\n',
'int<lower=0> n_sp; // number of smoothing parameters\n',
'int<lower=0> n_series; // number of series\n',
'int<lower=0> num_basis; // total number of basis coefficients\n',
'vector[num_basis] zero; // priro locations for basis coefficients\n',
'matrix[num_basis, total_obs] X; // transposed mgcv GAM design matrix\n',
'int<lower=0> ytimes[n, n_series]; // time-ordered matrix (which col in X belongs to each [time, series] observation?)\n',
paste0(smooth_penalty_data, collapse = '\n'), '\n',
'int<lower=0, upper=1> y_observed[n, n_series]; // indices of missing vs observed\n',
'int<lower=-1> y[n, n_series]; // time-ordered observations, with -1 indicating missing\n',
'}\n')
} else {
add_idxs <- FALSE
stan_file[grep('##insert data',
stan_file)] <- paste0('//Stan code generated by package mvgam\n',
'data {',
'\n',
'int<lower=0> total_obs; // total number of observations\n',
'int<lower=0> n; // number of timepoints per series\n',
'int<lower=0> n_sp; // number of smoothing parameters\n',
'int<lower=0> n_series; // number of series\n',
'int<lower=0> num_basis; // total number of basis coefficients\n',
'vector[num_basis] zero; // prior locations for basis coefficients\n',
'matrix[num_basis, total_obs] X; // transposed mgcv GAM design matrix\n',
'int<lower=0> ytimes[n, n_series]; // time-ordered matrix (which col in X belongs to each [time, series] observation?)\n',
paste0(smooth_penalty_data, collapse = '\n'), '\n',
'int<lower=0, upper=1> y_observed[n, n_series]; // indices of missing vs observed\n',
'int<lower=-1> y[n, n_series]; // time-ordered observations, with -1 indicating missing\n',
'}\n')
}
stan_file <- readLines(textConnection(stan_file), n = -1)

# Modify the model block to include each smooth term
smooths_start <- grep('## GAM-specific priors', jags_file) + 1
smooths_end <- grep('## smoothing parameter priors...', jags_file) - 1
jags_smooth_text <- jags_file[smooths_start:smooths_end]
jags_smooth_text <- gsub('##', '//', jags_smooth_text)
jags_smooth_text <- gsub('dexp', 'exponential', jags_smooth_text)

K_starts <- grep('K.* <- ', jags_smooth_text)
for(i in 1:length(K_starts)){
jags_smooth_text[K_starts[i]+1] <- gsub('\\bb\\b', 'b_raw',
gsub('dmnorm', 'multi_normal_prec',
paste0(gsub('K.*',
trimws(gsub('K.* <- ', '',
jags_smooth_text[K_starts[i]])),
jags_smooth_text[K_starts[i]+1]), ')')))
}
jags_smooth_text <- jags_smooth_text[-K_starts]
if(any(grep('b\\[i\\] <- b_raw', jags_smooth_text))){
jags_smooth_text <- jags_smooth_text[-grep('b\\[i\\] <- b_raw', jags_smooth_text)]
}
jags_smooth_text <- gsub('dnorm', 'normal', jags_smooth_text)
jags_smooth_text <- gsub(' ', ' ', jags_smooth_text)
jags_smooth_text[-grep('//|\\}|\\{', jags_smooth_text)] <-
paste0(jags_smooth_text[-grep('//|\\}|\\{', jags_smooth_text)], ';')
jags_smooth_text <- gsub(') }', '); }', jags_smooth_text)
jags_smooth_text <- gsub('}', '}\n', jags_smooth_text)
jags_smooth_text[(grep('//',
jags_smooth_text) - 1)[-1]] <-
paste0(jags_smooth_text[(grep('//',
jags_smooth_text) - 1)[-1]], '\n')
stan_file[grep('##insert smooths',
stan_file)] <- paste0(jags_smooth_text, collapse = '\n')
stan_file <- readLines(textConnection(stan_file), n = -1)

# Deal with any random effect priors
if(any(grep('b_raw\\[i\\] ~', stan_file))){
b_raw_string <- paste0(stan_file[grep('b_raw\\[i\\] ~', stan_file)-1], collapse = ',')
n_b_raw <- max(as.numeric(unlist(regmatches(b_raw_string,
gregexpr("[[:digit:]]+",
b_raw_string)))))

n_sigma_raw <- max(as.numeric(unlist(regmatches(grep('sigma_raw', stan_file, value = T),
gregexpr("[[:digit:]]+",
grep('sigma_raw',
stan_file, value = T))))))


stan_file <- stan_file[-grep('mu_raw.* ~ ', stan_file)]
stan_file <- stan_file[-grep('<- mu_raw', stan_file)]
stan_file <- stan_file[-grep('sigma_raw.* ~ ', stan_file)]
stan_file[grep('model \\{', stan_file)] <-
paste0('model {\n// prior for random effect population variances\nsigma_raw ~ exponential(0.5);\n\n',
'// prior for random effect population means\nmu_raw ~ normal(0, 1);\n')

stan_file[grep('parameters \\{', stan_file)[1] + 2] <-
paste0(stan_file[grep('parameters \\{', stan_file)[1] + 2],
'\n',
'\n// random effect variances\n',
paste0('vector<lower=0>[',n_sigma_raw,'] sigma_raw', ';\n', collapse = ''),
'\n',
paste0('vector<lower=0>[',n_sigma_raw,'] mu_raw', ';\n', collapse = ''))

b_raw_text <- vector()
b_raw_indices <- grep('b_raw\\[i\\] ~', stan_file)
for(i in 1:length(b_raw_indices)){

b_raw_text[i] <- paste0('for (i in ', as.numeric(sub('.*(?=.$)', '',
sub("\\:.*", "",
stan_file[b_raw_indices[i] - 1]), perl=T)),
':', as.numeric(substr(sub(".*\\:", "",
stan_file[b_raw_indices[i]-1]),
1, 1)),') {\nb[i] <- mu_raw[', i, '] + b_raw[i] * sigma_raw[',i,
'];\n}')
}

# If parametric coefficients are included, they'll come before random effects
min_re_betas <- as.numeric(sub('.*(?=.$)', '',
sub("\\:.*", "",
stan_file[b_raw_indices[1]-1]), perl=T))
if(min_re_betas > 1){
b_raw_text <- c(paste0('\nfor (i in 1:',
min_re_betas - 1, ') {\nb[i] <- b_raw[i];\n}'),
b_raw_text,
paste0('\nfor (i in ', n_b_raw+1,':num_basis) {\nb[i] <- b_raw[i];\n}\n'))
} else {
b_raw_text <- c(b_raw_text,
paste0('\nfor (i in ', n_b_raw+1,':num_basis) {\nb[i] <- b_raw[i];\n}\n'))
}

stan_file[grep('// basis coefficients', stan_file) + 2] <- paste0(b_raw_text,
collapse = '\n')
stan_file <- readLines(textConnection(stan_file), n = -1)

# If no random effects, betas are equal to beta_raws
} else {
stan_file[grep('// basis coefficients', stan_file) + 2] <-
paste0('\nfor (i in ','1:num_basis) {\nb[i] <- b_raw[i];\n}')
stan_file <- readLines(textConnection(stan_file), n = -1)
}

# Update parametric effect priors
if(any(grep('// parametric effect', stan_file))){
stan_file[grep('// parametric effect', stan_file) + 1] <-
paste0('for (i in ',

as.numeric(sub('.*(?=.$)', '',
sub("\\:.*", "",
stan_file[grep('// parametric effect', stan_file) + 1]), perl=T)),
':', as.numeric(substr(sub(".*\\:", "",
stan_file[grep('// parametric effect', stan_file) + 1]),
1, 1)),
') {\nb_raw[i] ~ normal(0, 1);\n}')
stan_file <- readLines(textConnection(stan_file), n = -1)
}
unlink('base_gam_stan.txt')
stan_file <- readLines(textConnection(stan_file), n = -1)

# Final tidying of the Stan model for readability
clean_up <- vector()
for(x in 1:length(stan_file)){
clean_up[x] <- stan_file[x-1] == "" & stan_file[x] == ""
}
clean_up[is.na(clean_up)] <- FALSE
stan_file <- stan_file[!clean_up]


#### Modify the Stan data list ####
# Create matrix representing whether an observation was missing or not
y_observed <- matrix(NA, ncol = NCOL(jags_data$y),
nrow = NROW(jags_data$y))
for (i in 1:dim(jags_data$y)[1]) {
for (s in 1:dim(jags_data$y)[2]) {
if (is.na(jags_data$y[i, s])) {
y_observed[i, s] = 0
} else {
y_observed[i, s] = 1
}
}
}

# Use -1 for any missing observations so Stan doesn't throw errors due to NAs
y <- jags_data$y
y[is.na(y)] <- -1

# The data list for Stan
stan_data <- jags_data
stan_data$y <- y
stan_data$y_observed <- y_observed
stan_data$X <- t(stan_data$X)
stan_data$total_obs <- NCOL(stan_data$X)
stan_data$num_basis <- NROW(stan_data$X)
stan_data$n_sp <- as.numeric(sub('\\) \\{', '',
sub('for \\(i in 1\\:', '',
jags_file[grep('lambda\\[i\\] ~ ',
trimws(jags_file)) - 1])))

# Add discontiguous index values if required
if(add_idxs){
names(idx_vals) <- paste0('idx', seq_len(length(idx_vals)))
stan_data <- append(stan_data, idx_vals)
}

return(list(stan_file = stan_file,
model_data = stan_data))
}
Loading

0 comments on commit 68adf05

Please sign in to comment.