-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- adds caching and resuming of hyperparameter search
- Loading branch information
Showing
12 changed files
with
364 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
Package: DeepPatientLevelPrediction | ||
Type: Package | ||
Title: Deep Learning For Patient Level Prediction Using Data In The OMOP Common Data Model | ||
Version: 1.1.5 | ||
Version: 1.1.6 | ||
Date: 18-04-2023 | ||
Authors@R: c( | ||
person("Egill", "Fridgeirsson", email = "[email protected]", role = c("aut", "cre")), | ||
|
@@ -37,7 +37,7 @@ Suggests: | |
ResultModelManager (>= 0.2.0), | ||
DatabaseConnector (>= 6.0.0) | ||
Remotes: | ||
ohdsi/PatientLevelPrediction, | ||
ohdsi/PatientLevelPrediction@develop, | ||
ohdsi/FeatureExtraction, | ||
ohdsi/Eunomia, | ||
ohdsi/ResultModelManager | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
#' TrainingCache | ||
#' @description | ||
#' Parameter caching for training persistence and continuity | ||
#' @export | ||
TrainingCache <- R6::R6Class( | ||
"TrainingCache", | ||
|
||
private = list( | ||
.paramPersistence = list( | ||
gridSearchPredictions = NULL, | ||
modelParams = NULL | ||
), | ||
.paramContinuity = list(), | ||
.saveDir = NULL, | ||
|
||
writeToFile = function() { | ||
saveRDS(private$.paramPersistence, file.path(private$.saveDir)) | ||
}, | ||
|
||
readFromFile = function() { | ||
private$.paramPersistence <- readRDS(file.path(private$.saveDir)) | ||
} | ||
), | ||
|
||
public = list( | ||
#' @description | ||
#' Creates a new training cache | ||
#' @param inDir Path to the analysis directory | ||
initialize = function(inDir) { | ||
private$.saveDir <- file.path(inDir, "paramPersistence.rds") | ||
|
||
if (file.exists(private$.saveDir)) { | ||
private$readFromFile() | ||
} else { | ||
private$writeToFile() | ||
} | ||
}, | ||
|
||
#' @description | ||
#' Checks whether the parameter grid in the model settings is identical to | ||
#' the cached parameters. | ||
#' @param inModelParams Parameter grid from the model settings | ||
#' @returns Whether the provided and cached parameter grid is identical | ||
isParamGridIdentical = function(inModelParams) { | ||
return(identical(inModelParams, private$.paramPersistence$modelParams)) | ||
}, | ||
|
||
#' @description | ||
#' Saves the grid search results to the training cache | ||
#' @param inGridSearchPredictions Grid search predictions | ||
saveGridSearchPredictions = function(inGridSearchPredictions) { | ||
private$.paramPersistence$gridSearchPredictions <- | ||
inGridSearchPredictions | ||
private$writeToFile() | ||
}, | ||
|
||
#' @description | ||
#' Saves the parameter grid to the training cache | ||
#' @param inModelParams Parameter grid from the model settings | ||
saveModelParams = function(inModelParams) { | ||
private$.paramPersistence$modelParams <- inModelParams | ||
private$writeToFile() | ||
}, | ||
|
||
#' @description | ||
#' Gets the grid search results from the training cache | ||
#' @returns Grid search results from the training cache | ||
getGridSearchPredictions = function() { | ||
return(private$.paramPersistence$gridSearchPredictions) | ||
}, | ||
|
||
#' @description | ||
#' Gets the last index from the cached grid search | ||
#' @returns Last grid search index | ||
getLastGridSearchIndex = function() { | ||
if (is.null(private$.paramPersistence$gridSearchPredictions)) { | ||
return(1) | ||
} else { | ||
return(which(sapply(private$.paramPersistence$gridSearchPredictions, | ||
is.null))[1]) | ||
} | ||
}, | ||
|
||
#' @description | ||
#' Remove the training cache from the analysis path | ||
dropCache = function() { | ||
# TODO | ||
} | ||
) | ||
) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Oops, something went wrong.