Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 2.0 #81

Merged
merged 76 commits into from
Sep 8, 2023
Merged
Changes from 2 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
64fce91
speed up data conversion about 3-5x
egillax Nov 14, 2022
30e67fd
updated news
egillax Nov 14, 2022
131c434
make sure tensorList matches dataset length
egillax Nov 14, 2022
38504b3
user integers for tensorList
egillax Nov 14, 2022
5d89faa
test a different way of getting torch binaries
egillax Dec 15, 2022
2f8d6d6
test a different way of getting torch binaries
egillax Dec 15, 2022
e79a43f
fix tidyselect warnings
egillax Dec 15, 2022
6c69427
Merge branch 'main' into 43-hotfix-lantern-binaries
egillax Dec 15, 2022
887fac9
Update description and news
egillax Dec 15, 2022
72f813e
remove torch install env variable from actions
egillax Dec 15, 2022
d66c47e
merge hotfix to develop
egillax Dec 16, 2022
ad0ad19
fixed dataset
egillax Jan 20, 2023
03ff99b
fix numericalIndex and address pull warning
egillax Jan 20, 2023
517f675
Allow dimToken and numHeads to take the form of vectors
lhjohn Jan 26, 2023
b87b769
Merge pull request #47 from OHDSI/transformer-bug
lhjohn Jan 26, 2023
dfca29a
Modeltype fix (#48)
egillax Jan 27, 2023
65302bd
update default ResNet and Transformer to have custom LR and WD
egillax Feb 13, 2023
353e0ba
Add seed for sampling hyperparameter combinations (#50)
lhjohn Feb 15, 2023
926e7d0
Lr find (#51)
egillax Mar 1, 2023
e580c2a
Derive dimension of feedforward block from embedding dimension (#53)
lhjohn Mar 5, 2023
fba85ec
Divisible check for Transformer not comprehensive (#55)
lhjohn Mar 5, 2023
e50d2e4
Update NEWS.md
egillax Mar 6, 2023
b66de7f
Merge branch 'main' into develop
egillax Mar 22, 2023
051ea88
update website and docs
egillax Mar 22, 2023
793338a
remove docs folder from code branches
egillax Mar 22, 2023
d0cb45b
render dev website
egillax Mar 22, 2023
609df5b
fix action
egillax Mar 22, 2023
3330c28
fix action
egillax Mar 22, 2023
dcb9423
fix badge in readme
egillax Mar 22, 2023
ce5e235
prepare version for release
egillax Mar 22, 2023
2d07b6b
Update DESCRIPTION
egillax Mar 23, 2023
05fecc2
modelType as attribute and tests to cover database upload
egillax Mar 24, 2023
32a7e23
modelType as attribute and tests to cover database upload
egillax Mar 24, 2023
0a0ac73
Merge branch '59-hotfix-modelType' of https://github.com/OHDSI/DeepPa…
egillax Mar 24, 2023
85ce433
Merge branch '59-hotfix-modelType' of https://github.com/OHDSI/DeepPa…
egillax Mar 24, 2023
6d433a3
Merge branch '59-hotfix-modelType' of https://github.com/OHDSI/DeepPa…
egillax Mar 24, 2023
c5a984f
fix dependanceis
egillax Mar 24, 2023
7541a38
prepare version and news for release
egillax Mar 24, 2023
3b02ce6
merged with hotfix branch
egillax Mar 24, 2023
4cd2e30
modelType attribute back to modelSettings functions
egillax Mar 24, 2023
a390d4e
Merge branch 'develop' of https://github.com/OHDSI/DeepPatientLevelPr…
egillax Mar 24, 2023
db13265
Merge branch 'main' into develop
egillax Mar 24, 2023
9c19173
Update DESCRIPTION
egillax Mar 27, 2023
808ead8
debug actions
egillax Apr 16, 2023
8e1de7f
Update R_CDM_check_hades.yaml
egillax Apr 17, 2023
d18d579
torch install environment variable
egillax Apr 17, 2023
55bfca2
Merge branch 'debug-actions' of https://github.com/OHDSI/DeepPatientL…
egillax Apr 17, 2023
510f4f1
update version and news
egillax Apr 17, 2023
5e09d1d
merged with debug-actions
egillax Apr 17, 2023
a1cb2e7
add device as expression with tests (#66)
egillax Apr 18, 2023
01ce148
Merge branch 'main' into develop
egillax Apr 18, 2023
2113a48
remove torchopt
egillax Apr 18, 2023
575d2e5
update news and version
egillax Apr 18, 2023
20d4a0d
fix docs
egillax Apr 18, 2023
f97b37f
update version number
egillax Apr 19, 2023
783f417
LRFinder works with device fun (#68)
egillax Apr 24, 2023
0a6f9de
update version and news
egillax Apr 24, 2023
58df70c
Merge branch 'main' into develop
egillax Apr 24, 2023
2883b5e
update version
egillax Apr 25, 2023
8a01ed7
fix bug when test subject has no features
egillax Jun 18, 2023
e555710
Add parameter caching for training persistence and continuity (#63)
lhjohn Jun 18, 2023
1e640ee
fix incs issue
egillax Jun 18, 2023
6326cd6
Merge branch 'develop' of https://github.com/OHDSI/DeepPatientLevelPr…
egillax Jun 18, 2023
5d9dc59
Release version and news updated
egillax Jun 18, 2023
af02541
Merge branch 'main' into develop
egillax Jun 18, 2023
506b940
Release and NEWS
egillax Jun 18, 2023
74608ff
Resolve an issue with hidden dimension ratio (#74)
lhjohn Jun 22, 2023
ba60c28
Cache single hyperparameter combination (#78)
lhjohn Jul 20, 2023
bd9b357
Change backend to pytorch (#80)
egillax Aug 28, 2023
04c2b36
Merge branch 'main' into develop
egillax Aug 28, 2023
e501124
fix dataset
egillax Aug 28, 2023
31ef832
update PLP version in DESCRIPTION
egillax Sep 7, 2023
66b8c84
integer handling in python and input checks (#83)
egillax Sep 7, 2023
85be689
Ensure that param search is completed in empty cache test (#84)
lhjohn Sep 7, 2023
4c83897
use ubuntu 22.04 in CI (#85)
egillax Sep 7, 2023
904e926
Update NEWS.md
egillax Sep 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@ Suggests:
ResultModelManager (>= 0.2.0),
DatabaseConnector (>= 6.0.0)
Remotes:
ohdsi/PatientLevelPrediction,
ohdsi/PatientLevelPrediction@develop,
ohdsi/FeatureExtraction,
ohdsi/Eunomia,
ohdsi/ResultModelManager
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@

export(Dataset)
export(Estimator)
export(TrainingCache)
export(fitEstimator)
export(gridCvDeep)
export(lrFinder)
27 changes: 22 additions & 5 deletions R/Estimator.R
Original file line number Diff line number Diff line change
@@ -97,12 +97,14 @@ setEstimator <- function(learningRate='auto',
#' @param trainData the data to use
#' @param modelSettings modelSettings object
#' @param analysisId Id of the analysis
#' @param analysisPath Path of the analysis
#' @param ... Extra inputs
#'
#' @export
fitEstimator <- function(trainData,
modelSettings,
analysisId,
analysisPath,
...) {
start <- Sys.time()

@@ -128,7 +130,8 @@ fitEstimator <- function(trainData,
mappedData = mappedCovariateData,
labels = trainData$labels,
modelSettings = modelSettings,
modelLocation = outLoc
modelLocation = outLoc,
analysisPath = analysisPath
)
)

@@ -251,26 +254,37 @@ predictDeepEstimator <- function(plpModel,
#' @param labels Dataframe with the outcomes
#' @param modelSettings Settings of the model
#' @param modelLocation Where to save the model
#' @param analysisPath Path of the analysis
#'
#' @export
gridCvDeep <- function(mappedData,
labels,
modelSettings,
modelLocation) {
modelLocation,
analysisPath) {
ParallelLogger::logInfo(paste0("Running hyperparameter search for ", modelSettings$modelType, " model"))

###########################################################################

paramSearch <- modelSettings$param
gridSearchPredictons <- list()
length(gridSearchPredictons) <- length(paramSearch)
trainCache <- TrainingCache$new(analysisPath)

if (trainCache$isParamGridIdentical(paramSearch)) {
gridSearchPredictons <- trainCache$getGridSearchPredictions()
} else {
gridSearchPredictons <- list()
length(gridSearchPredictons) <- length(paramSearch)
trainCache$saveGridSearchPredictions(gridSearchPredictons)
trainCache$saveModelParams(paramSearch)
}

dataset <- Dataset(mappedData$covariates, labels$outcomeCount)

estimatorSettings <- modelSettings$estimatorSettings

fitParams <- names(paramSearch[[1]])[grepl("^estimator", names(paramSearch[[1]]))]

for (gridId in 1:length(paramSearch)) {
for (gridId in trainCache$getLastGridSearchIndex():length(paramSearch)) {
ParallelLogger::logInfo(paste0("Running hyperparameter combination no ", gridId))
ParallelLogger::logInfo(paste0("HyperParameters: "))
ParallelLogger::logInfo(paste(names(paramSearch[[gridId]]), paramSearch[[gridId]], collapse = " | "))
@@ -336,7 +350,10 @@ gridCvDeep <- function(mappedData,
prediction = prediction,
param = paramSearch[[gridId]]
)

trainCache$saveGridSearchPredictions(gridSearchPredictons)
}

# get best para (this could be modified to enable any metric instead of AUC, just need metric input in function)
paramGridSearch <- lapply(gridSearchPredictons, function(x) {
do.call(PatientLevelPrediction::computeGridPerformance, x)
90 changes: 90 additions & 0 deletions R/TrainingCache-class.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#' TrainingCache
#' @description
#' Parameter caching for training persistence and continuity
#' @export
TrainingCache <- R6::R6Class(
"TrainingCache",

private = list(
.paramPersistence = list(
gridSearchPredictions = NULL,
modelParams = NULL
),
.paramContinuity = list(),
.saveDir = NULL,

writeToFile = function() {
saveRDS(private$.paramPersistence, file.path(private$.saveDir))
},

readFromFile = function() {
private$.paramPersistence <- readRDS(file.path(private$.saveDir))
}
),

public = list(
#' @description
#' Creates a new training cache
#' @param inDir Path to the analysis directory
initialize = function(inDir) {
private$.saveDir <- file.path(inDir, "paramPersistence.rds")

if (file.exists(private$.saveDir)) {
private$readFromFile()
} else {
private$writeToFile()
}
},

#' @description
#' Checks whether the parameter grid in the model settings is identical to
#' the cached parameters.
#' @param inModelParams Parameter grid from the model settings
#' @returns Whether the provided and cached parameter grid is identical
isParamGridIdentical = function(inModelParams) {
return(identical(inModelParams, private$.paramPersistence$modelParams))
},

#' @description
#' Saves the grid search results to the training cache
#' @param inGridSearchPredictions Grid search predictions
saveGridSearchPredictions = function(inGridSearchPredictions) {
private$.paramPersistence$gridSearchPredictions <-
inGridSearchPredictions
private$writeToFile()
},

#' @description
#' Saves the parameter grid to the training cache
#' @param inModelParams Parameter grid from the model settings
saveModelParams = function(inModelParams) {
private$.paramPersistence$modelParams <- inModelParams
private$writeToFile()
},

#' @description
#' Gets the grid search results from the training cache
#' @returns Grid search results from the training cache
getGridSearchPredictions = function() {
return(private$.paramPersistence$gridSearchPredictions)
},

#' @description
#' Gets the last index from the cached grid search
#' @returns Last grid search index
getLastGridSearchIndex = function() {
if (is.null(private$.paramPersistence$gridSearchPredictions)) {
return(1)
} else {
return(which(sapply(private$.paramPersistence$gridSearchPredictions,
is.null))[1])
}
},

#' @description
#' Remove the training cache from the analysis path
dropCache = function() {
# TODO
}
)
)
145 changes: 145 additions & 0 deletions man/TrainingCache.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/fitEstimator.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/gridCvDeep.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test-Estimator.R
Original file line number Diff line number Diff line change
@@ -138,7 +138,7 @@ modelSettings <- setResNet(
)

sink(nullfile())
results <- fitEstimator(trainData$Train, modelSettings = modelSettings, analysisId = 1)
results <- fitEstimator(trainData$Train, modelSettings = modelSettings, analysisId = 1, analysisPath = testLoc)
sink()

test_that("Estimator fit function works", {
@@ -149,7 +149,7 @@ test_that("Estimator fit function works", {
expect_equal(attr(results, "saveType"), "file")
fakeTrainData <- trainData
fakeTrainData$train$covariateData <- list(fakeCovData <- c("Fake"))
expect_error(fitEstimator(fakeTrainData$train, modelSettings, analysisId = 1))
expect_error(fitEstimator(fakeTrainData$train, modelSettings, analysisId = 1, analysisPath = testLoc))
})

test_that("predictDeepEstimator works", {
Loading