From 92bd44229bb23f8601a6c94c59a51c647f1151c7 Mon Sep 17 00:00:00 2001 From: egillax Date: Sun, 26 Nov 2023 18:53:47 +0100 Subject: [PATCH] start refactoring of estimator --- R/Estimator.R | 73 +++++++++++++++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 31 deletions(-) diff --git a/R/Estimator.R b/R/Estimator.R index 7705279..f6ff07f 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -319,12 +319,7 @@ gridCvDeep <- function(mappedData, currentEstimatorSettings <- fillEstimatorSettings(modelSettings$estimatorSettings, fitParams, paramSearch[[gridId]]) - - # initiate prediction - prediction <- NULL - - fold <- labels$index - ParallelLogger::logInfo(paste0("Max fold: ", max(fold))) + currentEstimatorSettings$modelType <- modelSettings$modelType currentModelParams$catFeatures <- dataset$get_cat_features()$shape[[1]] currentModelParams$numFeatures <- dataset$get_numerical_features()$shape[[1]] if (findLR) { @@ -337,32 +332,13 @@ gridCvDeep <- function(mappedData, currentEstimatorSettings$learningRate <- lr } - learnRates <- list() - for (i in 1:max(fold)) { - ParallelLogger::logInfo(paste0("Fold ", i)) - trainDataset <- torch$utils$data$Subset(dataset, indices = as.integer(which(fold != i) - 1)) # -1 for python 0-based indexing - testDataset <- torch$utils$data$Subset(dataset, indices = as.integer(which(fold == i) -1)) # -1 for python 0-based indexing + crossValidationResults <- doCrossvalidation(dataset, + labels=labels, + modelSettings = currentModelParams, + estimatorSettings = currentEstimatorSettings) + learnRates <- crossValidationResults$learnRates + prediction <- crossValidationResults$prediction - estimator <- createEstimator(modelType=modelSettings$modelType, - modelParameters=currentModelParams, - estimatorSettings=currentEstimatorSettings) - estimator$fit(trainDataset, testDataset) - - ParallelLogger::logInfo("Calculating predictions on left out fold set...") - - prediction <- rbind( - prediction, - predictDeepEstimator( - plpModel = estimator, - data = testDataset, - cohort = labels[fold == i, ] - ) - ) - learnRates[[i]] <- list( - LRs = estimator$learn_rate_schedule, - bestEpoch = estimator$best_epoch - ) - } maxIndex <- which.max(unlist(sapply(learnRates, `[`, 2))) gridSearchPredictons[[gridId]] <- list( prediction = prediction, @@ -496,4 +472,39 @@ createEstimator <- function(modelType, model_parameters = modelParameters, estimator_settings = estimatorSettings) return(estimator) +} + +doCrossvalidation <- function(dataset, labels, modelSettings, estimatorSettings) { + fold <- labels$index + ParallelLogger::logInfo(paste0("Max fold: ", max(fold))) + learnRates <- list() + prediction <- NULL + for (i in 1:max(fold)) { + ParallelLogger::logInfo(paste0("Fold ", i)) + trainDataset <- torch$utils$data$Subset(dataset, indices = as.integer(which(fold != i) - 1)) # -1 for python 0-based indexing + testDataset <- torch$utils$data$Subset(dataset, indices = as.integer(which(fold == i) -1)) # -1 for python 0-based indexing + + estimator <- createEstimator(modelType=estimatorSettings$modelType, + modelParameters=modelSettings, + estimatorSettings=estimatorSettings) + estimator$fit(trainDataset, testDataset) + + ParallelLogger::logInfo("Calculating predictions on left out fold set...") + + prediction <- rbind( + prediction, + predictDeepEstimator( + plpModel = estimator, + data = testDataset, + cohort = labels[fold == i, ] + ) + ) + learnRates[[i]] <- list( + LRs = estimator$learn_rate_schedule, + bestEpoch = estimator$best_epoch + ) + } + return (results=list(prediction=prediction, + learnRates=learnRates)) + } \ No newline at end of file