From d2469f6d6b4d788c39f378d81a6693642b92ad2e Mon Sep 17 00:00:00 2001 From: egillax Date: Sat, 23 Dec 2023 15:49:10 +0100 Subject: [PATCH] chore: styling --- R/Dataset.R | 9 +- R/Estimator.R | 263 +++++++++++++++++++++++++--------------- R/LRFinder.R | 21 ++-- R/MLP.R | 29 +++-- R/ResNet.R | 85 +++++++------ R/TrainingCache-class.R | 14 ++- R/Transformer.R | 77 +++++++----- 7 files changed, 301 insertions(+), 197 deletions(-) diff --git a/R/Dataset.R b/R/Dataset.R index 4ba4e81..f71b3db 100644 --- a/R/Dataset.R +++ b/R/Dataset.R @@ -23,13 +23,16 @@ createDataset <- function(data, labels, plpModel = NULL) { attributes(data)$path <- attributes(data)$dbname } if (is.null(plpModel)) { - data <- dataset(r_to_py(normalizePath(attributes(data)$path)), - r_to_py(labels$outcomeCount)) + data <- dataset( + r_to_py(normalizePath(attributes(data)$path)), + r_to_py(labels$outcomeCount) + ) } else { numericalFeatures <- r_to_py(as.array(which(plpModel$covariateImportance$isNumeric))) data <- dataset(r_to_py(normalizePath(attributes(data)$path)), - numerical_features = numericalFeatures) + numerical_features = numericalFeatures + ) } return(data) diff --git a/R/Estimator.R b/R/Estimator.R index 376b49d..0c7698e 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -42,24 +42,27 @@ #' outputs a score. #' @param seed seed to initialize weights of model with #' @export -setEstimator <- function(learningRate = "auto", - weightDecay = 0.0, - batchSize = 512, - epochs = 30, - device = "cpu", - optimizer = torch$optim$AdamW, - scheduler = list(fun = torch$optim$lr_scheduler$ReduceLROnPlateau, - params = list(patience = 1)), - criterion = torch$nn$BCEWithLogitsLoss, - earlyStopping = list(useEarlyStopping = TRUE, - params = list(patience = 4)), - metric = "auc", - seed = NULL -) { - +setEstimator <- function( + learningRate = "auto", + weightDecay = 0.0, + batchSize = 512, + epochs = 30, + device = "cpu", + optimizer = torch$optim$AdamW, + scheduler = list( + fun = torch$optim$lr_scheduler$ReduceLROnPlateau, + params = list(patience = 1) + ), + criterion = torch$nn$BCEWithLogitsLoss, + earlyStopping = list( + useEarlyStopping = TRUE, + params = list(patience = 4) + ), + metric = "auc", + seed = NULL) { checkIsClass(learningRate, c("numeric", "character")) if (inherits(learningRate, "character") && learningRate != "auto") { - stop(paste0('Learning rate should be either a numeric or "auto", + stop(paste0('Learning rate should be either a numeric or "auto", you provided: ', learningRate)) } checkIsClass(weightDecay, "numeric") @@ -81,34 +84,44 @@ setEstimator <- function(learningRate = "auto", if (is.null(seed)) { seed <- as.integer(sample(1e5, 1)) } - estimatorSettings <- list(learningRate = learningRate, - weightDecay = weightDecay, - batchSize = batchSize, - epochs = epochs, - device = device, - earlyStopping = earlyStopping, - findLR = findLR, - metric = metric, - seed = seed[1]) + estimatorSettings <- list( + learningRate = learningRate, + weightDecay = weightDecay, + batchSize = batchSize, + epochs = epochs, + device = device, + earlyStopping = earlyStopping, + findLR = findLR, + metric = metric, + seed = seed[1] + ) optimizer <- rlang::enquo(optimizer) estimatorSettings$optimizer <- function() rlang::eval_tidy(optimizer) - class(estimatorSettings$optimizer) <- c("delayed", - class(estimatorSettings$optimizer)) + class(estimatorSettings$optimizer) <- c( + "delayed", + class(estimatorSettings$optimizer) + ) criterion <- rlang::enquo(criterion) estimatorSettings$criterion <- function() rlang::eval_tidy(criterion) - class(estimatorSettings$criterion) <- c("delayed", - class(estimatorSettings$criterion)) + class(estimatorSettings$criterion) <- c( + "delayed", + class(estimatorSettings$criterion) + ) scheduler <- rlang::enquo(scheduler) estimatorSettings$scheduler <- function() rlang::eval_tidy(scheduler) - class(estimatorSettings$scheduler) <- c("delayed", - class(estimatorSettings$scheduler)) + class(estimatorSettings$scheduler) <- c( + "delayed", + class(estimatorSettings$scheduler) + ) if (is.function(device)) { - class(estimatorSettings$device) <- c("delayed", - class(estimatorSettings$device)) + class(estimatorSettings$device) <- c( + "delayed", + class(estimatorSettings$device) + ) } estimatorSettings$paramsToTune <- extractParamsToTune(estimatorSettings) @@ -161,13 +174,16 @@ fitEstimator <- function(trainData, ) ) - hyperSummary <- do.call(rbind, lapply(cvResult$paramGridSearch, - function(x) x$hyperSummary)) + hyperSummary <- do.call(rbind, lapply( + cvResult$paramGridSearch, + function(x) x$hyperSummary + )) prediction <- cvResult$prediction incs <- rep(1, covariateRef %>% - dplyr::tally() %>% - dplyr::collect() %>% - as.integer()) + dplyr::tally() %>% + dplyr::collect() %>% + as.integer() + ) covariateRef <- covariateRef %>% dplyr::arrange("columnId") %>% dplyr::collect() %>% @@ -180,26 +196,35 @@ fitEstimator <- function(trainData, comp <- start - Sys.time() result <- list( model = cvResult$estimator, - preprocessing = list( - featureEngineering = attr(trainData$covariateData, - "metaData")$featureEngineering, - tidyCovariates = attr(trainData$covariateData, - "metaData")$tidyCovariateDataSettings, + featureEngineering = attr( + trainData$covariateData, + "metaData" + )$featureEngineering, + tidyCovariates = attr( + trainData$covariateData, + "metaData" + )$tidyCovariateDataSettings, requireDenseMatrix = FALSE ), prediction = prediction, modelDesign = PatientLevelPrediction::createModelDesign( targetId = attr(trainData, "metaData")$targetId, outcomeId = attr(trainData, "metaData")$outcomeId, - restrictPlpDataSettings = attr(trainData, - "metaData")$restrictPlpDataSettings, + restrictPlpDataSettings = attr( + trainData, + "metaData" + )$restrictPlpDataSettings, covariateSettings = attr(trainData, "metaData")$covariateSettings, populationSettings = attr(trainData, "metaData")$populationSettings, - featureEngineeringSettings = attr(trainData$covariateData, - "metaData")$featureEngineeringSettings, - preprocessSettings = attr(trainData$covariateData, - "metaData")$preprocessSettings, + featureEngineeringSettings = attr( + trainData$covariateData, + "metaData" + )$featureEngineeringSettings, + preprocessSettings = attr( + trainData$covariateData, + "metaData" + )$preprocessSettings, modelSettings = modelSettings, splitSettings = attr(trainData, "metaData")$splitSettings, sampleSettings = attr(trainData, "metaData")$sampleSettings @@ -257,12 +282,18 @@ predictDeepEstimator <- function(plpModel, prediction <- cohort if (is.character(plpModel$model)) { modelSettings <- plpModel$modelDesign$modelSettings - model <- torch$load(file.path(plpModel$model, - "DeepEstimatorModel.pt"), - map_location = "cpu") - estimator <- createEstimator(modelType = modelSettings$modelType, - modelParameters = model$model_parameters, - estimatorSettings = model$estimator_settings) + model <- torch$load( + file.path( + plpModel$model, + "DeepEstimatorModel.pt" + ), + map_location = "cpu" + ) + estimator <- createEstimator( + modelType = modelSettings$modelType, + modelParameters = model$model_parameters, + estimatorSettings = model$estimator_settings + ) estimator$model$load_state_dict(model$model_state_dict) prediction$value <- estimator$predict_proba(data) } else { @@ -292,8 +323,10 @@ gridCvDeep <- function(mappedData, modelSettings, modelLocation, analysisPath) { - ParallelLogger::logInfo(paste0("Running hyperparameter search for ", - modelSettings$modelType, " model")) + ParallelLogger::logInfo(paste0( + "Running hyperparameter search for ", + modelSettings$modelType, " model" + )) ########################################################################### @@ -312,28 +345,37 @@ gridCvDeep <- function(mappedData, dataset <- createDataset(data = mappedData, labels = labels) - fitParams <- names(paramSearch[[1]])[grepl("^estimator", - names(paramSearch[[1]]))] + fitParams <- names(paramSearch[[1]])[grepl( + "^estimator", + names(paramSearch[[1]]) + )] findLR <- modelSettings$estimatorSettings$findLR if (!trainCache$isFull()) { for (gridId in trainCache$getLastGridSearchIndex():length(paramSearch)) { - ParallelLogger::logInfo(paste0("Running hyperparameter combination no ", - gridId)) + ParallelLogger::logInfo(paste0( + "Running hyperparameter combination no ", + gridId + )) ParallelLogger::logInfo(paste0("HyperParameters: ")) ParallelLogger::logInfo(paste(names(paramSearch[[gridId]]), - paramSearch[[gridId]], collapse = " | ")) + paramSearch[[gridId]], + collapse = " | " + )) currentModelParams <- paramSearch[[gridId]][modelSettings$modelParamNames] currentEstimatorSettings <- - fillEstimatorSettings(modelSettings$estimatorSettings, - fitParams, - paramSearch[[gridId]]) + fillEstimatorSettings( + modelSettings$estimatorSettings, + fitParams, + paramSearch[[gridId]] + ) currentEstimatorSettings$modelType <- modelSettings$modelType currentModelParams$catFeatures <- dataset$get_cat_features()$max() currentModelParams$numFeatures <- dataset$get_numerical_features()$max() if (findLR) { - lrFinder <- createLRFinder(modelType = modelSettings$modelType, + lrFinder <- createLRFinder( + modelType = modelSettings$modelType, modelParameters = currentModelParams, estimatorSettings = currentEstimatorSettings ) @@ -344,15 +386,18 @@ gridCvDeep <- function(mappedData, crossValidationResults <- doCrossvalidation(dataset, - labels = labels, - modelSettings = currentModelParams, - estimatorSettings = currentEstimatorSettings) + labels = labels, + modelSettings = currentModelParams, + estimatorSettings = currentEstimatorSettings + ) learnRates <- crossValidationResults$learnRates prediction <- crossValidationResults$prediction gridPerformance <- - PatientLevelPrediction::computeGridPerformance(prediction, - paramSearch[[gridId]]) + PatientLevelPrediction::computeGridPerformance( + prediction, + paramSearch[[gridId]] + ) maxIndex <- which.max(unlist(sapply(learnRates, `[`, 2))) gridSearchPredictons[[gridId]] <- list( prediction = prediction, @@ -360,30 +405,38 @@ gridCvDeep <- function(mappedData, gridPerformance = gridPerformance ) gridSearchPredictons[[gridId]]$gridPerformance$hyperSummary$learnRates <- - rep(list(unlist(learnRates[[maxIndex]]$LRs)), - nrow(gridSearchPredictons[[gridId]]$gridPerformance$hyperSummary)) + rep( + list(unlist(learnRates[[maxIndex]]$LRs)), + nrow(gridSearchPredictons[[gridId]]$gridPerformance$hyperSummary) + ) gridSearchPredictons[[gridId]]$param$learnSchedule <- learnRates[[maxIndex]] # remove all predictions that are not the max performance indexOfMax <- - which.max(unlist(lapply(gridSearchPredictons, - function(x) x$gridPerformance$cvPerformance))) + which.max(unlist(lapply( + gridSearchPredictons, + function(x) x$gridPerformance$cvPerformance + ))) for (i in seq_along(gridSearchPredictons)) { if (!is.null(gridSearchPredictons[[i]]) && i != indexOfMax) { gridSearchPredictons[[i]]$prediction <- list(NULL) } } - ParallelLogger::logInfo(paste0("Caching all grid search results and + ParallelLogger::logInfo(paste0( + "Caching all grid search results and prediction for best combination ", - indexOfMax)) + indexOfMax + )) trainCache$saveGridSearchPredictions(gridSearchPredictons) } } paramGridSearch <- lapply(gridSearchPredictons, function(x) x$gridPerformance) # get best params indexOfMax <- - which.max(unlist(lapply(gridSearchPredictons, - function(x) x$gridPerformance$cvPerformance))) + which.max(unlist(lapply( + gridSearchPredictons, + function(x) x$gridPerformance$cvPerformance + ))) finalParam <- gridSearchPredictons[[indexOfMax]]$param paramGridSearch <- lapply(gridSearchPredictons, function(x) x$gridPerformance) @@ -405,13 +458,17 @@ gridCvDeep <- function(mappedData, modelParams$catFeatures <- dataset$get_cat_features()$max() modelParams$numFeatures <- dataset$get_numerical_features()$len() - estimatorSettings <- fillEstimatorSettings(modelSettings$estimatorSettings, - fitParams, - finalParam) + estimatorSettings <- fillEstimatorSettings( + modelSettings$estimatorSettings, + fitParams, + finalParam + ) estimatorSettings$learningRate <- finalParam$learnSchedule$LRs[[1]] - estimator <- createEstimator(modelType = modelSettings$modelType, - modelParameters = modelParams, - estimatorSettings = estimatorSettings) + estimator <- createEstimator( + modelType = modelSettings$modelType, + modelParameters = modelParams, + estimatorSettings = estimatorSettings + ) numericalIndex <- dataset$get_numerical_features() estimator$fit_whole_training_set(dataset, finalParam$learnSchedule$LRs) @@ -433,7 +490,8 @@ gridCvDeep <- function(mappedData, dplyr::select(-"index") prediction$cohortStartDate <- as.Date(prediction$cohortStartDate, - origin = "1970-01-01") + origin = "1970-01-01" + ) # save torch code here @@ -488,9 +546,11 @@ createEstimator <- function(modelType, estimatorSettings <- camelCaseToSnakeCaseNames(estimatorSettings) estimatorSettings <- evalEstimatorSettings(estimatorSettings) - estimator <- estimator(model = model, - model_parameters = modelParameters, - estimator_settings = estimatorSettings) + estimator <- estimator( + model = model, + model_parameters = modelParameters, + estimator_settings = estimatorSettings + ) return(estimator) } @@ -507,16 +567,20 @@ doCrossvalidation <- function(dataset, # -1 for python 0-based indexing trainDataset <- torch$utils$data$Subset(dataset, - indices = - as.integer(which(fold != i) - 1)) + indices = + as.integer(which(fold != i) - 1) + ) # -1 for python 0-based indexing testDataset <- torch$utils$data$Subset(dataset, - indices = - as.integer(which(fold == i) - 1)) - estimator <- createEstimator(modelType = estimatorSettings$modelType, - modelParameters = modelSettings, - estimatorSettings = estimatorSettings) + indices = + as.integer(which(fold == i) - 1) + ) + estimator <- createEstimator( + modelType = estimatorSettings$modelType, + modelParameters = modelSettings, + estimatorSettings = estimatorSettings + ) estimator$fit(trainDataset, testDataset) ParallelLogger::logInfo("Calculating predictions on left out fold set...") @@ -534,9 +598,10 @@ doCrossvalidation <- function(dataset, bestEpoch = estimator$best_epoch ) } - return(results = list(prediction = prediction, - learnRates = learnRates)) - + return(results = list( + prediction = prediction, + learnRates = learnRates + )) } extractParamsToTune <- function(estimatorSettings) { diff --git a/R/LRFinder.R b/R/LRFinder.R index 4ab006c..a7cba73 100644 --- a/R/LRFinder.R +++ b/R/LRFinder.R @@ -22,25 +22,24 @@ createLRFinder <- function(modelType, path <- system.file("python", package = "DeepPatientLevelPrediction") lrFinderClass <- reticulate::import_from_path("LrFinder", path = path)$LrFinder - - + + model <- reticulate::import_from_path(modelType, path = path)[[modelType]] modelParameters <- camelCaseToSnakeCaseNames(modelParameters) estimatorSettings <- camelCaseToSnakeCaseNames(estimatorSettings) estimatorSettings <- evalEstimatorSettings(estimatorSettings) - - # estimator <- createEstimator(modelType = estimatorSettings$modelType, - # modelParameters = modelParameters, - # estimatorSettings = estimatorSettings) + if (!is.null(lrSettings)) { lrSettings <- camelCaseToSnakeCaseNames(lrSettings) } - - lrFinder <- lrFinderClass(model = model, - model_parameters = modelParameters, - estimator_settings = estimatorSettings, - lr_settings = lrSettings) + + lrFinder <- lrFinderClass( + model = model, + model_parameters = modelParameters, + estimator_settings = estimatorSettings, + lr_settings = lrSettings + ) return(lrFinder) } diff --git a/R/MLP.R b/R/MLP.R index cc309c8..0d2e14b 100644 --- a/R/MLP.R +++ b/R/MLP.R @@ -43,11 +43,13 @@ setMultiLayerPerceptron <- function(numLayers = c(1:8), dropout = c(seq(0, 0.3, 0.05)), sizeEmbedding = c(2^(6:9)), estimatorSettings = - setEstimator(learningRate = "auto", - weightDecay = c(1e-6, 1e-3), - batchSize = 1024, - epochs = 30, - device = "cpu"), + setEstimator( + learningRate = "auto", + weightDecay = c(1e-6, 1e-3), + batchSize = 1024, + epochs = 30, + device = "cpu" + ), hyperParamSearch = "random", randomSample = 100, randomSampleSeed = NULL) { @@ -81,17 +83,22 @@ setMultiLayerPerceptron <- function(numLayers = c(1:8), param <- PatientLevelPrediction::listCartesian(paramGrid) if (hyperParamSearch == "random" && randomSample > length(param)) { - stop(paste("\n Chosen amount of randomSamples is higher than the + stop(paste( + "\n Chosen amount of randomSamples is higher than the amount of possible hyperparameter combinations.", - "\n randomSample:", randomSample, "\n Possible hyperparameter + "\n randomSample:", randomSample, "\n Possible hyperparameter combinations:", length(param), - "\n Please lower the amount of randomSamples")) + "\n Please lower the amount of randomSamples" + )) } if (hyperParamSearch == "random") { - suppressWarnings(withr::with_seed(randomSampleSeed, - {param <- param[sample(length(param), - randomSample)]})) + suppressWarnings(withr::with_seed(randomSampleSeed, { + param <- param[sample( + length(param), + randomSample + )] + })) } attr(param, "settings")$modelType <- "MLP" diff --git a/R/ResNet.R b/R/ResNet.R index c12c6bb..698048b 100644 --- a/R/ResNet.R +++ b/R/ResNet.R @@ -29,21 +29,25 @@ #' @export setDefaultResNet <- function(estimatorSettings = - setEstimator(learningRate = "auto", - weightDecay = 1e-6, - device = "cpu", - batchSize = 1024, - epochs = 50, - seed = NULL)) { - resnetSettings <- setResNet(numLayers = 6, - sizeHidden = 512, - hiddenFactor = 2, - residualDropout = 0.1, - hiddenDropout = 0.4, - sizeEmbedding = 256, - estimatorSettings = estimatorSettings, - hyperParamSearch = "random", - randomSample = 1) + setEstimator( + learningRate = "auto", + weightDecay = 1e-6, + device = "cpu", + batchSize = 1024, + epochs = 50, + seed = NULL + )) { + resnetSettings <- setResNet( + numLayers = 6, + sizeHidden = 512, + hiddenFactor = 2, + residualDropout = 0.1, + hiddenDropout = 0.4, + sizeEmbedding = 256, + estimatorSettings = estimatorSettings, + hyperParamSearch = "random", + randomSample = 1 + ) attr(resnetSettings, "settings")$name <- "defaultResnet" return(resnetSettings) } @@ -83,12 +87,14 @@ setResNet <- function(numLayers = c(1:8), hiddenDropout = c(seq(0, 0.5, 0.05)), sizeEmbedding = c(2^(6:9)), estimatorSettings = - setEstimator(learningRate = "auto", - weightDecay = c(1e-6, 1e-3), - device = "cpu", - batchSize = 1024, - epochs = 30, - seed = NULL), + setEstimator( + learningRate = "auto", + weightDecay = c(1e-6, 1e-3), + device = "cpu", + batchSize = 1024, + epochs = 30, + seed = NULL + ), hyperParamSearch = "random", randomSample = 100, randomSampleSeed = NULL) { @@ -114,28 +120,35 @@ setResNet <- function(numLayers = c(1:8), checkIsClass(randomSampleSeed, c("numeric", "integer", "NULL")) - paramGrid <- list(numLayers = numLayers, - sizeHidden = sizeHidden, - hiddenFactor = hiddenFactor, - residualDropout = residualDropout, - hiddenDropout = hiddenDropout, - sizeEmbedding = sizeEmbedding) + paramGrid <- list( + numLayers = numLayers, + sizeHidden = sizeHidden, + hiddenFactor = hiddenFactor, + residualDropout = residualDropout, + hiddenDropout = hiddenDropout, + sizeEmbedding = sizeEmbedding + ) paramGrid <- c(paramGrid, estimatorSettings$paramsToTune) param <- PatientLevelPrediction::listCartesian(paramGrid) if (hyperParamSearch == "random" && randomSample > length(param)) { - stop(paste("\n Chosen amount of randomSamples is higher than the amount of + stop(paste( + "\n Chosen amount of randomSamples is higher than the amount of possible hyperparameter combinations.", "\n randomSample:", - randomSample, "\n Possible hyperparameter combinations:", - length(param), "\n Please lower the amount of randomSamples")) + randomSample, "\n Possible hyperparameter combinations:", + length(param), "\n Please lower the amount of randomSamples" + )) } if (hyperParamSearch == "random") { - suppressWarnings(withr::with_seed(randomSampleSeed, - {param <- param[sample(length(param), - randomSample)]})) + suppressWarnings(withr::with_seed(randomSampleSeed, { + param <- param[sample( + length(param), + randomSample + )] + })) } attr(param, "settings")$modelType <- "ResNet" results <- list( @@ -144,8 +157,10 @@ setResNet <- function(numLayers = c(1:8), estimatorSettings = estimatorSettings, modelType = "ResNet", saveType = "file", - modelParamNames = c("numLayers", "sizeHidden", "hiddenFactor", - "residualDropout", "hiddenDropout", "sizeEmbedding") + modelParamNames = c( + "numLayers", "sizeHidden", "hiddenFactor", + "residualDropout", "hiddenDropout", "sizeEmbedding" + ) ) class(results) <- "modelSettings" diff --git a/R/TrainingCache-class.R b/R/TrainingCache-class.R index fe7dfc7..6995162 100644 --- a/R/TrainingCache-class.R +++ b/R/TrainingCache-class.R @@ -69,10 +69,12 @@ trainingCache <- R6::R6Class( #' Check if cache is full #' @returns Boolen isFull = function() { - return(all(unlist(lapply(private$.paramPersistence$gridSearchPredictions, - function(x) !is.null(x$gridPerformance))))) + return(all(unlist(lapply( + private$.paramPersistence$gridSearchPredictions, + function(x) !is.null(x$gridPerformance) + )))) }, - + #' @description #' Gets the last index from the cached grid search #' @returns Last grid search index @@ -84,8 +86,10 @@ trainingCache <- R6::R6Class( if (length(private$.paramPersistence$gridSearchPredictions) == 1) { return(1) } else { - return(which(sapply(private$.paramPersistence$gridSearchPredictions, - is.null))[1]) + return(which(sapply( + private$.paramPersistence$gridSearchPredictions, + is.null + ))[1]) } } }, diff --git a/R/Transformer.R b/R/Transformer.R index 6992d47..de28e4b 100644 --- a/R/Transformer.R +++ b/R/Transformer.R @@ -25,24 +25,27 @@ #' #' @export setDefaultTransformer <- function(estimatorSettings = - setEstimator(learningRate = "auto", - weightDecay = 1e-4, - batchSize = 512, - epochs = 10, - seed = NULL, - device = "cpu") -) { - transformerSettings <- setTransformer(numBlocks = 3, - dimToken = 192, - dimOut = 1, - numHeads = 8, - attDropout = 0.2, - ffnDropout = 0.1, - resDropout = 0.0, - dimHidden = 256, - estimatorSettings = estimatorSettings, - hyperParamSearch = "random", - randomSample = 1) + setEstimator( + learningRate = "auto", + weightDecay = 1e-4, + batchSize = 512, + epochs = 10, + seed = NULL, + device = "cpu" + )) { + transformerSettings <- setTransformer( + numBlocks = 3, + dimToken = 192, + dimOut = 1, + numHeads = 8, + attDropout = 0.2, + ffnDropout = 0.1, + resDropout = 0.0, + dimHidden = 256, + estimatorSettings = estimatorSettings, + hyperParamSearch = "random", + randomSample = 1 + ) attr(transformerSettings, "settings")$name <- "defaultTransformer" return(transformerSettings) } @@ -81,14 +84,15 @@ setTransformer <- function(numBlocks = 3, resDropout = 0, dimHidden = 512, dimHiddenRatio = NULL, - estimatorSettings = setEstimator(weightDecay = 1e-6, - batchSize = 1024, - epochs = 10, - seed = NULL), + estimatorSettings = setEstimator( + weightDecay = 1e-6, + batchSize = 1024, + epochs = 10, + seed = NULL + ), hyperParamSearch = "random", randomSample = 1, randomSampleSeed = NULL) { - checkIsClass(numBlocks, c("integer", "numeric")) checkHigherEqual(numBlocks, 1) @@ -127,16 +131,18 @@ setTransformer <- function(numBlocks = 3, checkIsClass(randomSampleSeed, c("numeric", "integer", "NULL")) - if (any(with(expand.grid(dimToken = dimToken, numHeads = numHeads), - dimToken %% numHeads != 0))) { + if (any(with( + expand.grid(dimToken = dimToken, numHeads = numHeads), + dimToken %% numHeads != 0 + ))) { stop(paste( "dimToken needs to divisible by numHeads. dimToken =", dimToken, "is not divisible by numHeads =", numHeads )) } - if (is.null(dimHidden) && is.null(dimHiddenRatio) - || !is.null(dimHidden) && !is.null(dimHiddenRatio)) { + if (is.null(dimHidden) && is.null(dimHiddenRatio) || + !is.null(dimHidden) && !is.null(dimHiddenRatio)) { stop(paste( "dimHidden and dimHiddenRatio cannot be both set or both NULL" )) @@ -169,16 +175,21 @@ setTransformer <- function(numBlocks = 3, } if (hyperParamSearch == "random" && randomSample > length(param)) { - stop(paste("\n Chosen amount of randomSamples is higher than the amount of + stop(paste( + "\n Chosen amount of randomSamples is higher than the amount of possible hyperparameter combinations.", "\n randomSample:", - randomSample, "\n Possible hyperparameter combinations:", - length(param), "\n Please lower the amount of randomSample")) + randomSample, "\n Possible hyperparameter combinations:", + length(param), "\n Please lower the amount of randomSample" + )) } if (hyperParamSearch == "random") { - suppressWarnings(withr::with_seed(randomSampleSeed, - {param <- param[sample(length(param), - randomSample)]})) + suppressWarnings(withr::with_seed(randomSampleSeed, { + param <- param[sample( + length(param), + randomSample + )] + })) } attr(param, "settings")$modelType <- "Transformer" results <- list(