Skip to content

Commit

Permalink
Only cache best prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
lhjohn committed Oct 6, 2023
1 parent ff9e22e commit 216c7af
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
28 changes: 19 additions & 9 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -369,21 +369,31 @@ gridCvDeep <- function(mappedData,

gridSearchPredictons[[gridId]] <- list(
prediction = prediction,
param = paramSearch[[gridId]]
param = paramSearch[[gridId]],
gridPerformance = PatientLevelPrediction::computeGridPerformance(prediction, paramSearch[[gridId]])
)


# remove all predictions that are not the max performance
indexOfMax <- which.max(sapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance))
for (i in seq_along(gridSearchPredictons)) {
if (!is.null(gridSearchPredictons[[i]])) {
if (i != indexOfMax) {
gridSearchPredictons[[i]]$prediction <- NULL
}
}
}

trainCache$saveGridSearchPredictions(gridSearchPredictons)
}

# get best para (this could be modified to enable any metric instead of AUC, just need metric input in function)
paramGridSearch <- lapply(gridSearchPredictons, function(x) {
do.call(PatientLevelPrediction::computeGridPerformance, x)
}) # cvAUCmean, cvAUC, param
paramGridSearch <- lapply(gridSearchPredictons, function(x) x$gridPerformance)

optimalParamInd <- which.max(unlist(lapply(paramGridSearch, function(x) x$cvPerformance)))
finalParam <- paramGridSearch[[optimalParamInd]]$param
# get best params
indexOfMax <- which.max(sapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance))
finalParam <- gridSearchPredictons[[indexOfMax]]$param

cvPrediction <- gridSearchPredictons[[optimalParamInd]]$prediction
# get best CV prediction
cvPrediction <- gridSearchPredictons[[indexOfMax]]$prediction
cvPrediction$evaluationType <- "CV"

ParallelLogger::logInfo("Training final model using optimal parameters")
Expand Down
3 changes: 2 additions & 1 deletion R/TrainingCache-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ TrainingCache <- R6::R6Class(
private = list(
.paramPersistence = list(
gridSearchPredictions = NULL,
modelParams = NULL
modelParams = NULL,
gridPerformance = NULL
),
.paramContinuity = list(),
.saveDir = NULL,
Expand Down

0 comments on commit 216c7af

Please sign in to comment.