Skip to content

Commit

Permalink
Fix lr schedule (#91)
Browse files Browse the repository at this point in the history
* circumvent unnesting of lrSchedule in PLP

* make lrSchedule fix compatible with new cache code
  • Loading branch information
egillax authored Oct 10, 2023
1 parent 1f50fa5 commit 84bbb18
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,16 @@ gridCvDeep <- function(mappedData,
)
}
maxIndex <- which.max(unlist(sapply(learnRates, `[`, 2)))
paramSearch[[gridId]]$learnSchedule <- learnRates[[maxIndex]]

gridSearchPredictons[[gridId]] <- list(
prediction = prediction,
param = paramSearch[[gridId]],
gridPerformance = PatientLevelPrediction::computeGridPerformance(prediction, paramSearch[[gridId]])
gridPerformance = PatientLevelPrediction::computeGridPerformance(prediction, paramSearch[[gridId]])
)
gridSearchPredictons[[gridId]]$gridPerformance$hyperSummary$learnRates <- rep(list(unlist(learnRates[[maxIndex]]$LRs)),
nrow(gridSearchPredictons[[gridId]]$gridPerformance$hyperSummary))
gridSearchPredictons[[gridId]]$param$learnSchedule <- learnRates[[maxIndex]]


# remove all predictions that are not the max performance
indexOfMax <- which.max(unlist(lapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance)))
for (i in seq_along(gridSearchPredictons)) {
Expand All @@ -387,10 +389,11 @@ gridCvDeep <- function(mappedData,
}

paramGridSearch <- lapply(gridSearchPredictons, function(x) x$gridPerformance)

# get best params
indexOfMax <- which.max(unlist(lapply(gridSearchPredictons, function(x) x$gridPerformance$cvPerformance)))
finalParam <- gridSearchPredictons[[indexOfMax]]$param

paramGridSearch <- lapply(gridSearchPredictons, function(x) x$gridPerformance)

# get best CV prediction
cvPrediction <- gridSearchPredictons[[indexOfMax]]$prediction
Expand Down

0 comments on commit 84bbb18

Please sign in to comment.