Skip to content

Commit

Permalink
Transfear learning first prototype working
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Dec 6, 2023
1 parent 1170c83 commit 8fa3711
Show file tree
Hide file tree
Showing 19 changed files with 216 additions and 150 deletions.
54 changes: 26 additions & 28 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ setEstimator <- function(learningRate = "auto",
earlyStopping = list(useEarlyStopping = TRUE,
params = list(patience = 4)),
metric = "auc",
seed = NULL
seed = NULL,
modelType = NULL
) {

checkIsClass(learningRate, c("numeric", "character"))
Expand Down Expand Up @@ -89,7 +90,8 @@ setEstimator <- function(learningRate = "auto",
earlyStopping = earlyStopping,
findLR = findLR,
metric = metric,
seed = seed[1])
seed = seed[1],
modelType = modelType)

optimizer <- rlang::enquo(optimizer)
estimatorSettings$optimizer <- function() rlang::eval_tidy(optimizer)
Expand Down Expand Up @@ -143,7 +145,7 @@ fitEstimator <- function(trainData,
trainData$labels <- merge(trainData$labels, trainData$fold, by = "rowId")
}

if (modelSettings$modelType == "Finetuner") {
if (modelSettings$estimatorSettings$modelType == "Finetuner") {
# make sure to use same mapping from covariateIds to columns if finetuning
path <- modelSettings$param[[1]]$modelPath
oldCovImportance <- utils::read.csv(file.path(path,
Expand Down Expand Up @@ -227,7 +229,7 @@ fitEstimator <- function(trainData,
attrition = attr(trainData, "metaData")$attrition,
trainingTime = paste(as.character(abs(comp)), attr(comp, "units")),
trainingDate = Sys.Date(),
modelName = modelSettings$modelType,
modelName = modelSettings$estimatorSettings$modelType,
finalModelParameters = cvResult$finalParam,
hyperParamSearch = hyperSummary
),
Expand Down Expand Up @@ -271,13 +273,14 @@ predictDeepEstimator <- function(plpModel,
# get predictions
prediction <- cohort
if (is.character(plpModel$model)) {
modelSettings <- plpModel$modelDesign$modelSettings
model <- torch$load(file.path(plpModel$model,
"DeepEstimatorModel.pt"),
map_location = "cpu")
estimator <- createEstimator(modelType = modelSettings$modelType,
modelParameters = model$model_parameters,
estimatorSettings = model$estimator_settings)
estimator <-
createEstimator(modelParameters =
snakeCaseToCamelCaseNames(model$model_parameters),
estimatorSettings =
snakeCaseToCamelCaseNames(model$estimator_settings))
estimator$model$load_state_dict(model$model_state_dict)
prediction$value <- estimator$predict_proba(data)
} else {
Expand Down Expand Up @@ -308,7 +311,8 @@ gridCvDeep <- function(mappedData,
modelLocation,
analysisPath) {
ParallelLogger::logInfo(paste0("Running hyperparameter search for ",
modelSettings$modelType, " model"))
modelSettings$estimatorSettings$modelType,
" model"))

###########################################################################

Expand Down Expand Up @@ -343,15 +347,12 @@ gridCvDeep <- function(mappedData,
fillEstimatorSettings(modelSettings$estimatorSettings,
fitParams,
paramSearch[[gridId]])
currentEstimatorSettings$modelType <- modelSettings$modelType
currentModelParams$catFeatures <- dataset$get_cat_features()$max()
currentModelParams$numFeatures <-
dataset$get_numerical_features()$max()
dataset$get_numerical_features()$len()
if (findLR) {
lrFinder <- createLRFinder(modelType = modelSettings$modelType,
modelParameters = currentModelParams,
estimatorSettings = currentEstimatorSettings
)
lrFinder <- createLRFinder(modelParameters = currentModelParams,
estimatorSettings = currentEstimatorSettings)
lr <- lrFinder$get_lr(dataset)
ParallelLogger::logInfo(paste0("Auto learning rate selected as: ", lr))
currentEstimatorSettings$learningRate <- lr
Expand Down Expand Up @@ -418,15 +419,14 @@ gridCvDeep <- function(mappedData,
}

modelParams$catFeatures <- dataset$get_cat_features()$max()
modelParams$numFeatures <- dataset$get_numerical_features()$max()
modelParams$numFeatures <- dataset$get_numerical_features()$len()


estimatorSettings <- fillEstimatorSettings(modelSettings$estimatorSettings,
fitParams,
finalParam)
estimatorSettings$learningRate <- finalParam$learnSchedule$LRs[[1]]
estimator <- createEstimator(modelType = modelSettings$modelType,
modelParameters = modelParams,
estimator <- createEstimator(modelParameters = modelParams,
estimatorSettings = estimatorSettings)

numericalIndex <- dataset$get_numerical_features()
Expand Down Expand Up @@ -492,23 +492,22 @@ evalEstimatorSettings <- function(estimatorSettings) {
estimatorSettings
}

createEstimator <- function(modelType,
modelParameters,
createEstimator <- function(modelParameters,
estimatorSettings) {
path <- system.file("python", package = "DeepPatientLevelPrediction")

if (modelType == "Finetuner") {
if (estimatorSettings$modelType == "Finetuner") {
estimatorSettings$finetune <- TRUE
plpModel <- PatientLevelPrediction::loadPlpModel(modelParameters$modelPath)
estimatorSettings$finetuneModelPath <-
file.path(normalizePath(plpModel$model), "DeepEstimatorModel.pt")
modelType <- plpModel$modelDesign$modelSettings$modelType
oldModelParameters <- modelParameters
modelParameters <-
plpModel$trainDetails$finalModelParameters[plpModel$modelDesign$modelSettings$modelParamNames]
estimatorSettings$modelType <-
plpModel$modelDesign$modelSettings$estimatorSettings$modelType
}

model <- reticulate::import_from_path(modelType, path = path)[[modelType]]
model <-
reticulate::import_from_path(estimatorSettings$modelType,
path = path)[[estimatorSettings$modelType]]
estimator <- reticulate::import_from_path("Estimator", path = path)$Estimator

modelParameters <- camelCaseToSnakeCaseNames(modelParameters)
Expand Down Expand Up @@ -541,8 +540,7 @@ doCrossvalidation <- function(dataset,
testDataset <- torch$utils$data$Subset(dataset,
indices =
as.integer(which(fold == i) - 1))
estimator <- createEstimator(modelType = estimatorSettings$modelType,
modelParameters = modelSettings,
estimator <- createEstimator(modelParameters = modelSettings,
estimatorSettings = estimatorSettings)
estimator$fit(trainDataset, testDataset)

Expand Down
27 changes: 27 additions & 0 deletions R/HelperFunctions.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,33 @@ camelCaseToSnakeCase <- function(string) {
return(string)
}

#' Convert a camel case string to snake case
#'
#' @param string The string to be converted
#'
#' @return
#' A string
#'
snakeCaseToCamelCase <- function(string) {
string <- tolower(string)
for (letter in letters) {
string <- gsub(paste("_", letter, sep = ""), toupper(letter), string)
}
string <- gsub("_([0-9])", "\\1", string)
return(string)
}

#' Convert the names of an object from snake case to camel case
#'
#' @param object The object of which the names should be converted
#'
#' @return
#' The same object, but with converted names.
snakeCaseToCamelCaseNames <- function(object) {
names(object) <- snakeCaseToCamelCase(names(object))
return(object)
}

#' Convert the names of an object from camel case to snake case
#'
#' @param object The object of which the names should be converted
Expand Down
18 changes: 3 additions & 15 deletions R/LRFinder.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
createLRFinder <- function(modelType,
modelParameters,
createLRFinder <- function(modelParameters,
estimatorSettings,
lrSettings = NULL) {
path <- system.file("python", package = "DeepPatientLevelPrediction")
lrFinderClass <-
reticulate::import_from_path("LrFinder", path = path)$LrFinder

estimatorSettings <- evalEstimatorSettings(estimatorSettings)

model <- reticulate::import_from_path(modelType, path = path)[[modelType]]
modelParameters <- camelCaseToSnakeCaseNames(modelParameters)
estimatorSettings <- camelCaseToSnakeCaseNames(estimatorSettings)
estimatorSettings <- evalEstimatorSettings(estimatorSettings)
browser()
estimator <- createEstimator(modelType = estimatorSettings$modelType,
modelParameters = modelParameters,
estimator <- createEstimator(modelParameters = modelParameters,
estimatorSettings = estimatorSettings)

if (!is.null(lrSettings)) {
lrSettings <- camelCaseToSnakeCaseNames(lrSettings)
}

lrFinder <- lrFinderClass(model = model,
model_parameters = modelParameters,
estimator_settings = estimatorSettings,
lrFinder <- lrFinderClass(estimator = estimator,
lr_settings = lrSettings)

return(lrFinder)
Expand Down
5 changes: 2 additions & 3 deletions R/MLP.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,12 @@ setMultiLayerPerceptron <- function(numLayers = c(1:8),
{param <- param[sample(length(param),
randomSample)]}))
}
attr(param, "settings")$modelType <- "MLP"

estimatorSettings$modelType <- "MLP"
attr(param, "settings")$modelType <- estimatorSettings$modelType
results <- list(
fitFunction = "fitEstimator",
param = param,
estimatorSettings = estimatorSettings,
modelType = "MLP",
saveType = "file",
modelParamNames = c(
"numLayers", "sizeHidden",
Expand Down
4 changes: 2 additions & 2 deletions R/ResNet.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ setResNet <- function(numLayers = c(1:8),
{param <- param[sample(length(param),
randomSample)]}))
}
attr(param, "settings")$modelType <- "ResNet"
estimatorSettings$modelType <- "ResNet"
attr(param, "settings")$modelType <- estimatorSettings$modelType
results <- list(
fitFunction = "fitEstimator",
param = param,
estimatorSettings = estimatorSettings,
modelType = "ResNet",
saveType = "file",
modelParamNames = c("numLayers", "sizeHidden", "hiddenFactor",
"residualDropout", "hiddenDropout", "sizeEmbedding")
Expand Down
20 changes: 6 additions & 14 deletions R/TransferLearning.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,25 @@
#' @param estimatorSettings settings created with `setEstimator`
#' @export
setFinetuner <- function(modelPath,
estimatorSettings =
setEstimator(learningRate = learningRate,
weightDecay = weightDecay,
batchSize = batchSize,
epochs = epochs,
device = device,
optimizer = optimizer,
scheduler = scheduler,
criterion = criterion,
earlyStopping = earlyStopping,
metric = metric,
seed = seed)
estimatorSettings = setEstimator()
) {

if (!dir.exists(modelPath)) {
stop(paste0("supplied modelPath does not exist, you supplied: modelPath = ",
modelPath))
}

# TODO check if it's a valid path to a plpModel

param <- list()
param[[1]] <- list(modelPath = modelPath)

attr(param, "settings")$modelType <- "FineTuner"
estimatorSettings$modelType <- "Finetuner"
attr(param, "settings")$modelType <- estimatorSettings$modelType
results <- list(
fitFunction = "fitEstimator",
param = param,
estimatorSettings = estimatorSettings,
modelType = "Finetuner",
saveType = "file",
modelParamNames = c("modelPath")
)
Expand Down
4 changes: 2 additions & 2 deletions R/Transformer.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,12 @@ setTransformer <- function(numBlocks = 3,
{param <- param[sample(length(param),
randomSample)]}))
}
attr(param, "settings")$modelType <- "Transformer"
estimatorSettings$modelType <- "Transformer"
attr(param, "settings")$modelType <- estimatorSettings$modelType
results <- list(
fitFunction = "fitEstimator",
param = param,
estimatorSettings = estimatorSettings,
modelType = "Transformer",
saveType = "file",
modelParamNames = c(
"numBlocks", "dimToken", "dimOut", "numHeads",
Expand Down
15 changes: 11 additions & 4 deletions inst/python/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, data, labels=None, numerical_features=None):
.n_unique()
.filter(pl.col("covariateValue") > 1)
.select("columnId")
.sort("columnId")
.collect()["columnId"]
)
else:
Expand Down Expand Up @@ -69,13 +70,19 @@ def __init__(self, data, labels=None, numerical_features=None):
if pl.count(self.numerical_features) == 0:
self.num = None
else:
map_numerical = dict(zip(self.numerical_features.sort().to_list(),
list(range(len(self.numerical_features)))))
map_numerical = dict(
zip(
self.numerical_features.sort().to_list(),
list(range(len(self.numerical_features))),
)
)

numerical_data = (
data.filter(pl.col("columnId").is_in(self.numerical_features))
.with_columns(pl.col("columnId").replace(map_numerical),
pl.col("rowId") - 1)
.sort("columnId")
.with_columns(
pl.col("columnId").replace(map_numerical), pl.col("rowId") - 1
)
.select(
pl.col("rowId"),
pl.col("columnId"),
Expand Down
18 changes: 15 additions & 3 deletions inst/python/Estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
from torch.utils.data import DataLoader, BatchSampler, RandomSampler, SequentialSampler
import torch.nn.functional as F
from tqdm import tqdm


Expand All @@ -20,12 +19,25 @@ def __init__(self, model, model_parameters, estimator_settings):
self.device = estimator_settings["device"]

torch.manual_seed(seed=self.seed)
self.model = model(**model_parameters)
if estimator_settings["finetune"]:
path = estimator_settings["finetune_model_path"]
fitted_estimator = torch.load(path, map_location="cpu")
fitted_parameters = fitted_estimator["model_parameters"]
self.model = model(**fitted_parameters)
self.model.load_state_dict(fitted_estimator["model_state_dict"])
for param in self.model.parameters():
param.requires_grad = False
self.model.reset_head()
else:
self.model = model(**model_parameters)
self.model_parameters = model_parameters
self.estimator_settings = estimator_settings

self.epochs = int(estimator_settings.get("epochs", 5))
self.learning_rate = estimator_settings.get("learning_rate", 3e-4)
if estimator_settings["find_l_r"]:
self.learning_rate = 3e-4
else:
self.learning_rate = estimator_settings.get("learning_rate", 3e-4)
self.weight_decay = estimator_settings.get("weight_decay", 1e-5)
self.batch_size = int(estimator_settings.get("batch_size", 1024))
self.prefix = estimator_settings.get("prefix", self.model.name)
Expand Down
Loading

0 comments on commit 8fa3711

Please sign in to comment.