Skip to content

Commit

Permalink
add a defaultResNet with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Oct 9, 2022
1 parent 7d21ed7 commit 5356946
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 2 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export(Estimator)
export(fitEstimator)
export(gridCvDeep)
export(predictDeepEstimator)
export(setDefaultResNet)
export(setMultiLayerPerceptron)
export(setResNet)
export(setTransformer)
Expand Down
40 changes: 39 additions & 1 deletion R/ResNet.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.

#' setDefaultResNet
#'
#' @description
#' Creates settings for a default ResNet model
#'
#' @details
#' Model architecture from by https://arxiv.org/abs/2106.11959 .
#' Hyperparameters chosen by a experience on a few prediction problems.
#'
#' @param device Which device to run analysis on, either 'cpu' or 'cuda', default: 'cpu'
#' @param batchSize Size of batch, default: 1024
#' @param epochs Number of epochs to run, default: 10
#' @param seed Random seed to use

#' @export
setDefaultResNet <- function(device='cpu',
batchSize=1024,
epochs=10,
seed=NULL) {

resnetSettings <- setResNet(numLayers = 6,
sizeHidden = 512,
hiddenFactor = 2,
residualDropout = 0.1,
hiddenDropout = 0.4,
sizeEmbedding = 256,
weightDecay = 1e-6,
learningRate = 0.01,
hyperParamSearch = 'random',
randomSample = 1,
device = device,
batchSize = batchSize,
seed = seed)
attr(resnetSettings, 'settings')$name <- 'defaultResnet'
return(resnetSettings)
}


#' setResNet
#'
#' @description
Expand All @@ -42,7 +80,7 @@
#'
#' @export
setResNet <- function(numLayers = c(1:8),
sizeHidden = c(2^(6:9)),
sizeHidden = c(2^(6:10)),
hiddenFactor = c(1:4),
residualDropout = c(seq(0, 0.5, 0.05)),
hiddenDropout = c(seq(0, 0.5, 0.05)),
Expand Down
24 changes: 24 additions & 0 deletions man/setDefaultResNet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/setResNet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions tests/testthat/test-ResNet.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,16 @@ test_that("ResNet nn-module works ", {
# model works without numeric variables
expect_equal(output$shape, 10)
})

test_that("Default Resnet works", {
defaultResNet <- setDefaultResNet()
params <- defaultResNet$param[[1]]

expect_equal(params$numLayers, 6)
expect_equal(params$sizeHidden, 512)
expect_equal(params$hiddenFactor, 2)
expect_equal(params$residualDropout, 0.1)
expect_equal(params$hiddenDropout, 0.4)
expect_equal(params$sizeEmbedding, 256)

})

0 comments on commit 5356946

Please sign in to comment.