Skip to content

Commit

Permalink
Merge pull request #90 from OHDSI/88-reduce-cache-size
Browse files Browse the repository at this point in the history
Reduce cache size
  • Loading branch information
lhjohn authored Oct 6, 2023
2 parents ff9e22e + 53af472 commit 1f50fa5
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 10 deletions.
28 changes: 19 additions & 9 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -369,21 +369,31 @@ gridCvDeep <- function(mappedData,

gridSearchPredictons[[gridId]] <- list(
prediction = prediction,
param = paramSearch[[gridId]]
param = paramSearch[[gridId]],
gridPerformance = PatientLevelPrediction::computeGridPerformance(prediction, paramSearch[[gridId]])
)


# remove all predictions that are not the max performance
indexOfMax <- which.max(unlist(lapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance)))
for (i in seq_along(gridSearchPredictons)) {
if (!is.null(gridSearchPredictons[[i]])) {
if (i != indexOfMax) {
gridSearchPredictons[[i]]$prediction <- list(NULL)
}
}
}
ParallelLogger::logInfo(paste0("Caching all grid search results and prediction for best combination ", indexOfMax))
trainCache$saveGridSearchPredictions(gridSearchPredictons)
}

# get best para (this could be modified to enable any metric instead of AUC, just need metric input in function)
paramGridSearch <- lapply(gridSearchPredictons, function(x) {
do.call(PatientLevelPrediction::computeGridPerformance, x)
}) # cvAUCmean, cvAUC, param
paramGridSearch <- lapply(gridSearchPredictons, function(x) x$gridPerformance)

optimalParamInd <- which.max(unlist(lapply(paramGridSearch, function(x) x$cvPerformance)))
finalParam <- paramGridSearch[[optimalParamInd]]$param
# get best params
indexOfMax <- which.max(unlist(lapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance)))
finalParam <- gridSearchPredictons[[indexOfMax]]$param

cvPrediction <- gridSearchPredictons[[optimalParamInd]]$prediction
# get best CV prediction
cvPrediction <- gridSearchPredictons[[indexOfMax]]$prediction
cvPrediction$evaluationType <- "CV"

ParallelLogger::logInfo("Training final model using optimal parameters")
Expand Down
43 changes: 42 additions & 1 deletion tests/testthat/test-TrainingCache.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ resNetSettings <- setResNet(numLayers = c(1, 2, 4),
seed=NULL),
hyperParamSearch = "random",
randomSample = 3,
randomSampleSeed = NULL)
randomSampleSeed = 123)

trainCache <- TrainingCache$new(testLoc)
paramSearch <- resNetSettings$param
Expand Down Expand Up @@ -87,3 +87,44 @@ test_that("Estimator can resume training from cache", {
trainCache <- TrainingCache$new(analysisPath)
testthat::expect_equal(is.na(trainCache$getLastGridSearchIndex()), TRUE)
})

test_that("Prediction is cached for optimal parameters", {
modelPath <- tempdir()
analysisPath <- file.path(modelPath, "Analysis_TrainCacheResNet_GridSearchPred")
dir.create(analysisPath)

sink(nullfile())
res2 <- tryCatch(
{
PatientLevelPrediction::runPlp(
plpData = plpData,
outcomeId = 3,
modelSettings = resNetSettings,
analysisId = "Analysis_TrainCacheResNet_GridSearchPred",
analysisName = "Testing Training Cache - GridSearch",
populationSettings = populationSet,
splitSettings = PatientLevelPrediction::createDefaultSplitSetting(),
sampleSettings = PatientLevelPrediction::createSampleSettings(), # none
featureEngineeringSettings = PatientLevelPrediction::createFeatureEngineeringSettings(), # none
preprocessSettings = PatientLevelPrediction::createPreprocessSettings(),
executeSettings = PatientLevelPrediction::createExecuteSettings(
runSplitData = T,
runSampleData = F,
runfeatureEngineering = F,
runPreprocessData = T,
runModelDevelopment = T,
runCovariateSummary = F
),
saveDirectory = modelPath
)
},
error = function(e) {
print(e)
return(NULL)
}
)
sink()
testCache <- readRDS(file.path(analysisPath, "paramPersistence.rds"))
indexOfMax <- which.max(unlist(lapply(testCache$gridSearchPredictions, function(x) x$gridPerformance$cvPerformance)))
testthat::expect_equal(class(testCache$gridSearchPredictions[[indexOfMax]]$prediction), class(data.frame()))
})

0 comments on commit 1f50fa5

Please sign in to comment.