From 216c7afb8b7255d9da149c140341e0ce232de231 Mon Sep 17 00:00:00 2001 From: Henrik John Date: Fri, 6 Oct 2023 10:27:24 +0200 Subject: [PATCH] Only cache best prediction --- R/Estimator.R | 28 +++++++++++++++++++--------- R/TrainingCache-class.R | 3 ++- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/R/Estimator.R b/R/Estimator.R index 573c6a6..2b08b73 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -369,21 +369,31 @@ gridCvDeep <- function(mappedData, gridSearchPredictons[[gridId]] <- list( prediction = prediction, - param = paramSearch[[gridId]] + param = paramSearch[[gridId]], + gridPerformance = PatientLevelPrediction::computeGridPerformance(prediction, paramSearch[[gridId]]) ) - + + # remove all predictions that are not the max performance + indexOfMax <- which.max(sapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance)) + for (i in seq_along(gridSearchPredictons)) { + if (!is.null(gridSearchPredictons[[i]])) { + if (i != indexOfMax) { + gridSearchPredictons[[i]]$prediction <- NULL + } + } + } + trainCache$saveGridSearchPredictions(gridSearchPredictons) } - # get best para (this could be modified to enable any metric instead of AUC, just need metric input in function) - paramGridSearch <- lapply(gridSearchPredictons, function(x) { - do.call(PatientLevelPrediction::computeGridPerformance, x) - }) # cvAUCmean, cvAUC, param + paramGridSearch <- lapply(gridSearchPredictons, function(x) x$gridPerformance) - optimalParamInd <- which.max(unlist(lapply(paramGridSearch, function(x) x$cvPerformance))) - finalParam <- paramGridSearch[[optimalParamInd]]$param + # get best params + indexOfMax <- which.max(sapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance)) + finalParam <- gridSearchPredictons[[indexOfMax]]$param - cvPrediction <- gridSearchPredictons[[optimalParamInd]]$prediction + # get best CV prediction + cvPrediction <- gridSearchPredictons[[indexOfMax]]$prediction cvPrediction$evaluationType <- "CV" ParallelLogger::logInfo("Training final model using optimal parameters") diff --git a/R/TrainingCache-class.R b/R/TrainingCache-class.R index 8577f31..0dc4c02 100644 --- a/R/TrainingCache-class.R +++ b/R/TrainingCache-class.R @@ -8,7 +8,8 @@ TrainingCache <- R6::R6Class( private = list( .paramPersistence = list( gridSearchPredictions = NULL, - modelParams = NULL + modelParams = NULL, + gridPerformance = NULL ), .paramContinuity = list(), .saveDir = NULL,