Skip to content

Commit

Permalink
Add fix to the full cache issue (#99)
Browse files Browse the repository at this point in the history
Co-authored-by: Xinzhuo Jiang <[email protected]>
  • Loading branch information
egillax and xj2193 authored Oct 21, 2023
1 parent 74c346f commit 69b8bec
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 13 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ config.yml
docs
.idea/
renv.lock
extras/
extras/
.Renviron
5 changes: 3 additions & 2 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ gridCvDeep <- function(mappedData,

fitParams <- names(paramSearch[[1]])[grepl("^estimator", names(paramSearch[[1]]))]
findLR <- modelSettings$estimatorSettings$findLR
for (gridId in trainCache$getLastGridSearchIndex():length(paramSearch)) {
if (!trainCache$isFull()) {
for (gridId in trainCache$getLastGridSearchIndex():length(paramSearch)) {
ParallelLogger::logInfo(paste0("Running hyperparameter combination no ", gridId))
ParallelLogger::logInfo(paste0("HyperParameters: "))
ParallelLogger::logInfo(paste(names(paramSearch[[gridId]]), paramSearch[[gridId]], collapse = " | "))
Expand Down Expand Up @@ -385,7 +386,7 @@ gridCvDeep <- function(mappedData,
ParallelLogger::logInfo(paste0("Caching all grid search results and prediction for best combination ", indexOfMax))
trainCache$saveGridSearchPredictions(gridSearchPredictons)
}

}
paramGridSearch <- lapply(gridSearchPredictons, function(x) x$gridPerformance)
# get best params
indexOfMax <- which.max(unlist(lapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance)))
Expand Down
7 changes: 7 additions & 0 deletions R/TrainingCache-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ TrainingCache <- R6::R6Class(
return(private$.paramPersistence$gridSearchPredictions)
},

#' @description
#' Check if cache is full
#' @returns Boolen
isFull = function() {
return(all(unlist(lapply(private$.paramPersistence$gridSearchPredictions, function(x) !is.null(x$gridPerformance)))))
},

#' @description
#' Gets the last index from the cached grid search
#' @returns Last grid search index
Expand Down
21 changes: 11 additions & 10 deletions extras/example.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ populationSet <- PatientLevelPrediction::createStudyPopulationSettings(
# epochs = 10L
# ))

modelSettings <- setDefaultResNet(estimatorSettings = setEstimator(
learningRate = "auto",
weightDecay = 1e-06,
device="cuda:0",
batchSize=128L,
epochs=50L,
seed=42
))
# modelSettings <- setDefaultResNet(estimatorSettings = setEstimator(
# learningRate = "auto",
# weightDecay = 1e-06,
# device="cuda:0",
# batchSize=128L,
# epochs=50L,
# seed=42
# ))

modelSettings <- setResNet(numLayers = c(1L, 2L),
sizeHidden = 72L,
Expand All @@ -45,7 +45,8 @@ modelSettings <- setResNet(numLayers = c(1L, 2L),
device = "cpu",
seed = 42
),
randomSample = 2)
randomSample = 2,
randomSampleSeed = 1)

res2 <- PatientLevelPrediction::runPlp(
plpData = plpData,
Expand All @@ -67,7 +68,7 @@ res2 <- PatientLevelPrediction::runPlp(
runModelDevelopment = T,
runCovariateSummary = F
),
saveDirectory = '~/test/resnet/'
saveDirectory = '~/deep_plp_test/resnet/'
)


0 comments on commit 69b8bec

Please sign in to comment.