Skip to content

Commit

Permalink
optimize tests
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Oct 13, 2023
1 parent def727f commit bdf8bba
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 92 deletions.
18 changes: 18 additions & 0 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,21 @@ dataset <- Dataset$Data(
)
small_dataset <- torch$utils$data$Subset(dataset, (1:round(length(dataset)/3)))

modelSettings <- setResNet(
numLayers = 1, sizeHidden = 16, hiddenFactor = 1,
residualDropout = c(0, 0.2), hiddenDropout = 0,
sizeEmbedding = 16, hyperParamSearch = "random",
randomSample = 2,
setEstimator(epochs=1,
learningRate = 3e-4)
)
fitEstimatorPath <- file.path(testLoc, 'fitEstimator')
if (!dir.exists(fitEstimatorPath)) {
dir.create(fitEstimatorPath)
}
fitEstimatorResults <- fitEstimator(trainData$Train,
modelSettings = modelSettings,
analysisId = 1,
analysisPath = fitEstimatorPath)


25 changes: 5 additions & 20 deletions tests/testthat/test-Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ test_that("Estimator detects wrong inputs", {
testthat::expect_error(setEstimator(batchSize = "text"))
testthat::expect_error(setEstimator(epochs = 0))
testthat::expect_error(setEstimator(epochs = "test"))
testthat::expect_error(setEstimator(device = 1))
testthat::expect_error(setEstimator(scheduler = "notList"))
testthat::expect_error(setEstimator(earlyStopping = "notListorNull"))
testthat::expect_error(setEstimator(metric = 1))
testthat::expect_error(setEstimator(seed = "32"))
Expand Down Expand Up @@ -148,25 +146,12 @@ test_that("early stopping works", {
testthat::expect_true(earlyStop$early_stop)
})

modelSettings <- setResNet(
numLayers = 1, sizeHidden = 16, hiddenFactor = 1,
residualDropout = 0, hiddenDropout = 0,
sizeEmbedding = 16, hyperParamSearch = "random",
randomSample = 1,
setEstimator(epochs=1,
learningRate = 3e-4)
)

sink(nullfile())
results <- fitEstimator(trainData$Train, modelSettings = modelSettings, analysisId = 1, analysisPath = testLoc)
sink()

test_that("Estimator fit function works", {
expect_true(!is.null(results$trainDetails$trainingTime))
expect_true(!is.null(fitEstimatorResults$trainDetails$trainingTime))

expect_equal(class(results), "plpModel")
expect_equal(attr(results, "modelType"), "binary")
expect_equal(attr(results, "saveType"), "file")
expect_equal(class(fitEstimatorResults), "plpModel")
expect_equal(attr(fitEstimatorResults, "modelType"), "binary")
expect_equal(attr(fitEstimatorResults, "saveType"), "file")
fakeTrainData <- trainData
fakeTrainData$train$covariateData <- list(fakeCovData <- c("Fake"))
expect_error(fitEstimator(fakeTrainData$train, modelSettings, analysisId = 1, analysisPath = testLoc))
Expand All @@ -186,7 +171,7 @@ test_that("predictDeepEstimator works", {
# input is a plpModel and data
sink(nullfile())
predictions <- predictDeepEstimator(
plpModel = results, data = trainData$Test,
plpModel = fitEstimatorResults, data = trainData$Test,
trainData$Test$labels
)
sink()
Expand Down
92 changes: 20 additions & 72 deletions tests/testthat/test-TrainingCache.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,84 +47,32 @@ test_that("Param grid predictions can be cached", {
})

test_that("Estimator can resume training from cache", {
modelPath <- tempdir()
analysisPath <- file.path(modelPath, "Analysis_TrainCacheResNet")
dir.create(analysisPath)
trainCache <- TrainingCache$new(analysisPath)
trainCache$saveModelParams(paramSearch)
trainCache <- readRDS(file.path(fitEstimatorPath, "paramPersistence.rds"))
newPath <- file.path(testLoc, 'resume')
dir.create(newPath)

# remove last row
trainCache$gridSearchPredictions[[2]] <- NULL
length(trainCache$gridSearchPredictions) <- 2

# save new cache
saveRDS(trainCache, file=file.path(newPath, "paramPersistence.rds"))

sink(nullfile())
res2 <- tryCatch(
{
PatientLevelPrediction::runPlp(
plpData = plpData,
outcomeId = 3,
modelSettings = resNetSettings,
analysisId = "Analysis_TrainCacheResNet",
analysisName = "Testing Training Cache",
populationSettings = populationSet,
splitSettings = PatientLevelPrediction::createDefaultSplitSetting(),
sampleSettings = PatientLevelPrediction::createSampleSettings(), # none
featureEngineeringSettings = PatientLevelPrediction::createFeatureEngineeringSettings(), # none
preprocessSettings = PatientLevelPrediction::createPreprocessSettings(),
executeSettings = PatientLevelPrediction::createExecuteSettings(
runSplitData = T,
runSampleData = F,
runfeatureEngineering = F,
runPreprocessData = T,
runModelDevelopment = T,
runCovariateSummary = F
),
saveDirectory = modelPath
)
},
error = function(e) {
print(e)
return(NULL)
}
)
fitEstimatorResults <- fitEstimator(trainData$Train,
modelSettings = modelSettings,
analysisId = 1,
analysisPath = newPath)
sink()
trainCache <- TrainingCache$new(analysisPath)
testthat::expect_equal(is.na(trainCache$getLastGridSearchIndex()), TRUE)

newCache <- readRDS(file.path(newPath, "paramPersistence.rds"))
testthat::expect_equal(nrow(newCache$gridSearchPredictions[[2]]$gridPerformance$hyperSummary), 4)
})

test_that("Prediction is cached for optimal parameters", {
modelPath <- tempdir()
analysisPath <- file.path(modelPath, "Analysis_TrainCacheResNet_GridSearchPred")
dir.create(analysisPath)

sink(nullfile())
res2 <- tryCatch(
{
PatientLevelPrediction::runPlp(
plpData = plpData,
outcomeId = 3,
modelSettings = resNetSettings,
analysisId = "Analysis_TrainCacheResNet_GridSearchPred",
analysisName = "Testing Training Cache - GridSearch",
populationSettings = populationSet,
splitSettings = PatientLevelPrediction::createDefaultSplitSetting(),
sampleSettings = PatientLevelPrediction::createSampleSettings(), # none
featureEngineeringSettings = PatientLevelPrediction::createFeatureEngineeringSettings(), # none
preprocessSettings = PatientLevelPrediction::createPreprocessSettings(),
executeSettings = PatientLevelPrediction::createExecuteSettings(
runSplitData = T,
runSampleData = F,
runfeatureEngineering = F,
runPreprocessData = T,
runModelDevelopment = T,
runCovariateSummary = F
),
saveDirectory = modelPath
)
},
error = function(e) {
print(e)
return(NULL)
}
)
sink()
testCache <- readRDS(file.path(analysisPath, "paramPersistence.rds"))
testCache <- readRDS(file.path(fitEstimatorPath, "paramPersistence.rds"))
indexOfMax <- which.max(unlist(lapply(testCache$gridSearchPredictions, function(x) x$gridPerformance$cvPerformance)))
indexOfMin <- which.min(unlist(lapply(testCache$gridSearchPredictions, function(x) x$gridPerformance$cvPerformance)))
testthat::expect_equal(class(testCache$gridSearchPredictions[[indexOfMax]]$prediction), class(data.frame()))
testthat::expect_null(testCache$gridSearchPredictions[[indexOfMin]]$prediction[[1]])
})

0 comments on commit bdf8bba

Please sign in to comment.