Skip to content

Commit

Permalink
Add interface for model initialization (#124)
Browse files Browse the repository at this point in the history
* Add interface for model initialization

* Seperate finetuner from estimator

* Add missing seed to test case
  • Loading branch information
lhjohn authored Aug 2, 2024
1 parent 600fc0a commit 04cd731
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 27 deletions.
10 changes: 0 additions & 10 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -468,16 +468,6 @@ evalEstimatorSettings <- function(estimatorSettings) {
createEstimator <- function(modelParameters,
estimatorSettings) {
path <- system.file("python", package = "DeepPatientLevelPrediction")

if (modelParameters$modelType == "Finetuner") {
estimatorSettings$finetune <- TRUE
plpModel <- PatientLevelPrediction::loadPlpModel(modelParameters$modelPath)
estimatorSettings$finetuneModelPath <-
normalizePath(file.path(plpModel$model, "DeepEstimatorModel.pt"))
modelParameters$modelType <-
plpModel$modelDesign$modelSettings$modelType
}

model <-
reticulate::import_from_path(modelParameters$modelType,
path = path)[[modelParameters$modelType]]
Expand Down
14 changes: 12 additions & 2 deletions R/TransferLearning.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ setFinetuner <- function(modelPath,
modelPath))
}

plpModel <- PatientLevelPrediction::loadPlpModel(modelPath)
estimatorSettings$finetuneModelPath <-
normalizePath(file.path(plpModel$model, "DeepEstimatorModel.pt"))
modelType <-
plpModel$modelDesign$modelSettings$modelType

path <- system.file("python", package = "DeepPatientLevelPrediction")
estimatorSettings$initStrategy <-
reticulate::import_from_path("InitStrategy",
path = path)$FinetuneInitStrategy()

param <- list()
param[[1]] <- list(modelPath = modelPath)
Expand All @@ -52,9 +62,9 @@ setFinetuner <- function(modelPath,
estimatorSettings = estimatorSettings,
saveType = "file",
modelParamNames = c("modelPath"),
modelType = "Finetuner"
modelType = modelType
)
attr(results$param, "settings")$modelType <- results$modelType
attr(results$param, "settings")$modelType <- "Finetuner"

class(results) <- "modelSettings"

Expand Down
17 changes: 6 additions & 11 deletions inst/python/Estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm

from gpu_memory_cleanup import memory_cleanup

from InitStrategy import InitStrategy, DefaultInitStrategy

class Estimator:
"""
Expand All @@ -20,17 +20,12 @@ def __init__(self, model, model_parameters, estimator_settings):
else:
self.device = estimator_settings["device"]
torch.manual_seed(seed=self.seed)
if "finetune" in estimator_settings.keys() and estimator_settings["finetune"]:
path = estimator_settings["finetune_model_path"]
fitted_estimator = torch.load(path, map_location="cpu")
fitted_parameters = fitted_estimator["model_parameters"]
self.model = model(**fitted_parameters)
self.model.load_state_dict(fitted_estimator["model_state_dict"])
for param in self.model.parameters():
param.requires_grad = False
self.model.reset_head()

if "init_strategy" in estimator_settings:
self.model = estimator_settings["init_strategy"].initialize(model, model_parameters, estimator_settings)
else:
self.model = model(**model_parameters)
self.model = DefaultInitStrategy().initialize(model, model_parameters, estimator_settings)

self.model_parameters = model_parameters
self.estimator_settings = estimator_settings

Expand Down
24 changes: 24 additions & 0 deletions inst/python/InitStrategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from abc import ABC, abstractmethod
import torch

class InitStrategy(ABC):
@abstractmethod
def initialize(self, model, model_parameters, estimator_settings):
pass

class DefaultInitStrategy(InitStrategy):
def initialize(self, model, model_parameters, estimator_settings):
return model(**model_parameters)

class FinetuneInitStrategy(InitStrategy):
def initialize(self, model, model_parameters, estimator_settings):
path = estimator_settings["finetune_model_path"]
fitted_estimator = torch.load(path, map_location="cpu")
fitted_parameters = fitted_estimator["model_parameters"]
model_instance = model(**fitted_parameters)
model_instance.load_state_dict(fitted_estimator["model_state_dict"])
for param in model_instance.parameters():
param.requires_grad = False
model_instance.reset_head()
return model_instance

7 changes: 5 additions & 2 deletions tests/testthat/test-Finetuner.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ fineTunerSettings <- setFinetuner(
epochs = 1)
)

plpModel <- PatientLevelPrediction::loadPlpModel(file.path(fitEstimatorPath,
"plpModel"))
modelType <- plpModel$modelDesign$modelSettings$modelType

test_that("Finetuner settings work", {
expect_equal(fineTunerSettings$param[[1]]$modelPath,
file.path(fitEstimatorPath, "plpModel"))
Expand All @@ -14,7 +18,7 @@ test_that("Finetuner settings work", {
expect_equal(fineTunerSettings$estimatorSettings$epochs, 1)
expect_equal(fineTunerSettings$fitFunction, "fitEstimator")
expect_equal(fineTunerSettings$saveType, "file")
expect_equal(fineTunerSettings$modelType, "Finetuner")
expect_equal(fineTunerSettings$modelType, modelType)
expect_equal(fineTunerSettings$modelParamNames, "modelPath")
expect_equal(class(fineTunerSettings), "modelSettings")
expect_equal(attr(fineTunerSettings$param, "settings")$modelType, "Finetuner")
Expand Down Expand Up @@ -44,7 +48,6 @@ test_that("Finetuner fitEstimator works", {

fineTunedModel <- torch$load(file.path(fineTunerResults$model,
"DeepEstimatorModel.pt"))
expect_true(fineTunedModel$estimator_settings$finetune)
expect_equal(fineTunedModel$estimator_settings$finetune_model_path,
normalizePath(file.path(fitEstimatorPath, "plpModel", "model",
"DeepEstimatorModel.pt")))
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-TrainingCache.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ resNetSettings <- setResNet(numLayers = c(1, 2, 4),
device = "cpu",
batchSize = 64,
epochs = 1,
seed = NULL),
seed = 42),
hyperParamSearch = "random",
randomSample = 3,
randomSampleSeed = NULL)
randomSampleSeed = 42)

trainCache <- trainingCache$new(testLoc)
paramSearch <- resNetSettings$param
Expand Down

0 comments on commit 04cd731

Please sign in to comment.