From a2423dacd3aac611fd273424866d3cc0b6d3f1d4 Mon Sep 17 00:00:00 2001 From: Isaac Gravestock Date: Tue, 3 Dec 2024 13:45:39 +0100 Subject: [PATCH] hash cached models --- R/utilities.R | 31 ++++++++++++++++++++++++++++++- tests/testthat/test-utilities.R | 24 ++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/R/utilities.R b/R/utilities.R index b4b3f2e1e..fcbd428fe 100644 --- a/R/utilities.R +++ b/R/utilities.R @@ -532,6 +532,33 @@ ensure_rstan <- function() { } } +#' Get session hash +#' +#' Gets a unique string based on the current R version and relevant packages. +#' @keywords internal +get_session_hash <- function() { + pkg_versions <- vapply( + sessionInfo(c("rbmi", "rstan", "Rcpp", "RcppEigen", "BH"))[["otherPkgs"]], + function(x) x[["Version"]], + character(1L) + ) + version_string <- paste0(R.version.string, paste0(names(pkg_versions), pkg_versions, collapse = ":")) + temp_file <- tempfile() + writeLines(version_string, temp_file) + hash <- tools::md5sum(temp_file) + unlist(temp_file) + return(hash) +} + +tidy_up_models <- function(cache_dir, keep_hash = NULL) { + files <- list.files(cache_dir, pattern = "(MMRM_).*(\\.stan|\\.rds)", full.names = TRUE) + if (!is.null(keep_hash)) { + keep_pattern <- paste0("(MMRM_", keep_hash, "(\\.stan|\\.rds)") + files <- grep(keep_pattern, files, invert = TRUE, value = TRUE) + } + unlink(files) +} + #' Get Compiled Stan Object #' #' Gets a compiled Stan object that can be used with `rstan::sampling()` @@ -549,11 +576,13 @@ get_stan_model <- function() { } cache_dir <- getOption("rbmi.cache_dir") dir.create(cache_dir, showWarnings = FALSE, recursive = TRUE) - file_loc_cache <- file.path(cache_dir, "MMRM.stan") + session_hash <- get_session_hash() + file_loc_cache <- file.path(cache_dir, paste0("MMRM_", session_hash, ".stan")) if (!file.exists(file_loc_cache)) { message("Compiling Stan model please wait...") } file.copy(file_loc, file_loc_cache, overwrite = TRUE) + tidy_up_models(cache_dir, keep_hash = session_hash) rstan::stan_model( file = file_loc_cache, auto_write = TRUE, diff --git a/tests/testthat/test-utilities.R b/tests/testthat/test-utilities.R index 740617698..7534a4113 100644 --- a/tests/testthat/test-utilities.R +++ b/tests/testthat/test-utilities.R @@ -241,3 +241,27 @@ test_that("Stack", { expect_equal(mstack$pop(3), list(7)) expect_error(mstack$pop(1), "items to return") }) + + +test_that("tidy_up_models", { + td <- tempdir() + files <- c( + file.path(td, "MMRM_123.rds"), + file.path(td, "MMRM_123.stan"), + file.path(td, "MMRM_456.stan"), + file.path(td, "MMRM_456.rds"), + file.path(td, "MMRM_456.log") + ) + expect_equal(file.create(files), rep(TRUE, 5)) + tidy_up_models(td, keep_hash = "123") + expect_equal( + file.exists(files), + c(TRUE, TRUE, FALSE, FALSE, TRUE) + ) + tidy_up_models(td) + expect_equal( + file.exists(files), + c(FALSE, FALSE, FALSE, FALSE, TRUE) + ) + file.remove(files[5]) +}) \ No newline at end of file