From bdf8bba57807c99aab2d0c438f472b9ce92e9e24 Mon Sep 17 00:00:00 2001 From: egillax Date: Fri, 13 Oct 2023 11:11:29 +0200 Subject: [PATCH] optimize tests --- tests/testthat/setup.R | 18 ++++++ tests/testthat/test-Estimator.R | 25 ++------ tests/testthat/test-TrainingCache.R | 92 +++++++---------------------- 3 files changed, 43 insertions(+), 92 deletions(-) diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index adf0dcb..7cd0fee 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -78,3 +78,21 @@ dataset <- Dataset$Data( ) small_dataset <- torch$utils$data$Subset(dataset, (1:round(length(dataset)/3))) +modelSettings <- setResNet( + numLayers = 1, sizeHidden = 16, hiddenFactor = 1, + residualDropout = c(0, 0.2), hiddenDropout = 0, + sizeEmbedding = 16, hyperParamSearch = "random", + randomSample = 2, + setEstimator(epochs=1, + learningRate = 3e-4) +) +fitEstimatorPath <- file.path(testLoc, 'fitEstimator') +if (!dir.exists(fitEstimatorPath)) { + dir.create(fitEstimatorPath) +} +fitEstimatorResults <- fitEstimator(trainData$Train, + modelSettings = modelSettings, + analysisId = 1, + analysisPath = fitEstimatorPath) + + diff --git a/tests/testthat/test-Estimator.R b/tests/testthat/test-Estimator.R index 0c6708f..b4dd0a4 100644 --- a/tests/testthat/test-Estimator.R +++ b/tests/testthat/test-Estimator.R @@ -55,8 +55,6 @@ test_that("Estimator detects wrong inputs", { testthat::expect_error(setEstimator(batchSize = "text")) testthat::expect_error(setEstimator(epochs = 0)) testthat::expect_error(setEstimator(epochs = "test")) - testthat::expect_error(setEstimator(device = 1)) - testthat::expect_error(setEstimator(scheduler = "notList")) testthat::expect_error(setEstimator(earlyStopping = "notListorNull")) testthat::expect_error(setEstimator(metric = 1)) testthat::expect_error(setEstimator(seed = "32")) @@ -148,25 +146,12 @@ test_that("early stopping works", { testthat::expect_true(earlyStop$early_stop) }) -modelSettings <- setResNet( - numLayers = 1, sizeHidden = 16, hiddenFactor = 1, - residualDropout = 0, hiddenDropout = 0, - sizeEmbedding = 16, hyperParamSearch = "random", - randomSample = 1, - setEstimator(epochs=1, - learningRate = 3e-4) -) - -sink(nullfile()) -results <- fitEstimator(trainData$Train, modelSettings = modelSettings, analysisId = 1, analysisPath = testLoc) -sink() - test_that("Estimator fit function works", { - expect_true(!is.null(results$trainDetails$trainingTime)) + expect_true(!is.null(fitEstimatorResults$trainDetails$trainingTime)) - expect_equal(class(results), "plpModel") - expect_equal(attr(results, "modelType"), "binary") - expect_equal(attr(results, "saveType"), "file") + expect_equal(class(fitEstimatorResults), "plpModel") + expect_equal(attr(fitEstimatorResults, "modelType"), "binary") + expect_equal(attr(fitEstimatorResults, "saveType"), "file") fakeTrainData <- trainData fakeTrainData$train$covariateData <- list(fakeCovData <- c("Fake")) expect_error(fitEstimator(fakeTrainData$train, modelSettings, analysisId = 1, analysisPath = testLoc)) @@ -186,7 +171,7 @@ test_that("predictDeepEstimator works", { # input is a plpModel and data sink(nullfile()) predictions <- predictDeepEstimator( - plpModel = results, data = trainData$Test, + plpModel = fitEstimatorResults, data = trainData$Test, trainData$Test$labels ) sink() diff --git a/tests/testthat/test-TrainingCache.R b/tests/testthat/test-TrainingCache.R index b577e13..debe95c 100644 --- a/tests/testthat/test-TrainingCache.R +++ b/tests/testthat/test-TrainingCache.R @@ -47,84 +47,32 @@ test_that("Param grid predictions can be cached", { }) test_that("Estimator can resume training from cache", { - modelPath <- tempdir() - analysisPath <- file.path(modelPath, "Analysis_TrainCacheResNet") - dir.create(analysisPath) - trainCache <- TrainingCache$new(analysisPath) - trainCache$saveModelParams(paramSearch) + trainCache <- readRDS(file.path(fitEstimatorPath, "paramPersistence.rds")) + newPath <- file.path(testLoc, 'resume') + dir.create(newPath) + + # remove last row + trainCache$gridSearchPredictions[[2]] <- NULL + length(trainCache$gridSearchPredictions) <- 2 + + # save new cache + saveRDS(trainCache, file=file.path(newPath, "paramPersistence.rds")) sink(nullfile()) - res2 <- tryCatch( - { - PatientLevelPrediction::runPlp( - plpData = plpData, - outcomeId = 3, - modelSettings = resNetSettings, - analysisId = "Analysis_TrainCacheResNet", - analysisName = "Testing Training Cache", - populationSettings = populationSet, - splitSettings = PatientLevelPrediction::createDefaultSplitSetting(), - sampleSettings = PatientLevelPrediction::createSampleSettings(), # none - featureEngineeringSettings = PatientLevelPrediction::createFeatureEngineeringSettings(), # none - preprocessSettings = PatientLevelPrediction::createPreprocessSettings(), - executeSettings = PatientLevelPrediction::createExecuteSettings( - runSplitData = T, - runSampleData = F, - runfeatureEngineering = F, - runPreprocessData = T, - runModelDevelopment = T, - runCovariateSummary = F - ), - saveDirectory = modelPath - ) - }, - error = function(e) { - print(e) - return(NULL) - } - ) + fitEstimatorResults <- fitEstimator(trainData$Train, + modelSettings = modelSettings, + analysisId = 1, + analysisPath = newPath) sink() - trainCache <- TrainingCache$new(analysisPath) - testthat::expect_equal(is.na(trainCache$getLastGridSearchIndex()), TRUE) + + newCache <- readRDS(file.path(newPath, "paramPersistence.rds")) + testthat::expect_equal(nrow(newCache$gridSearchPredictions[[2]]$gridPerformance$hyperSummary), 4) }) test_that("Prediction is cached for optimal parameters", { - modelPath <- tempdir() - analysisPath <- file.path(modelPath, "Analysis_TrainCacheResNet_GridSearchPred") - dir.create(analysisPath) - - sink(nullfile()) - res2 <- tryCatch( - { - PatientLevelPrediction::runPlp( - plpData = plpData, - outcomeId = 3, - modelSettings = resNetSettings, - analysisId = "Analysis_TrainCacheResNet_GridSearchPred", - analysisName = "Testing Training Cache - GridSearch", - populationSettings = populationSet, - splitSettings = PatientLevelPrediction::createDefaultSplitSetting(), - sampleSettings = PatientLevelPrediction::createSampleSettings(), # none - featureEngineeringSettings = PatientLevelPrediction::createFeatureEngineeringSettings(), # none - preprocessSettings = PatientLevelPrediction::createPreprocessSettings(), - executeSettings = PatientLevelPrediction::createExecuteSettings( - runSplitData = T, - runSampleData = F, - runfeatureEngineering = F, - runPreprocessData = T, - runModelDevelopment = T, - runCovariateSummary = F - ), - saveDirectory = modelPath - ) - }, - error = function(e) { - print(e) - return(NULL) - } - ) - sink() - testCache <- readRDS(file.path(analysisPath, "paramPersistence.rds")) + testCache <- readRDS(file.path(fitEstimatorPath, "paramPersistence.rds")) indexOfMax <- which.max(unlist(lapply(testCache$gridSearchPredictions, function(x) x$gridPerformance$cvPerformance))) + indexOfMin <- which.min(unlist(lapply(testCache$gridSearchPredictions, function(x) x$gridPerformance$cvPerformance))) testthat::expect_equal(class(testCache$gridSearchPredictions[[indexOfMax]]$prediction), class(data.frame())) + testthat::expect_null(testCache$gridSearchPredictions[[indexOfMin]]$prediction[[1]]) })