From 74608ffb7543371f11863d41c8beb3ca28abc6dd Mon Sep 17 00:00:00 2001 From: Henrik Date: Thu, 22 Jun 2023 10:55:44 +0200 Subject: [PATCH] Resolve an issue with hidden dimension ratio (#74) * Resolve an issue with hidden dimension ratio * Optimize solution * add test case --------- Co-authored-by: egillax --- R/Estimator.R | 4 ++-- R/Transformer.R | 9 ++++++++- extras/example.R | 8 +++----- tests/testthat/test-Transformer.R | 14 ++++++++++++++ 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/R/Estimator.R b/R/Estimator.R index 529c315..1452503 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -138,8 +138,8 @@ fitEstimator <- function(trainData, hyperSummary <- do.call(rbind, lapply(cvResult$paramGridSearch, function(x) x$hyperSummary)) prediction <- cvResult$prediction incs <- rep(1, covariateRef %>% dplyr::tally() %>% - dplyr::collect () %>% - as.integer) + dplyr::collect() %>% + as.integer()) covariateRef <- covariateRef %>% dplyr::collect() %>% dplyr::mutate( diff --git a/R/Transformer.R b/R/Transformer.R index 7e25ff7..04212fd 100644 --- a/R/Transformer.R +++ b/R/Transformer.R @@ -91,7 +91,7 @@ setTransformer <- function(numBlocks = 3, dimToken = 96, dimOut = 1, )) } else { if (!is.null(dimHiddenRatio)) { - dimHidden <- round(dimToken*dimHiddenRatio, digits = 0) + dimHidden <- dimHiddenRatio } } @@ -110,6 +110,13 @@ setTransformer <- function(numBlocks = 3, dimToken = 96, dimOut = 1, param <- PatientLevelPrediction::listCartesian(paramGrid) + if (!is.null(dimHiddenRatio)) { + param <- lapply(param, function(x) { + x$dimHidden <- round(x$dimToken*x$dimHidden, digits = 0) + return(x) + }) + } + if (hyperParamSearch == "random" && randomSample>length(param)) { stop(paste("\n Chosen amount of randomSamples is higher than the amount of possible hyperparameter combinations.", "\n randomSample:", randomSample,"\n Possible hyperparameter combinations:", length(param), diff --git a/extras/example.R b/extras/example.R index fa989b1..49ead98 100644 --- a/extras/example.R +++ b/extras/example.R @@ -16,11 +16,9 @@ populationSet <- PatientLevelPrediction::createStudyPopulationSettings( riskWindowEnd = 365) -modelSettings <- setResNet(numLayers = 2, sizeHidden = 64, hiddenFactor = 1, - residualDropout = 0, hiddenDropout = 0.2, normalization = 'BatchNorm', - activation = 'RelU', sizeEmbedding = 512, weightDecay = 1e-6, - learningRate = 3e-4, seed = 42, hyperParamSearch = 'random', - randomSample = 1, device = 'cuda:0',batchSize = 32,epochs = 10) +modelSettings <- setDefaultResNet(estimatorSettings = setEstimator(epochs=1L, + device='cuda:0', + batchSize=128L)) # modelSettings <- setTransformer(numBlocks=1, dimToken = 33, dimOut = 1, numHeads = 3, # attDropout = 0.2, ffnDropout = 0.2, resDropout = 0, diff --git a/tests/testthat/test-Transformer.R b/tests/testthat/test-Transformer.R index 94c2e48..52139ad 100644 --- a/tests/testthat/test-Transformer.R +++ b/tests/testthat/test-Transformer.R @@ -90,3 +90,17 @@ test_that("Errors are produced by settings function", { expect_error(setTransformer(randomSample = randomSample)) }) + +test_that("dimHidden ratio works as expected", { + randomSample <- 4 + dimToken <- c(64, 128, 256, 512) + dimHiddenRatio <- 2 + modelSettings <- setTransformer(dimToken = dimToken, + dimHiddenRatio = dimHiddenRatio, + dimHidden = NULL, + randomSample = randomSample) + dimHidden <- unlist(lapply(modelSettings$param, function(x) x$dimHidden)) + tokens <- unlist(lapply(modelSettings$param, function(x) x$dimToken)) + expect_true(all(dimHidden == dimHiddenRatio * tokens)) + +})