Skip to content

Commit

Permalink
Add interface for model initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
lhjohn committed Jul 31, 2024
1 parent 600fc0a commit fb6552f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 16 deletions.
7 changes: 5 additions & 2 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -470,12 +470,14 @@ createEstimator <- function(modelParameters,
path <- system.file("python", package = "DeepPatientLevelPrediction")

if (modelParameters$modelType == "Finetuner") {
estimatorSettings$finetune <- TRUE
plpModel <- PatientLevelPrediction::loadPlpModel(modelParameters$modelPath)
estimatorSettings$finetuneModelPath <-
normalizePath(file.path(plpModel$model, "DeepEstimatorModel.pt"))
modelParameters$modelType <-
plpModel$modelDesign$modelSettings$modelType
initStrategy <- reticulate::import_from_path("InitStrategy", path = path)$FinetuneInitStrategy()
} else {
initStrategy <- reticulate::import_from_path("InitStrategy", path = path)$DefaultInitStrategy()
}

model <-
Expand All @@ -490,7 +492,8 @@ createEstimator <- function(modelParameters,
estimator <- estimator(
model = model,
model_parameters = modelParameters,
estimator_settings = estimatorSettings
estimator_settings = estimatorSettings,
init_strategy = initStrategy
)
return(estimator)
}
Expand Down
18 changes: 5 additions & 13 deletions inst/python/Estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,23 @@
from tqdm import tqdm

from gpu_memory_cleanup import memory_cleanup

from InitStrategy import InitStrategy, DefaultInitStrategy

class Estimator:
"""
A class that wraps around pytorch models.
"""

def __init__(self, model, model_parameters, estimator_settings):
def __init__(self, model, model_parameters, estimator_settings, init_strategy: InitStrategy = DefaultInitStrategy()):
self.seed = estimator_settings["seed"]
if callable(estimator_settings["device"]):
self.device = estimator_settings["device"]()
else:
self.device = estimator_settings["device"]
torch.manual_seed(seed=self.seed)
if "finetune" in estimator_settings.keys() and estimator_settings["finetune"]:
path = estimator_settings["finetune_model_path"]
fitted_estimator = torch.load(path, map_location="cpu")
fitted_parameters = fitted_estimator["model_parameters"]
self.model = model(**fitted_parameters)
self.model.load_state_dict(fitted_estimator["model_state_dict"])
for param in self.model.parameters():
param.requires_grad = False
self.model.reset_head()
else:
self.model = model(**model_parameters)

self.model = init_strategy.initialize(model, model_parameters, estimator_settings)

self.model_parameters = model_parameters
self.estimator_settings = estimator_settings

Expand Down
24 changes: 24 additions & 0 deletions inst/python/InitStrategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from abc import ABC, abstractmethod
import torch

class InitStrategy(ABC):
@abstractmethod
def initialize(self, model, model_parameters, estimator_settings):
pass

class DefaultInitStrategy(InitStrategy):
def initialize(self, model, model_parameters, estimator_settings):
return model(**model_parameters)

class FinetuneInitStrategy(InitStrategy):
def initialize(self, model, model_parameters, estimator_settings):
path = estimator_settings["finetune_model_path"]
fitted_estimator = torch.load(path, map_location="cpu")
fitted_parameters = fitted_estimator["model_parameters"]
model_instance = model(**fitted_parameters)
model_instance.load_state_dict(fitted_estimator["model_state_dict"])
for param in model_instance.parameters():
param.requires_grad = False
model_instance.reset_head()
return model_instance

1 change: 0 additions & 1 deletion tests/testthat/test-Finetuner.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ test_that("Finetuner fitEstimator works", {

fineTunedModel <- torch$load(file.path(fineTunerResults$model,
"DeepEstimatorModel.pt"))
expect_true(fineTunedModel$estimator_settings$finetune)
expect_equal(fineTunedModel$estimator_settings$finetune_model_path,
normalizePath(file.path(fitEstimatorPath, "plpModel", "model",
"DeepEstimatorModel.pt")))
Expand Down

0 comments on commit fb6552f

Please sign in to comment.