Skip to content

Commit

Permalink
fix param vs modelSettings input to fitFunctions
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Aug 17, 2022
1 parent 4874f1f commit 628d07b
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 10 deletions.
5 changes: 3 additions & 2 deletions R/DeepNNTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ setDeepNNTorch <- function(

#' Fits a deep neural network
#' @param trainData Training data object
#' @param param Hyperparameters to search over
#' @param modelSettings modelSettings object
#' @param search Which kind of search strategy to use
#' @param analysisId Analysis Id
#' @export
fitDeepNNTorch <- function(
trainData,
param,
modelSettings,
search='grid',
analysisId)
{
Expand All @@ -80,6 +80,7 @@ fitDeepNNTorch <- function(
stop('DeepNNTorch requires correct covariateData')
}

param <- modelSettings$param
# get the settings from the param
settings <- attr(param, 'settings')

Expand Down
6 changes: 4 additions & 2 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
#' fits a deep learning estimator to data.
#'
#' @param trainData the data to use
#' @param param model parameters
#' @param modelSettings modelSettings object
#' @param analysisId Id of the analysis
#' @param ... Extra inputs
#'
#' @export
fitEstimator <- function(
trainData,
param,
modelSettings,
analysisId,
...
) {
Expand All @@ -39,6 +39,8 @@ fitEstimator <- function(
# check covariate data
if(!FeatureExtraction::isCovariateData(trainData$covariateData)){stop("Needs correct covariateData")}

param <- modelSettings$param

# get the settings from the param
settings <- attr(param, 'settings')
if(!is.null(trainData$folds)){
Expand Down
4 changes: 2 additions & 2 deletions man/fitDeepNNTorch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/fitEstimator.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ modelSettings <- setResNet(numLayers=1, sizeHidden=16, hiddenFactor=1,
randomSample = 1, epochs=1)

sink(nullfile())
results <- fitEstimator(trainData$Train, param = modelSettings$param, analysisId = 1)
results <- fitEstimator(trainData$Train, modelSettings = modelSettings, analysisId = 1)
sink()

test_that('Estimator fit function works', {
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-Transformer.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ test_that('Transformer settings work', {

test_that('fitEstimator with Transformer works', {

results <- fitEstimator(trainData$Train, settings$param, analysisId=1)
results <- fitEstimator(trainData$Train, settings, analysisId=1)

expect_equal(class(results), 'plpModel')
expect_equal(attr(results, 'modelType'), 'binary')
Expand Down

0 comments on commit 628d07b

Please sign in to comment.