Skip to content

Commit

Permalink
Fix randomSample default value for transformer and add error messages…
Browse files Browse the repository at this point in the history
… if randomSamples is higher than amount of hyperparameters
  • Loading branch information
egillax committed Sep 7, 2022
1 parent 59446fc commit 4430e13
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
7 changes: 6 additions & 1 deletion R/MLP.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ setMultiLayerPerceptron <- function(numLayers = c(1:8),
)

param <- PatientLevelPrediction::listCartesian(paramGrid)

if (randomSamples>length(param)) {
stop(paste("\n Chosen amount of randomSamples is higher than the amount of possible hyperparameter combinations.",
"\n randomSamples:", randomSamples,"\n Possible hyperparameter combinations:", length(param),
"\n Please lower the amount of randomSamples"))
}

if (hyperParamSearch == "random") {
param <- param[sample(length(param), randomSample)]
}
Expand Down
6 changes: 6 additions & 0 deletions R/ResNet.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ setResNet <- function(numLayers = c(1:8),

param <- PatientLevelPrediction::listCartesian(paramGrid)

if (randomSamples>length(param)) {
stop(paste("\n Chosen amount of randomSamples is higher than the amount of possible hyperparameter combinations.",
"\n randomSamples:", randomSamples,"\n Possible hyperparameter combinations:", length(param),
"\n Please lower the amount of randomSamples"))
}

if (hyperParamSearch == "random") {
param <- param[sample(length(param), randomSample)]
}
Expand Down
8 changes: 7 additions & 1 deletion R/Transformer.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ setTransformer <- function(numBlocks = 3, dimToken = 96, dimOut = 1,
resDropout = 0, dimHidden = 512, weightDecay = 1e-6,
learningRate = 3e-4, batchSize = 1024,
epochs = 10, device = "cpu", hyperParamSearch = "random",
randomSamples = 100, seed = NULL) {
randomSamples = 1, seed = NULL) {
if (is.null(seed)) {
seed <- as.integer(sample(1e5, 1))
}
Expand Down Expand Up @@ -54,6 +54,12 @@ setTransformer <- function(numBlocks = 3, dimToken = 96, dimOut = 1,

param <- PatientLevelPrediction::listCartesian(paramGrid)

if (randomSamples>length(param)) {
stop(paste("\n Chosen amount of randomSamples is higher than the amount of possible hyperparameter combinations.",
"\n randomSamples:", randomSamples,"\n Possible hyperparameter combinations:", length(param),
"\n Please lower the amount of randomSamples"))
}

if (hyperParamSearch == "random") {
param <- param[sample(length(param), randomSamples)]
}
Expand Down

0 comments on commit 4430e13

Please sign in to comment.