Skip to content

Commit

Permalink
Add test to ensure prediction is cached for optimal parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
lhjohn committed Oct 6, 2023
1 parent 9f8d23a commit f93a53f
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion tests/testthat/test-TrainingCache.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ resNetSettings <- setResNet(numLayers = c(1, 2, 4),
seed=NULL),
hyperParamSearch = "random",
randomSample = 3,
randomSampleSeed = NULL)
randomSampleSeed = 123)

trainCache <- TrainingCache$new(testLoc)
paramSearch <- resNetSettings$param
Expand Down Expand Up @@ -87,3 +87,44 @@ test_that("Estimator can resume training from cache", {
trainCache <- TrainingCache$new(analysisPath)
testthat::expect_equal(is.na(trainCache$getLastGridSearchIndex()), TRUE)
})

test_that("Prediction is cached for optimal parameters", {
modelPath <- tempdir()
analysisPath <- file.path(modelPath, "Analysis_TrainCacheResNet_GridSearchPred")
dir.create(analysisPath)

sink(nullfile())
res2 <- tryCatch(
{
PatientLevelPrediction::runPlp(
plpData = plpData,
outcomeId = 3,
modelSettings = resNetSettings,
analysisId = "Analysis_TrainCacheResNet_GridSearchPred",
analysisName = "Testing Training Cache - GridSearch",
populationSettings = populationSet,
splitSettings = PatientLevelPrediction::createDefaultSplitSetting(),
sampleSettings = PatientLevelPrediction::createSampleSettings(), # none
featureEngineeringSettings = PatientLevelPrediction::createFeatureEngineeringSettings(), # none
preprocessSettings = PatientLevelPrediction::createPreprocessSettings(),
executeSettings = PatientLevelPrediction::createExecuteSettings(
runSplitData = T,
runSampleData = F,
runfeatureEngineering = F,
runPreprocessData = T,
runModelDevelopment = T,
runCovariateSummary = F
),
saveDirectory = modelPath
)
},
error = function(e) {
print(e)
return(NULL)
}
)
sink()
testCache <- readRDS(file.path(analysisPath, "paramPersistence.RDS"))
indexOfMax <- which.max(unlist(lapply(testCache$gridSearchPredictions, function(x) x$gridPerformance$cvPerformance)))
testthat::expect_equal(class(testCache$gridSearchPredictions[[indexOfMax]]$prediction), class(data.frame()))
})

0 comments on commit f93a53f

Please sign in to comment.