From 9dd84bde981b82e29ec8bcc28d610e1abf64f0e8 Mon Sep 17 00:00:00 2001 From: Egill Fridgeirsson Date: Wed, 7 Sep 2022 14:55:17 +0200 Subject: [PATCH] Fix randomSample default value for transformer and add error messages if randomSamples is higher than amount of hyperparameters --- R/MLP.R | 7 ++++++- R/ResNet.R | 6 ++++++ R/Transformer.R | 8 +++++++- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/R/MLP.R b/R/MLP.R index dc5f56a..7df166a 100644 --- a/R/MLP.R +++ b/R/MLP.R @@ -66,7 +66,12 @@ setMultiLayerPerceptron <- function(numLayers = c(1:8), ) param <- PatientLevelPrediction::listCartesian(paramGrid) - + if (randomSample>length(param)) { + stop(paste("\n Chosen amount of randomSamples is higher than the amount of possible hyperparameter combinations.", + "\n randomSample:", randomSample,"\n Possible hyperparameter combinations:", length(param), + "\n Please lower the amount of randomSamples")) + } + if (hyperParamSearch == "random") { param <- param[sample(length(param), randomSample)] } diff --git a/R/ResNet.R b/R/ResNet.R index e8ce322..fafb40b 100644 --- a/R/ResNet.R +++ b/R/ResNet.R @@ -73,6 +73,12 @@ setResNet <- function(numLayers = c(1:8), param <- PatientLevelPrediction::listCartesian(paramGrid) + if (randomSample>length(param)) { + stop(paste("\n Chosen amount of randomSamples is higher than the amount of possible hyperparameter combinations.", + "\n randomSample:", randomSample,"\n Possible hyperparameter combinations:", length(param), + "\n Please lower the amount of randomSamples")) + } + if (hyperParamSearch == "random") { param <- param[sample(length(param), randomSample)] } diff --git a/R/Transformer.R b/R/Transformer.R index 0ab7227..ec4d469 100644 --- a/R/Transformer.R +++ b/R/Transformer.R @@ -26,7 +26,7 @@ setTransformer <- function(numBlocks = 3, dimToken = 96, dimOut = 1, resDropout = 0, dimHidden = 512, weightDecay = 1e-6, learningRate = 3e-4, batchSize = 1024, epochs = 10, device = "cpu", hyperParamSearch = "random", - randomSamples = 100, seed = NULL) { + randomSample = 1, seed = NULL) { if (is.null(seed)) { seed <- as.integer(sample(1e5, 1)) } @@ -54,6 +54,12 @@ setTransformer <- function(numBlocks = 3, dimToken = 96, dimOut = 1, param <- PatientLevelPrediction::listCartesian(paramGrid) + if (randomSample>length(param)) { + stop(paste("\n Chosen amount of randomSamples is higher than the amount of possible hyperparameter combinations.", + "\n randomSample:", randomSample,"\n Possible hyperparameter combinations:", length(param), + "\n Please lower the amount of randomSample")) + } + if (hyperParamSearch == "random") { param <- param[sample(length(param), randomSamples)] }