Skip to content

Commit

Permalink
Recursive find of covariate settings for all modules
Browse files Browse the repository at this point in the history
  • Loading branch information
anthonysena committed Nov 13, 2024
1 parent 074ced1 commit ad34028
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 9 deletions.
41 changes: 41 additions & 0 deletions R/StrategusModule.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,28 @@ StrategusModule <- R6::R6Class(
}
private$jobContext$settings <- moduleSpecification$settings

# Make sure that the covariate settings for the analysis are updated
# to reflect the location of the cohort tables
private$jobContext$settings <- .replaceCovariateSettings(
moduleSettings = private$jobContext$settings,
executionSettings = executionSettings
)


# Assemble the job context from the analysis specification
# for the given module.
private$jobContext$sharedResources <- analysisSpecifications$sharedResources
private$jobContext$moduleExecutionSettings <- executionSettings
private$jobContext$moduleExecutionSettings$resultsSubFolder <- file.path(private$jobContext$moduleExecutionSettings$resultsFolder, self$moduleName)
if (!dir.exists(private$jobContext$moduleExecutionSettings$resultsSubFolder)) {
dir.create(private$jobContext$moduleExecutionSettings$resultsSubFolder, showWarnings = F, recursive = T)
}

if (is(private$jobContext$moduleExecutionSettings, "ExecutionSettings")) {
private$jobContext$moduleExecutionSettings$workSubFolder <- file.path(private$jobContext$moduleExecutionSettings$workFolder, self$moduleName)
if (!dir.exists(private$jobContext$moduleExecutionSettings$workSubFolder)) {
dir.create(private$jobContext$moduleExecutionSettings$workSubFolder, showWarnings = F, recursive = T)
}
}
},
.getModuleSpecification = function(analysisSpecifications, moduleName) {
Expand Down Expand Up @@ -324,3 +338,30 @@ StrategusModule <- R6::R6Class(
return(modifiedCovariateSettings)
}

.replaceCovariateSettings <- function(moduleSettings, executionSettings) {
errorMessages <- checkmate::makeAssertCollection()
checkmate::assertList(moduleSettings, min.len = 1, add = errorMessages)
checkmate::assertClass(executionSettings, "ExecutionSettings", add = errorMessages)
checkmate::reportAssertions(collection = errorMessages)

# A helper function to perform the replacement
for (i in seq_along(length(moduleSettings)))
replaceHelper <- function(x) {
if (is.list(x) && inherits(x, "covariateSettings")) {
# If the element is a list and of type covariate settings
# replace the cohort table names
return(.replaceCovariateSettingsCohortTableNames(x, executionSettings))
} else if (is.list(x)) {
# If the element is a list, recurse on each element
return(lapply(x, replaceHelper))
} else {
# If the element is not a list or "covariateSettings", return it as is
return(x)
}
}

# Call the helper function on the input list
return(replaceHelper(moduleSettings))
}


36 changes: 27 additions & 9 deletions tests/testthat/test-Settings.R
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,10 @@ test_that("Create results data model settings", {
})

test_that("Test internal function for modifying covariate settings", {
# Create module settings that contain a combination of
# 1) covariate settings that do not contain cohort table settings
# 2) covariate settings that contain cohort table settings
# 3) a list of covariate setting that has 1 & 2 above
cov1 <- FeatureExtraction::createDefaultCovariateSettings()
cov2 <- FeatureExtraction::createCohortBasedCovariateSettings(
analysisId = 999,
Expand All @@ -419,6 +423,16 @@ test_that("Test internal function for modifying covariate settings", {
)
)
covariateSettings <- list(cov1, cov2)
moduleSettings <- list(
analysis = list(
something = covariateSettings,
somethingElse = list(
nested1 = cov1,
nested2 = cov2,
nested3 = covariateSettings
)
)
)
workDatabaseSchema <- "foo"
cohortTableNames <- CohortGenerator::getCohortTableNames(cohortTable = "unit_test")
executionSettings <- createCdmExecutionSettings(
Expand All @@ -429,16 +443,20 @@ test_that("Test internal function for modifying covariate settings", {
resultsFolder = "temp"
)

test1 <- .replaceCovariateSettingsCohortTableNames(covariateSettings, executionSettings)
expect_equal(test1[[2]]$covariateCohortDatabaseSchema, workDatabaseSchema)
expect_equal(test1[[2]]$covariateCohortTable, cohortTableNames$cohortTable)
testReplacedModuleSettings <- .replaceCovariateSettings(moduleSettings, executionSettings)
expect_equal(testReplacedModuleSettings$analysis$something[[1]]$covariateCohortDatabaseSchema, NULL)
expect_equal(testReplacedModuleSettings$analysis$something[[1]]$covariateCohortTable, NULL)
expect_equal(testReplacedModuleSettings$analysis$something[[2]]$covariateCohortDatabaseSchema, workDatabaseSchema)
expect_equal(testReplacedModuleSettings$analysis$something[[2]]$covariateCohortTable, cohortTableNames$cohortTable)

test2 <- .replaceCovariateSettingsCohortTableNames(cov1, executionSettings)
expect_equal(test2$covariateCohortDatabaseSchema, NULL)
expect_equal(test2$covariateCohortTable, NULL)
expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested1$covariateCohortDatabaseSchema, NULL)
expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested1$covariateCohortTable, NULL)

test3 <- .replaceCovariateSettingsCohortTableNames(cov2, executionSettings)
expect_equal(test3$covariateCohortDatabaseSchema, workDatabaseSchema)
expect_equal(test3$covariateCohortTable, cohortTableNames$cohortTable)
expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested2$covariateCohortDatabaseSchema, workDatabaseSchema)
expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested2$covariateCohortTable, cohortTableNames$cohortTable)

expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested3[[1]]$covariateCohortDatabaseSchema, NULL)
expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested3[[1]]$covariateCohortTable, NULL)
expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested3[[2]]$covariateCohortDatabaseSchema, workDatabaseSchema)
expect_equal(testReplacedModuleSettings$analysis$somethingElse$nested3[[2]]$covariateCohortTable, cohortTableNames$cohortTable)
})

0 comments on commit ad34028

Please sign in to comment.