Skip to content

Commit

Permalink
Release 1.1.6 (#72)
Browse files Browse the repository at this point in the history
- adds caching and resuming of hyperparameter search
  • Loading branch information
egillax authored Jun 18, 2023
1 parent cfb1bfa commit 2c88d3a
Show file tree
Hide file tree
Showing 12 changed files with 364 additions and 16 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: DeepPatientLevelPrediction
Type: Package
Title: Deep Learning For Patient Level Prediction Using Data In The OMOP Common Data Model
Version: 1.1.5
Version: 1.1.6
Date: 18-04-2023
Authors@R: c(
person("Egill", "Fridgeirsson", email = "[email protected]", role = c("aut", "cre")),
Expand Down Expand Up @@ -37,7 +37,7 @@ Suggests:
ResultModelManager (>= 0.2.0),
DatabaseConnector (>= 6.0.0)
Remotes:
ohdsi/PatientLevelPrediction,
ohdsi/PatientLevelPrediction@develop,
ohdsi/FeatureExtraction,
ohdsi/Eunomia,
ohdsi/ResultModelManager
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

export(Dataset)
export(Estimator)
export(TrainingCache)
export(fitEstimator)
export(gridCvDeep)
export(lrFinder)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
DeepPatientLevelPrediction 1.1.6
======================
- Caching and resuming of hyperparameter iterations

DeepPatientLevelPrediction 1.1.5
======================
- Fix bug where device function was not working for LRFinder
Expand Down
4 changes: 2 additions & 2 deletions R/Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ Dataset <- torch::dataset(
if (!is.null(labels)) {
self$target <- torch::torch_tensor(labels)
} else {
self$target <- torch::torch_tensor(rep(0, data %>% dplyr::distinct(rowId)
%>% dplyr::collect() %>% nrow()))
self$target <- torch::torch_tensor(rep(0, data %>% dplyr::summarize(m=max(rowId)) %>%
dplyr::collect() %>% dplyr::pull()))
}
# Weight to add in loss function to positive class
self$posWeight <- (self$target == 0)$sum() / self$target$sum()
Expand Down
31 changes: 24 additions & 7 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,14 @@ setEstimator <- function(learningRate='auto',
#' @param trainData the data to use
#' @param modelSettings modelSettings object
#' @param analysisId Id of the analysis
#' @param analysisPath Path of the analysis
#' @param ... Extra inputs
#'
#' @export
fitEstimator <- function(trainData,
modelSettings,
analysisId,
analysisPath,
...) {
start <- Sys.time()

Expand All @@ -128,15 +130,16 @@ fitEstimator <- function(trainData,
mappedData = mappedCovariateData,
labels = trainData$labels,
modelSettings = modelSettings,
modelLocation = outLoc
modelLocation = outLoc,
analysisPath = analysisPath
)
)

hyperSummary <- do.call(rbind, lapply(cvResult$paramGridSearch, function(x) x$hyperSummary))
prediction <- cvResult$prediction
incs <- rep(1, covariateRef %>% dplyr::tally() %>%
dplyr::collect ()
%>% dplyr::pull())
dplyr::collect () %>%
as.integer)
covariateRef <- covariateRef %>%
dplyr::collect() %>%
dplyr::mutate(
Expand Down Expand Up @@ -251,26 +254,37 @@ predictDeepEstimator <- function(plpModel,
#' @param labels Dataframe with the outcomes
#' @param modelSettings Settings of the model
#' @param modelLocation Where to save the model
#' @param analysisPath Path of the analysis
#'
#' @export
gridCvDeep <- function(mappedData,
labels,
modelSettings,
modelLocation) {
modelLocation,
analysisPath) {
ParallelLogger::logInfo(paste0("Running hyperparameter search for ", modelSettings$modelType, " model"))

###########################################################################

paramSearch <- modelSettings$param
gridSearchPredictons <- list()
length(gridSearchPredictons) <- length(paramSearch)
trainCache <- TrainingCache$new(analysisPath)

if (trainCache$isParamGridIdentical(paramSearch)) {
gridSearchPredictons <- trainCache$getGridSearchPredictions()
} else {
gridSearchPredictons <- list()
length(gridSearchPredictons) <- length(paramSearch)
trainCache$saveGridSearchPredictions(gridSearchPredictons)
trainCache$saveModelParams(paramSearch)
}

dataset <- Dataset(mappedData$covariates, labels$outcomeCount)

estimatorSettings <- modelSettings$estimatorSettings

fitParams <- names(paramSearch[[1]])[grepl("^estimator", names(paramSearch[[1]]))]

for (gridId in 1:length(paramSearch)) {
for (gridId in trainCache$getLastGridSearchIndex():length(paramSearch)) {
ParallelLogger::logInfo(paste0("Running hyperparameter combination no ", gridId))
ParallelLogger::logInfo(paste0("HyperParameters: "))
ParallelLogger::logInfo(paste(names(paramSearch[[gridId]]), paramSearch[[gridId]], collapse = " | "))
Expand Down Expand Up @@ -336,7 +350,10 @@ gridCvDeep <- function(mappedData,
prediction = prediction,
param = paramSearch[[gridId]]
)

trainCache$saveGridSearchPredictions(gridSearchPredictons)
}

# get best para (this could be modified to enable any metric instead of AUC, just need metric input in function)
paramGridSearch <- lapply(gridSearchPredictons, function(x) {
do.call(PatientLevelPrediction::computeGridPerformance, x)
Expand Down
90 changes: 90 additions & 0 deletions R/TrainingCache-class.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#' TrainingCache
#' @description
#' Parameter caching for training persistence and continuity
#' @export
TrainingCache <- R6::R6Class(
"TrainingCache",

private = list(
.paramPersistence = list(
gridSearchPredictions = NULL,
modelParams = NULL
),
.paramContinuity = list(),
.saveDir = NULL,

writeToFile = function() {
saveRDS(private$.paramPersistence, file.path(private$.saveDir))
},

readFromFile = function() {
private$.paramPersistence <- readRDS(file.path(private$.saveDir))
}
),

public = list(
#' @description
#' Creates a new training cache
#' @param inDir Path to the analysis directory
initialize = function(inDir) {
private$.saveDir <- file.path(inDir, "paramPersistence.rds")

if (file.exists(private$.saveDir)) {
private$readFromFile()
} else {
private$writeToFile()
}
},

#' @description
#' Checks whether the parameter grid in the model settings is identical to
#' the cached parameters.
#' @param inModelParams Parameter grid from the model settings
#' @returns Whether the provided and cached parameter grid is identical
isParamGridIdentical = function(inModelParams) {
return(identical(inModelParams, private$.paramPersistence$modelParams))
},

#' @description
#' Saves the grid search results to the training cache
#' @param inGridSearchPredictions Grid search predictions
saveGridSearchPredictions = function(inGridSearchPredictions) {
private$.paramPersistence$gridSearchPredictions <-
inGridSearchPredictions
private$writeToFile()
},

#' @description
#' Saves the parameter grid to the training cache
#' @param inModelParams Parameter grid from the model settings
saveModelParams = function(inModelParams) {
private$.paramPersistence$modelParams <- inModelParams
private$writeToFile()
},

#' @description
#' Gets the grid search results from the training cache
#' @returns Grid search results from the training cache
getGridSearchPredictions = function() {
return(private$.paramPersistence$gridSearchPredictions)
},

#' @description
#' Gets the last index from the cached grid search
#' @returns Last grid search index
getLastGridSearchIndex = function() {
if (is.null(private$.paramPersistence$gridSearchPredictions)) {
return(1)
} else {
return(which(sapply(private$.paramPersistence$gridSearchPredictions,
is.null))[1])
}
},

#' @description
#' Remove the training cache from the analysis path
dropCache = function() {
# TODO
}
)
)
145 changes: 145 additions & 0 deletions man/TrainingCache.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/fitEstimator.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 2c88d3a

Please sign in to comment.