From fb6552f31c24cd079017d948927bd9e2337f18fc Mon Sep 17 00:00:00 2001 From: Henrik John Date: Wed, 31 Jul 2024 15:02:43 +0200 Subject: [PATCH 1/3] Add interface for model initialization --- R/Estimator.R | 7 +++++-- inst/python/Estimator.py | 18 +++++------------- inst/python/InitStrategy.py | 24 ++++++++++++++++++++++++ tests/testthat/test-Finetuner.R | 1 - 4 files changed, 34 insertions(+), 16 deletions(-) create mode 100644 inst/python/InitStrategy.py diff --git a/R/Estimator.R b/R/Estimator.R index ece685c..4825601 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -470,12 +470,14 @@ createEstimator <- function(modelParameters, path <- system.file("python", package = "DeepPatientLevelPrediction") if (modelParameters$modelType == "Finetuner") { - estimatorSettings$finetune <- TRUE plpModel <- PatientLevelPrediction::loadPlpModel(modelParameters$modelPath) estimatorSettings$finetuneModelPath <- normalizePath(file.path(plpModel$model, "DeepEstimatorModel.pt")) modelParameters$modelType <- plpModel$modelDesign$modelSettings$modelType + initStrategy <- reticulate::import_from_path("InitStrategy", path = path)$FinetuneInitStrategy() + } else { + initStrategy <- reticulate::import_from_path("InitStrategy", path = path)$DefaultInitStrategy() } model <- @@ -490,7 +492,8 @@ createEstimator <- function(modelParameters, estimator <- estimator( model = model, model_parameters = modelParameters, - estimator_settings = estimatorSettings + estimator_settings = estimatorSettings, + init_strategy = initStrategy ) return(estimator) } diff --git a/inst/python/Estimator.py b/inst/python/Estimator.py index f0e020d..93a2f4d 100644 --- a/inst/python/Estimator.py +++ b/inst/python/Estimator.py @@ -6,31 +6,23 @@ from tqdm import tqdm from gpu_memory_cleanup import memory_cleanup - +from InitStrategy import InitStrategy, DefaultInitStrategy class Estimator: """ A class that wraps around pytorch models. """ - def __init__(self, model, model_parameters, estimator_settings): + def __init__(self, model, model_parameters, estimator_settings, init_strategy: InitStrategy = DefaultInitStrategy()): self.seed = estimator_settings["seed"] if callable(estimator_settings["device"]): self.device = estimator_settings["device"]() else: self.device = estimator_settings["device"] torch.manual_seed(seed=self.seed) - if "finetune" in estimator_settings.keys() and estimator_settings["finetune"]: - path = estimator_settings["finetune_model_path"] - fitted_estimator = torch.load(path, map_location="cpu") - fitted_parameters = fitted_estimator["model_parameters"] - self.model = model(**fitted_parameters) - self.model.load_state_dict(fitted_estimator["model_state_dict"]) - for param in self.model.parameters(): - param.requires_grad = False - self.model.reset_head() - else: - self.model = model(**model_parameters) + + self.model = init_strategy.initialize(model, model_parameters, estimator_settings) + self.model_parameters = model_parameters self.estimator_settings = estimator_settings diff --git a/inst/python/InitStrategy.py b/inst/python/InitStrategy.py new file mode 100644 index 0000000..d723d4e --- /dev/null +++ b/inst/python/InitStrategy.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +import torch + +class InitStrategy(ABC): + @abstractmethod + def initialize(self, model, model_parameters, estimator_settings): + pass + +class DefaultInitStrategy(InitStrategy): + def initialize(self, model, model_parameters, estimator_settings): + return model(**model_parameters) + +class FinetuneInitStrategy(InitStrategy): + def initialize(self, model, model_parameters, estimator_settings): + path = estimator_settings["finetune_model_path"] + fitted_estimator = torch.load(path, map_location="cpu") + fitted_parameters = fitted_estimator["model_parameters"] + model_instance = model(**fitted_parameters) + model_instance.load_state_dict(fitted_estimator["model_state_dict"]) + for param in model_instance.parameters(): + param.requires_grad = False + model_instance.reset_head() + return model_instance + diff --git a/tests/testthat/test-Finetuner.R b/tests/testthat/test-Finetuner.R index a8f7cb3..162da9f 100644 --- a/tests/testthat/test-Finetuner.R +++ b/tests/testthat/test-Finetuner.R @@ -44,7 +44,6 @@ test_that("Finetuner fitEstimator works", { fineTunedModel <- torch$load(file.path(fineTunerResults$model, "DeepEstimatorModel.pt")) - expect_true(fineTunedModel$estimator_settings$finetune) expect_equal(fineTunedModel$estimator_settings$finetune_model_path, normalizePath(file.path(fitEstimatorPath, "plpModel", "model", "DeepEstimatorModel.pt"))) From ef1263038263dbd0bea7f100ece31e38a7188450 Mon Sep 17 00:00:00 2001 From: Henrik John Date: Thu, 1 Aug 2024 14:16:27 +0200 Subject: [PATCH 2/3] Seperate finetuner from estimator --- R/Estimator.R | 15 +-------------- R/TransferLearning.R | 14 ++++++++++++-- inst/python/Estimator.py | 9 ++++++--- tests/testthat/test-Finetuner.R | 6 +++++- 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/R/Estimator.R b/R/Estimator.R index 4825601..19f414e 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -468,18 +468,6 @@ evalEstimatorSettings <- function(estimatorSettings) { createEstimator <- function(modelParameters, estimatorSettings) { path <- system.file("python", package = "DeepPatientLevelPrediction") - - if (modelParameters$modelType == "Finetuner") { - plpModel <- PatientLevelPrediction::loadPlpModel(modelParameters$modelPath) - estimatorSettings$finetuneModelPath <- - normalizePath(file.path(plpModel$model, "DeepEstimatorModel.pt")) - modelParameters$modelType <- - plpModel$modelDesign$modelSettings$modelType - initStrategy <- reticulate::import_from_path("InitStrategy", path = path)$FinetuneInitStrategy() - } else { - initStrategy <- reticulate::import_from_path("InitStrategy", path = path)$DefaultInitStrategy() - } - model <- reticulate::import_from_path(modelParameters$modelType, path = path)[[modelParameters$modelType]] @@ -492,8 +480,7 @@ createEstimator <- function(modelParameters, estimator <- estimator( model = model, model_parameters = modelParameters, - estimator_settings = estimatorSettings, - init_strategy = initStrategy + estimator_settings = estimatorSettings ) return(estimator) } diff --git a/R/TransferLearning.R b/R/TransferLearning.R index 95d1c04..e47a942 100644 --- a/R/TransferLearning.R +++ b/R/TransferLearning.R @@ -42,6 +42,16 @@ setFinetuner <- function(modelPath, modelPath)) } + plpModel <- PatientLevelPrediction::loadPlpModel(modelPath) + estimatorSettings$finetuneModelPath <- + normalizePath(file.path(plpModel$model, "DeepEstimatorModel.pt")) + modelType <- + plpModel$modelDesign$modelSettings$modelType + + path <- system.file("python", package = "DeepPatientLevelPrediction") + estimatorSettings$initStrategy <- + reticulate::import_from_path("InitStrategy", + path = path)$FinetuneInitStrategy() param <- list() param[[1]] <- list(modelPath = modelPath) @@ -52,9 +62,9 @@ setFinetuner <- function(modelPath, estimatorSettings = estimatorSettings, saveType = "file", modelParamNames = c("modelPath"), - modelType = "Finetuner" + modelType = modelType ) - attr(results$param, "settings")$modelType <- results$modelType + attr(results$param, "settings")$modelType <- "Finetuner" class(results) <- "modelSettings" diff --git a/inst/python/Estimator.py b/inst/python/Estimator.py index 93a2f4d..1b6ac18 100644 --- a/inst/python/Estimator.py +++ b/inst/python/Estimator.py @@ -13,7 +13,7 @@ class Estimator: A class that wraps around pytorch models. """ - def __init__(self, model, model_parameters, estimator_settings, init_strategy: InitStrategy = DefaultInitStrategy()): + def __init__(self, model, model_parameters, estimator_settings): self.seed = estimator_settings["seed"] if callable(estimator_settings["device"]): self.device = estimator_settings["device"]() @@ -21,8 +21,11 @@ def __init__(self, model, model_parameters, estimator_settings, init_strategy: I self.device = estimator_settings["device"] torch.manual_seed(seed=self.seed) - self.model = init_strategy.initialize(model, model_parameters, estimator_settings) - + if "init_strategy" in estimator_settings: + self.model = estimator_settings["init_strategy"].initialize(model, model_parameters, estimator_settings) + else: + self.model = DefaultInitStrategy().initialize(model, model_parameters, estimator_settings) + self.model_parameters = model_parameters self.estimator_settings = estimator_settings diff --git a/tests/testthat/test-Finetuner.R b/tests/testthat/test-Finetuner.R index 162da9f..e6489e2 100644 --- a/tests/testthat/test-Finetuner.R +++ b/tests/testthat/test-Finetuner.R @@ -6,6 +6,10 @@ fineTunerSettings <- setFinetuner( epochs = 1) ) +plpModel <- PatientLevelPrediction::loadPlpModel(file.path(fitEstimatorPath, + "plpModel")) +modelType <- plpModel$modelDesign$modelSettings$modelType + test_that("Finetuner settings work", { expect_equal(fineTunerSettings$param[[1]]$modelPath, file.path(fitEstimatorPath, "plpModel")) @@ -14,7 +18,7 @@ test_that("Finetuner settings work", { expect_equal(fineTunerSettings$estimatorSettings$epochs, 1) expect_equal(fineTunerSettings$fitFunction, "fitEstimator") expect_equal(fineTunerSettings$saveType, "file") - expect_equal(fineTunerSettings$modelType, "Finetuner") + expect_equal(fineTunerSettings$modelType, modelType) expect_equal(fineTunerSettings$modelParamNames, "modelPath") expect_equal(class(fineTunerSettings), "modelSettings") expect_equal(attr(fineTunerSettings$param, "settings")$modelType, "Finetuner") From ea07c75e5dce2c9ca6dfa85b58eb68c885e474ce Mon Sep 17 00:00:00 2001 From: Henrik John Date: Thu, 1 Aug 2024 14:31:10 +0200 Subject: [PATCH 3/3] Add missing seed to test case --- tests/testthat/test-TrainingCache.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test-TrainingCache.R b/tests/testthat/test-TrainingCache.R index ec5b063..5663beb 100644 --- a/tests/testthat/test-TrainingCache.R +++ b/tests/testthat/test-TrainingCache.R @@ -10,10 +10,10 @@ resNetSettings <- setResNet(numLayers = c(1, 2, 4), device = "cpu", batchSize = 64, epochs = 1, - seed = NULL), + seed = 42), hyperParamSearch = "random", randomSample = 3, - randomSampleSeed = NULL) + randomSampleSeed = 42) trainCache <- trainingCache$new(testLoc) paramSearch <- resNetSettings$param