Skip to content

Commit

Permalink
initial transfer learning changes
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Dec 4, 2023
1 parent 02a3535 commit 5836121
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 31 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export(predictDeepEstimator)
export(setDefaultResNet)
export(setDefaultTransformer)
export(setEstimator)
export(setFinetuner)
export(setMultiLayerPerceptron)
export(setResNet)
export(setTransformer)
Expand Down
12 changes: 10 additions & 2 deletions R/Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,18 @@ createDataset <- function(data, labels, plpModel = NULL) {
# sqlite object
attributes(data)$path <- attributes(data)$dbname
}
if (is.null(plpModel)) {
if (is.null(plpModel) && is.null(data$numericalIndex)) {
data <- dataset(r_to_py(normalizePath(attributes(data)$path)),
r_to_py(labels$outcomeCount))
} else {
}
else if (!is.null(data$numericalIndex)) {
numericalIndex <-
r_to_py(as.array(data$numericalIndex %>% dplyr::pull()))
data <- dataset(r_to_py(normalizePath(attributes(data)$path)),
r_to_py(labels$outcomeCount),
numericalIndex)
}
else {
numericalFeatures <-
r_to_py(as.array(which(plpModel$covariateImportance$isNumeric)))
data <- dataset(r_to_py(normalizePath(attributes(data)$path)),
Expand Down
39 changes: 33 additions & 6 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,26 @@ fitEstimator <- function(trainData,
if (!is.null(trainData$folds)) {
trainData$labels <- merge(trainData$labels, trainData$fold, by = "rowId")
}
mappedCovariateData <- PatientLevelPrediction::MapIds(
covariateData = trainData$covariateData,
cohort = trainData$labels
)

if (modelSettings$modelType == "Finetuner") {
# make sure to use same mapping from covariateIds to columns if finetuning
path <- modelSettings$param[[1]]$modelPath
oldCovImportance <- utils::read.csv(file.path(path,
"covariateImportance.csv"))
mapping <- oldCovImportance %>% dplyr::select("columnId", "covariateId")
numericalIndex <- which(oldCovImportance %>% dplyr::pull("isNumeric"))
mappedCovariateData <- PatientLevelPrediction::MapIds(
covariateData = trainData$covariateData,
cohort = trainData$labels,
mapping = mapping
)
mappedCovariateData$numericalIndex <- as.data.frame(numericalIndex)
} else {
mappedCovariateData <- PatientLevelPrediction::MapIds(
covariateData = trainData$covariateData,
cohort = trainData$labels
)
}

covariateRef <- mappedCovariateData$covariateRef

Expand Down Expand Up @@ -322,7 +338,7 @@ gridCvDeep <- function(mappedData,
ParallelLogger::logInfo(paste(names(paramSearch[[gridId]]),
paramSearch[[gridId]], collapse = " | "))
currentModelParams <- paramSearch[[gridId]][modelSettings$modelParamNames]

attr(currentModelParams, "metaData")$names <- modelSettings$modelParamNames
currentEstimatorSettings <-
fillEstimatorSettings(modelSettings$estimatorSettings,
fitParams,
Expand Down Expand Up @@ -480,7 +496,18 @@ createEstimator <- function(modelType,
modelParameters,
estimatorSettings) {
path <- system.file("python", package = "DeepPatientLevelPrediction")


if (modelType == "Finetuner") {
estimatorSettings$finetune <- TRUE
plpModel <- PatientLevelPrediction::loadPlpModel(modelParameters$modelPath)
estimatorSettings$finetuneModelPath <-
file.path(normalizePath(plpModel$model), "DeepEstimatorModel.pt")
modelType <- plpModel$modelDesign$modelSettings$modelType
oldModelParameters <- modelParameters
modelParameters <-
plpModel$trainDetails$finalModelParameters[plpModel$modelDesign$modelSettings$modelParamNames]
}

model <- reticulate::import_from_path(modelType, path = path)[[modelType]]
estimator <- reticulate::import_from_path("Estimator", path = path)$Estimator

Expand Down
15 changes: 9 additions & 6 deletions R/LRFinder.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@ createLRFinder <- function(modelType,
path <- system.file("python", package = "DeepPatientLevelPrediction")
lrFinderClass <-
reticulate::import_from_path("LrFinder", path = path)$LrFinder



model <- reticulate::import_from_path(modelType, path = path)[[modelType]]
modelParameters <- camelCaseToSnakeCaseNames(modelParameters)
estimatorSettings <- camelCaseToSnakeCaseNames(estimatorSettings)
estimatorSettings <- evalEstimatorSettings(estimatorSettings)

browser()
estimator <- createEstimator(modelType = estimatorSettings$modelType,
modelParameters = modelParameters,
estimatorSettings = estimatorSettings)
if (!is.null(lrSettings)) {
lrSettings <- camelCaseToSnakeCaseNames(lrSettings)
}

estimatorSettings <- evalEstimatorSettings(estimatorSettings)


lrFinder <- lrFinderClass(model = model,
model_parameters = modelParameters,
estimator_settings = estimatorSettings,
Expand Down
63 changes: 63 additions & 0 deletions R/TransferLearning.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# @file TransferLearning.R
#
# Copyright 2023 Observational Health Data Sciences and Informatics
#
# This file is part of DeepPatientLevelPrediction
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#' setFinetuner
#'
#' @description
#' creates settings for using transfer learning to finetune a model
#'
#' @name setFinetuner
#' @param modelPath path to existing plpModel directory
#' @param estimatorSettings settings created with `setEstimator`
#' @export
setFinetuner <- function(modelPath,
estimatorSettings =
setEstimator(learningRate = learningRate,
weightDecay = weightDecay,
batchSize = batchSize,
epochs = epochs,
device = device,
optimizer = optimizer,
scheduler = scheduler,
criterion = criterion,
earlyStopping = earlyStopping,
metric = metric,
seed = seed)
) {

if (!dir.exists(modelPath)) {
stop(paste0("supplied modelPath does not exist, you supplied: modelPath = ",
modelPath))
}
param <- list()
param[[1]] <- list(modelPath = modelPath)

attr(param, "settings")$modelType <- "FineTuner"
results <- list(
fitFunction = "fitEstimator",
param = param,
estimatorSettings = estimatorSettings,
modelType = "Finetuner",
saveType = "file",
modelParamNames = c("modelPath")
)

class(results) <- "modelSettings"

return(results)
}
20 changes: 3 additions & 17 deletions inst/python/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,10 @@ def __init__(self, data, labels=None, numerical_features=None):
self.target = torch.zeros(size=(observations,))

# filter by categorical columns,
# sort and group_by columnId
# create newColumnId from 1 (or zero?) until # catColumns
# select rowId and newColumnId
# rename newColumnId to columnId and sort by it
# select rowId and columnId
data_cat = (
data.filter(~pl.col("columnId").is_in(self.numerical_features))
.sort(by="columnId")
.with_row_count("newColumnId")
.with_columns(
pl.col("newColumnId").first().over("columnId").rank(method="dense")
)
.select(pl.col("rowId"), pl.col("newColumnId").alias("columnId"))
.select(pl.col("rowId"), pl.col("columnId"))
.sort("rowId")
.with_columns(pl.col("rowId") - 1)
.collect()
Expand Down Expand Up @@ -80,15 +72,9 @@ def __init__(self, data, labels=None, numerical_features=None):
numerical_data = (
data.filter(pl.col("columnId").is_in(self.numerical_features))
.sort(by="columnId")
.with_row_count("newColumnId")
.with_columns(
pl.col("newColumnId").first().over("columnId").rank(method="dense")
- 1,
pl.col("rowId") - 1,
)
.select(
pl.col("rowId"),
pl.col("newColumnId").alias("columnId"),
pl.col("columnId"),
pl.col("covariateValue"),
)
.collect()
Expand Down
22 changes: 22 additions & 0 deletions man/setFinetuner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions tests/testthat/test-Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,9 @@ test_that(".getbatch works", {

expect_equal(out[[2]]$shape[0], 16)
})

test_that("Column order is preserved in presence of missing features", {
# important for both external validation and transfer learning


})

0 comments on commit 5836121

Please sign in to comment.