diff --git a/R/Estimator.R b/R/Estimator.R index 1452503..246310f 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -353,12 +353,17 @@ gridCvDeep <- function(mappedData, trainCache$saveGridSearchPredictions(gridSearchPredictons) } - + browser() + learnSchedules <- lapply(gridSearchPredictons, function(x) {x$param$learnSchedule}) + gridSearchPredictons <- lapply(gridSearchPredictons, function(x) {x$param <- x$param[names(x$param) != "learnSchedule"]; x}) # get best para (this could be modified to enable any metric instead of AUC, just need metric input in function) paramGridSearch <- lapply(gridSearchPredictons, function(x) { do.call(PatientLevelPrediction::computeGridPerformance, x) }) # cvAUCmean, cvAUC, param - + paramGridSearch <- mapply(function(x, y) {x$hyperSummary$learnRates <- rep(list(unlist(y$LRs)),4) ; + x$param$learnSchedule <- y; x}, + paramGridSearch, learnSchedules, SIMPLIFY = FALSE) + optimalParamInd <- which.max(unlist(lapply(paramGridSearch, function(x) x$cvPerformance))) finalParam <- paramGridSearch[[optimalParamInd]]$param