From f93a53f2e4c940d0b18cb08bf2b2c71e547f826d Mon Sep 17 00:00:00 2001 From: Henrik John Date: Fri, 6 Oct 2023 13:21:07 +0200 Subject: [PATCH] Add test to ensure prediction is cached for optimal parameters --- tests/testthat/test-TrainingCache.R | 43 ++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/testthat/test-TrainingCache.R b/tests/testthat/test-TrainingCache.R index eb4ab17..5d639af 100644 --- a/tests/testthat/test-TrainingCache.R +++ b/tests/testthat/test-TrainingCache.R @@ -12,7 +12,7 @@ resNetSettings <- setResNet(numLayers = c(1, 2, 4), seed=NULL), hyperParamSearch = "random", randomSample = 3, - randomSampleSeed = NULL) + randomSampleSeed = 123) trainCache <- TrainingCache$new(testLoc) paramSearch <- resNetSettings$param @@ -87,3 +87,44 @@ test_that("Estimator can resume training from cache", { trainCache <- TrainingCache$new(analysisPath) testthat::expect_equal(is.na(trainCache$getLastGridSearchIndex()), TRUE) }) + +test_that("Prediction is cached for optimal parameters", { + modelPath <- tempdir() + analysisPath <- file.path(modelPath, "Analysis_TrainCacheResNet_GridSearchPred") + dir.create(analysisPath) + + sink(nullfile()) + res2 <- tryCatch( + { + PatientLevelPrediction::runPlp( + plpData = plpData, + outcomeId = 3, + modelSettings = resNetSettings, + analysisId = "Analysis_TrainCacheResNet_GridSearchPred", + analysisName = "Testing Training Cache - GridSearch", + populationSettings = populationSet, + splitSettings = PatientLevelPrediction::createDefaultSplitSetting(), + sampleSettings = PatientLevelPrediction::createSampleSettings(), # none + featureEngineeringSettings = PatientLevelPrediction::createFeatureEngineeringSettings(), # none + preprocessSettings = PatientLevelPrediction::createPreprocessSettings(), + executeSettings = PatientLevelPrediction::createExecuteSettings( + runSplitData = T, + runSampleData = F, + runfeatureEngineering = F, + runPreprocessData = T, + runModelDevelopment = T, + runCovariateSummary = F + ), + saveDirectory = modelPath + ) + }, + error = function(e) { + print(e) + return(NULL) + } + ) + sink() + testCache <- readRDS(file.path(analysisPath, "paramPersistence.RDS")) + indexOfMax <- which.max(unlist(lapply(testCache$gridSearchPredictions, function(x) x$gridPerformance$cvPerformance))) + testthat::expect_equal(class(testCache$gridSearchPredictions[[indexOfMax]]$prediction), class(data.frame())) +})