diff --git a/R/Estimator.R b/R/Estimator.R index 573c6a6..1ff990f 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -369,21 +369,31 @@ gridCvDeep <- function(mappedData, gridSearchPredictons[[gridId]] <- list( prediction = prediction, - param = paramSearch[[gridId]] + param = paramSearch[[gridId]], + gridPerformance = PatientLevelPrediction::computeGridPerformance(prediction, paramSearch[[gridId]]) ) - + + # remove all predictions that are not the max performance + indexOfMax <- which.max(unlist(lapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance))) + for (i in seq_along(gridSearchPredictons)) { + if (!is.null(gridSearchPredictons[[i]])) { + if (i != indexOfMax) { + gridSearchPredictons[[i]]$prediction <- list(NULL) + } + } + } + ParallelLogger::logInfo(paste0("Caching all grid search results and prediction for best combination ", indexOfMax)) trainCache$saveGridSearchPredictions(gridSearchPredictons) } - # get best para (this could be modified to enable any metric instead of AUC, just need metric input in function) - paramGridSearch <- lapply(gridSearchPredictons, function(x) { - do.call(PatientLevelPrediction::computeGridPerformance, x) - }) # cvAUCmean, cvAUC, param + paramGridSearch <- lapply(gridSearchPredictons, function(x) x$gridPerformance) - optimalParamInd <- which.max(unlist(lapply(paramGridSearch, function(x) x$cvPerformance))) - finalParam <- paramGridSearch[[optimalParamInd]]$param + # get best params + indexOfMax <- which.max(unlist(lapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance))) + finalParam <- gridSearchPredictons[[indexOfMax]]$param - cvPrediction <- gridSearchPredictons[[optimalParamInd]]$prediction + # get best CV prediction + cvPrediction <- gridSearchPredictons[[indexOfMax]]$prediction cvPrediction$evaluationType <- "CV" ParallelLogger::logInfo("Training final model using optimal parameters") diff --git a/tests/testthat/test-TrainingCache.R b/tests/testthat/test-TrainingCache.R index eb4ab17..aacd9a0 100644 --- a/tests/testthat/test-TrainingCache.R +++ b/tests/testthat/test-TrainingCache.R @@ -12,7 +12,7 @@ resNetSettings <- setResNet(numLayers = c(1, 2, 4), seed=NULL), hyperParamSearch = "random", randomSample = 3, - randomSampleSeed = NULL) + randomSampleSeed = 123) trainCache <- TrainingCache$new(testLoc) paramSearch <- resNetSettings$param @@ -87,3 +87,44 @@ test_that("Estimator can resume training from cache", { trainCache <- TrainingCache$new(analysisPath) testthat::expect_equal(is.na(trainCache$getLastGridSearchIndex()), TRUE) }) + +test_that("Prediction is cached for optimal parameters", { + modelPath <- tempdir() + analysisPath <- file.path(modelPath, "Analysis_TrainCacheResNet_GridSearchPred") + dir.create(analysisPath) + + sink(nullfile()) + res2 <- tryCatch( + { + PatientLevelPrediction::runPlp( + plpData = plpData, + outcomeId = 3, + modelSettings = resNetSettings, + analysisId = "Analysis_TrainCacheResNet_GridSearchPred", + analysisName = "Testing Training Cache - GridSearch", + populationSettings = populationSet, + splitSettings = PatientLevelPrediction::createDefaultSplitSetting(), + sampleSettings = PatientLevelPrediction::createSampleSettings(), # none + featureEngineeringSettings = PatientLevelPrediction::createFeatureEngineeringSettings(), # none + preprocessSettings = PatientLevelPrediction::createPreprocessSettings(), + executeSettings = PatientLevelPrediction::createExecuteSettings( + runSplitData = T, + runSampleData = F, + runfeatureEngineering = F, + runPreprocessData = T, + runModelDevelopment = T, + runCovariateSummary = F + ), + saveDirectory = modelPath + ) + }, + error = function(e) { + print(e) + return(NULL) + } + ) + sink() + testCache <- readRDS(file.path(analysisPath, "paramPersistence.rds")) + indexOfMax <- which.max(unlist(lapply(testCache$gridSearchPredictions, function(x) x$gridPerformance$cvPerformance))) + testthat::expect_equal(class(testCache$gridSearchPredictions[[indexOfMax]]$prediction), class(data.frame())) +})