From 216c7afb8b7255d9da149c140341e0ce232de231 Mon Sep 17 00:00:00 2001 From: Henrik John Date: Fri, 6 Oct 2023 10:27:24 +0200 Subject: [PATCH 1/5] Only cache best prediction --- R/Estimator.R | 28 +++++++++++++++++++--------- R/TrainingCache-class.R | 3 ++- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/R/Estimator.R b/R/Estimator.R index 573c6a6..2b08b73 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -369,21 +369,31 @@ gridCvDeep <- function(mappedData, gridSearchPredictons[[gridId]] <- list( prediction = prediction, - param = paramSearch[[gridId]] + param = paramSearch[[gridId]], + gridPerformance = PatientLevelPrediction::computeGridPerformance(prediction, paramSearch[[gridId]]) ) - + + # remove all predictions that are not the max performance + indexOfMax <- which.max(sapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance)) + for (i in seq_along(gridSearchPredictons)) { + if (!is.null(gridSearchPredictons[[i]])) { + if (i != indexOfMax) { + gridSearchPredictons[[i]]$prediction <- NULL + } + } + } + trainCache$saveGridSearchPredictions(gridSearchPredictons) } - # get best para (this could be modified to enable any metric instead of AUC, just need metric input in function) - paramGridSearch <- lapply(gridSearchPredictons, function(x) { - do.call(PatientLevelPrediction::computeGridPerformance, x) - }) # cvAUCmean, cvAUC, param + paramGridSearch <- lapply(gridSearchPredictons, function(x) x$gridPerformance) - optimalParamInd <- which.max(unlist(lapply(paramGridSearch, function(x) x$cvPerformance))) - finalParam <- paramGridSearch[[optimalParamInd]]$param + # get best params + indexOfMax <- which.max(sapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance)) + finalParam <- gridSearchPredictons[[indexOfMax]]$param - cvPrediction <- gridSearchPredictons[[optimalParamInd]]$prediction + # get best CV prediction + cvPrediction <- gridSearchPredictons[[indexOfMax]]$prediction cvPrediction$evaluationType <- "CV" ParallelLogger::logInfo("Training final model using optimal parameters") diff --git a/R/TrainingCache-class.R b/R/TrainingCache-class.R index 8577f31..0dc4c02 100644 --- a/R/TrainingCache-class.R +++ b/R/TrainingCache-class.R @@ -8,7 +8,8 @@ TrainingCache <- R6::R6Class( private = list( .paramPersistence = list( gridSearchPredictions = NULL, - modelParams = NULL + modelParams = NULL, + gridPerformance = NULL ), .paramContinuity = list(), .saveDir = NULL, From 2220ff2cf1962e381277871694271a87086982a0 Mon Sep 17 00:00:00 2001 From: Henrik John Date: Fri, 6 Oct 2023 11:25:49 +0200 Subject: [PATCH 2/5] Clean up --- R/Estimator.R | 6 +++--- R/TrainingCache-class.R | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/R/Estimator.R b/R/Estimator.R index 2b08b73..c2dba62 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -374,11 +374,11 @@ gridCvDeep <- function(mappedData, ) # remove all predictions that are not the max performance - indexOfMax <- which.max(sapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance)) + indexOfMax <- which.max(unlist(lapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance))) for (i in seq_along(gridSearchPredictons)) { if (!is.null(gridSearchPredictons[[i]])) { if (i != indexOfMax) { - gridSearchPredictons[[i]]$prediction <- NULL + gridSearchPredictons[[i]]$prediction <- list(NULL) } } } @@ -389,7 +389,7 @@ gridCvDeep <- function(mappedData, paramGridSearch <- lapply(gridSearchPredictons, function(x) x$gridPerformance) # get best params - indexOfMax <- which.max(sapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance)) + indexOfMax <- which.max(unlist(lapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance))) finalParam <- gridSearchPredictons[[indexOfMax]]$param # get best CV prediction diff --git a/R/TrainingCache-class.R b/R/TrainingCache-class.R index 0dc4c02..8577f31 100644 --- a/R/TrainingCache-class.R +++ b/R/TrainingCache-class.R @@ -8,8 +8,7 @@ TrainingCache <- R6::R6Class( private = list( .paramPersistence = list( gridSearchPredictions = NULL, - modelParams = NULL, - gridPerformance = NULL + modelParams = NULL ), .paramContinuity = list(), .saveDir = NULL, From 9f8d23a920c37de7c9e49711a094aa8490dcfbb1 Mon Sep 17 00:00:00 2001 From: Henrik John Date: Fri, 6 Oct 2023 11:50:40 +0200 Subject: [PATCH 3/5] Add logger message when caching --- R/Estimator.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/Estimator.R b/R/Estimator.R index c2dba62..1ff990f 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -382,7 +382,7 @@ gridCvDeep <- function(mappedData, } } } - + ParallelLogger::logInfo(paste0("Caching all grid search results and prediction for best combination ", indexOfMax)) trainCache$saveGridSearchPredictions(gridSearchPredictons) } From f93a53f2e4c940d0b18cb08bf2b2c71e547f826d Mon Sep 17 00:00:00 2001 From: Henrik John Date: Fri, 6 Oct 2023 13:21:07 +0200 Subject: [PATCH 4/5] Add test to ensure prediction is cached for optimal parameters --- tests/testthat/test-TrainingCache.R | 43 ++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/testthat/test-TrainingCache.R b/tests/testthat/test-TrainingCache.R index eb4ab17..5d639af 100644 --- a/tests/testthat/test-TrainingCache.R +++ b/tests/testthat/test-TrainingCache.R @@ -12,7 +12,7 @@ resNetSettings <- setResNet(numLayers = c(1, 2, 4), seed=NULL), hyperParamSearch = "random", randomSample = 3, - randomSampleSeed = NULL) + randomSampleSeed = 123) trainCache <- TrainingCache$new(testLoc) paramSearch <- resNetSettings$param @@ -87,3 +87,44 @@ test_that("Estimator can resume training from cache", { trainCache <- TrainingCache$new(analysisPath) testthat::expect_equal(is.na(trainCache$getLastGridSearchIndex()), TRUE) }) + +test_that("Prediction is cached for optimal parameters", { + modelPath <- tempdir() + analysisPath <- file.path(modelPath, "Analysis_TrainCacheResNet_GridSearchPred") + dir.create(analysisPath) + + sink(nullfile()) + res2 <- tryCatch( + { + PatientLevelPrediction::runPlp( + plpData = plpData, + outcomeId = 3, + modelSettings = resNetSettings, + analysisId = "Analysis_TrainCacheResNet_GridSearchPred", + analysisName = "Testing Training Cache - GridSearch", + populationSettings = populationSet, + splitSettings = PatientLevelPrediction::createDefaultSplitSetting(), + sampleSettings = PatientLevelPrediction::createSampleSettings(), # none + featureEngineeringSettings = PatientLevelPrediction::createFeatureEngineeringSettings(), # none + preprocessSettings = PatientLevelPrediction::createPreprocessSettings(), + executeSettings = PatientLevelPrediction::createExecuteSettings( + runSplitData = T, + runSampleData = F, + runfeatureEngineering = F, + runPreprocessData = T, + runModelDevelopment = T, + runCovariateSummary = F + ), + saveDirectory = modelPath + ) + }, + error = function(e) { + print(e) + return(NULL) + } + ) + sink() + testCache <- readRDS(file.path(analysisPath, "paramPersistence.RDS")) + indexOfMax <- which.max(unlist(lapply(testCache$gridSearchPredictions, function(x) x$gridPerformance$cvPerformance))) + testthat::expect_equal(class(testCache$gridSearchPredictions[[indexOfMax]]$prediction), class(data.frame())) +}) From 53af472f5efef6e5b5e5542f00910a72986afa3e Mon Sep 17 00:00:00 2001 From: Henrik John Date: Fri, 6 Oct 2023 13:34:16 +0200 Subject: [PATCH 5/5] Resolve an issue with case sensitivity on Ubuntu --- tests/testthat/test-TrainingCache.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test-TrainingCache.R b/tests/testthat/test-TrainingCache.R index 5d639af..aacd9a0 100644 --- a/tests/testthat/test-TrainingCache.R +++ b/tests/testthat/test-TrainingCache.R @@ -124,7 +124,7 @@ test_that("Prediction is cached for optimal parameters", { } ) sink() - testCache <- readRDS(file.path(analysisPath, "paramPersistence.RDS")) + testCache <- readRDS(file.path(analysisPath, "paramPersistence.rds")) indexOfMax <- which.max(unlist(lapply(testCache$gridSearchPredictions, function(x) x$gridPerformance$cvPerformance))) testthat::expect_equal(class(testCache$gridSearchPredictions[[indexOfMax]]$prediction), class(data.frame())) })