Skip to content

Commit

Permalink
Refactoring gridCvDeep (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax authored Nov 27, 2023
1 parent 8900d8c commit d49b50e
Showing 1 changed file with 85 additions and 71 deletions.
156 changes: 85 additions & 71 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,9 @@ setEstimator <- function(learningRate = "auto",
) {

checkIsClass(learningRate, c("numeric", "character"))
if (inherits(learningRate, "character")) {
if (learningRate != "auto") {
stop(paste0('Learning rate should be either a numeric or "auto",
you provided: ', learningRate))
}
if (inherits(learningRate, "character") && learningRate != "auto") {
stop(paste0('Learning rate should be either a numeric or "auto",
you provided: ', learningRate))
}
checkIsClass(weightDecay, "numeric")
checkHigherEqual(weightDecay, 0.0)
Expand Down Expand Up @@ -113,23 +111,7 @@ setEstimator <- function(learningRate = "auto",
class(estimatorSettings$device))
}

paramsToTune <- list()
for (name in names(estimatorSettings)) {
param <- estimatorSettings[[name]]
if (length(param) > 1 && is.atomic(param)) {
paramsToTune[[paste0("estimator.", name)]] <- param
}
if ("params" %in% names(param)) {
for (name2 in names(param[["params"]])) {
param2 <- param[["params"]][[name2]]
if (length(param2) > 1) {
paramsToTune[[paste0("estimator.", name, ".", name2)]] <- param2
}
}
}
}
estimatorSettings$paramsToTune <- paramsToTune

estimatorSettings$paramsToTune <- extractParamsToTune(estimatorSettings)
return(estimatorSettings)
}

Expand Down Expand Up @@ -342,14 +324,10 @@ gridCvDeep <- function(mappedData,
currentModelParams <- paramSearch[[gridId]][modelSettings$modelParamNames]

currentEstimatorSettings <-
fillEstimatorSettings(modelSettings$estimatorSettings, fitParams,
fillEstimatorSettings(modelSettings$estimatorSettings,
fitParams,
paramSearch[[gridId]])

# initiate prediction
prediction <- NULL

fold <- labels$index
ParallelLogger::logInfo(paste0("Max fold: ", max(fold)))
currentEstimatorSettings$modelType <- modelSettings$modelType
currentModelParams$catFeatures <- dataset$get_cat_features()$shape[[1]]
currentModelParams$numFeatures <-
dataset$get_numerical_features()$shape[[1]]
Expand All @@ -363,63 +341,35 @@ gridCvDeep <- function(mappedData,
currentEstimatorSettings$learningRate <- lr
}

learnRates <- list()
for (i in 1:max(fold)) {
ParallelLogger::logInfo(paste0("Fold ", i))
trainDataset <-
torch$utils$data$Subset(dataset,
indices = as.integer(which(fold != i) - 1))
# -1 for python 0-based indexing
testDataset <-
torch$utils$data$Subset(dataset,
indices = as.integer(which(fold == i) - 1))
# -1 for python 0-based indexing

estimator <- createEstimator(modelType = modelSettings$modelType,
modelParameters = currentModelParams,
estimatorSettings =
currentEstimatorSettings)
estimator$fit(trainDataset, testDataset)

ParallelLogger::logInfo("Calculating predictions on left out
fold set...")

prediction <- rbind(
prediction,
predictDeepEstimator(
plpModel = estimator,
data = testDataset,
cohort = labels[fold == i, ]
)
)
learnRates[[i]] <- list(
LRs = estimator$learn_rate_schedule,
bestEpoch = estimator$best_epoch
)
}
crossValidationResults <-
doCrossvalidation(dataset,
labels = labels,
modelSettings = currentModelParams,
estimatorSettings = currentEstimatorSettings)
learnRates <- crossValidationResults$learnRates
prediction <- crossValidationResults$prediction

gridPerformance <-
PatientLevelPrediction::computeGridPerformance(prediction,
paramSearch[[gridId]])
maxIndex <- which.max(unlist(sapply(learnRates, `[`, 2)))
gridSearchPredictons[[gridId]] <- list(
prediction = prediction,
param = paramSearch[[gridId]],
gridPerformance =
PatientLevelPrediction::computeGridPerformance(prediction,
paramSearch[[gridId]])
gridPerformance = gridPerformance
)
gridSearchPredictons[[gridId]]$gridPerformance$hyperSummary$learnRates <-
rep(list(unlist(learnRates[[maxIndex]]$LRs)),
nrow(gridSearchPredictons[[gridId]]$gridPerformance$hyperSummary))
gridSearchPredictons[[gridId]]$param$learnSchedule <-
learnRates[[maxIndex]]

# remove all predictions that are not the max performance
indexOfMax <-
which.max(unlist(lapply(gridSearchPredictons,
function(x) x$gridPerformance$cvPerformance)))
for (i in seq_along(gridSearchPredictons)) {
if (!is.null(gridSearchPredictons[[i]])) {
if (i != indexOfMax) {
gridSearchPredictons[[i]]$prediction <- list(NULL)
}
if (!is.null(gridSearchPredictons[[i]]) && i != indexOfMax) {
gridSearchPredictons[[i]]$prediction <- list(NULL)
}
}
ParallelLogger::logInfo(paste0("Caching all grid search results and
Expand Down Expand Up @@ -543,3 +493,67 @@ createEstimator <- function(modelType,
estimator_settings = estimatorSettings)
return(estimator)
}

doCrossvalidation <- function(dataset,
labels,
modelSettings,
estimatorSettings) {
fold <- labels$index
ParallelLogger::logInfo(paste0("Max fold: ", max(fold)))
learnRates <- list()
prediction <- NULL
for (i in 1:max(fold)) {
ParallelLogger::logInfo(paste0("Fold ", i))

# -1 for python 0-based indexing
trainDataset <- torch$utils$data$Subset(dataset,
indices =
as.integer(which(fold != i) - 1))

# -1 for python 0-based indexing
testDataset <- torch$utils$data$Subset(dataset,
indices =
as.integer(which(fold == i) - 1))
estimator <- createEstimator(modelType = estimatorSettings$modelType,
modelParameters = modelSettings,
estimatorSettings = estimatorSettings)
estimator$fit(trainDataset, testDataset)

ParallelLogger::logInfo("Calculating predictions on left out fold set...")

prediction <- rbind(
prediction,
predictDeepEstimator(
plpModel = estimator,
data = testDataset,
cohort = labels[fold == i, ]
)
)
learnRates[[i]] <- list(
LRs = estimator$learn_rate_schedule,
bestEpoch = estimator$best_epoch
)
}
return(results = list(prediction = prediction,
learnRates = learnRates))

}

extractParamsToTune <- function(estimatorSettings) {
paramsToTune <- list()
for (name in names(estimatorSettings)) {
param <- estimatorSettings[[name]]
if (length(param) > 1 && is.atomic(param)) {
paramsToTune[[paste0("estimator.", name)]] <- param
}
if ("params" %in% names(param)) {
for (name2 in names(param[["params"]])) {
param2 <- param[["params"]][[name2]]
if (length(param2) > 1) {
paramsToTune[[paste0("estimator.", name, ".", name2)]] <- param2
}
}
}
}
return(paramsToTune)
}

0 comments on commit d49b50e

Please sign in to comment.