Skip to content

Commit

Permalink
Resolve an issue with hidden dimension ratio (#74)
Browse files Browse the repository at this point in the history
* Resolve an issue with hidden dimension ratio

* Optimize solution

* add test case

---------

Co-authored-by: egillax <[email protected]>
  • Loading branch information
lhjohn and egillax authored Jun 22, 2023
1 parent 506b940 commit 74608ff
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
4 changes: 2 additions & 2 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ fitEstimator <- function(trainData,
hyperSummary <- do.call(rbind, lapply(cvResult$paramGridSearch, function(x) x$hyperSummary))
prediction <- cvResult$prediction
incs <- rep(1, covariateRef %>% dplyr::tally() %>%
dplyr::collect () %>%
as.integer)
dplyr::collect() %>%
as.integer())
covariateRef <- covariateRef %>%
dplyr::collect() %>%
dplyr::mutate(
Expand Down
9 changes: 8 additions & 1 deletion R/Transformer.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ setTransformer <- function(numBlocks = 3, dimToken = 96, dimOut = 1,
))
} else {
if (!is.null(dimHiddenRatio)) {
dimHidden <- round(dimToken*dimHiddenRatio, digits = 0)
dimHidden <- dimHiddenRatio
}
}

Expand All @@ -110,6 +110,13 @@ setTransformer <- function(numBlocks = 3, dimToken = 96, dimOut = 1,

param <- PatientLevelPrediction::listCartesian(paramGrid)

if (!is.null(dimHiddenRatio)) {
param <- lapply(param, function(x) {
x$dimHidden <- round(x$dimToken*x$dimHidden, digits = 0)
return(x)
})
}

if (hyperParamSearch == "random" && randomSample>length(param)) {
stop(paste("\n Chosen amount of randomSamples is higher than the amount of possible hyperparameter combinations.",
"\n randomSample:", randomSample,"\n Possible hyperparameter combinations:", length(param),
Expand Down
8 changes: 3 additions & 5 deletions extras/example.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ populationSet <- PatientLevelPrediction::createStudyPopulationSettings(
riskWindowEnd = 365)


modelSettings <- setResNet(numLayers = 2, sizeHidden = 64, hiddenFactor = 1,
residualDropout = 0, hiddenDropout = 0.2, normalization = 'BatchNorm',
activation = 'RelU', sizeEmbedding = 512, weightDecay = 1e-6,
learningRate = 3e-4, seed = 42, hyperParamSearch = 'random',
randomSample = 1, device = 'cuda:0',batchSize = 32,epochs = 10)
modelSettings <- setDefaultResNet(estimatorSettings = setEstimator(epochs=1L,
device='cuda:0',
batchSize=128L))

# modelSettings <- setTransformer(numBlocks=1, dimToken = 33, dimOut = 1, numHeads = 3,
# attDropout = 0.2, ffnDropout = 0.2, resDropout = 0,
Expand Down
14 changes: 14 additions & 0 deletions tests/testthat/test-Transformer.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,17 @@ test_that("Errors are produced by settings function", {

expect_error(setTransformer(randomSample = randomSample))
})

test_that("dimHidden ratio works as expected", {
randomSample <- 4
dimToken <- c(64, 128, 256, 512)
dimHiddenRatio <- 2
modelSettings <- setTransformer(dimToken = dimToken,
dimHiddenRatio = dimHiddenRatio,
dimHidden = NULL,
randomSample = randomSample)
dimHidden <- unlist(lapply(modelSettings$param, function(x) x$dimHidden))
tokens <- unlist(lapply(modelSettings$param, function(x) x$dimToken))
expect_true(all(dimHidden == dimHiddenRatio * tokens))

})

0 comments on commit 74608ff

Please sign in to comment.