From 8fa37112c4047c93bd397aec0413548c176e167c Mon Sep 17 00:00:00 2001 From: egillax Date: Wed, 6 Dec 2023 17:43:19 +0100 Subject: [PATCH] Transfear learning first prototype working --- R/Estimator.R | 54 +++++++++++++++--------------- R/HelperFunctions.R | 27 +++++++++++++++ R/LRFinder.R | 18 ++-------- R/MLP.R | 5 ++- R/ResNet.R | 4 +-- R/TransferLearning.R | 20 ++++-------- R/Transformer.R | 4 +-- inst/python/Dataset.py | 15 ++++++--- inst/python/Estimator.py | 18 ++++++++-- inst/python/LrFinder.py | 40 +++++++++-------------- inst/python/MLP.py | 11 +++++-- inst/python/ResNet.py | 5 +++ inst/python/Transformer.py | 13 ++++++++ man/setEstimator.Rd | 3 +- man/setFinetuner.Rd | 8 +---- man/snakeCaseToCamelCase.Rd | 17 ++++++++++ man/snakeCaseToCamelCaseNames.Rd | 17 ++++++++++ tests/testthat/test-Estimator.R | 56 ++++++++++++++++---------------- tests/testthat/test-LRFinder.R | 31 +++++++++--------- 19 files changed, 216 insertions(+), 150 deletions(-) create mode 100644 man/snakeCaseToCamelCase.Rd create mode 100644 man/snakeCaseToCamelCaseNames.Rd diff --git a/R/Estimator.R b/R/Estimator.R index 019eba7..1023a92 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -54,7 +54,8 @@ setEstimator <- function(learningRate = "auto", earlyStopping = list(useEarlyStopping = TRUE, params = list(patience = 4)), metric = "auc", - seed = NULL + seed = NULL, + modelType = NULL ) { checkIsClass(learningRate, c("numeric", "character")) @@ -89,7 +90,8 @@ setEstimator <- function(learningRate = "auto", earlyStopping = earlyStopping, findLR = findLR, metric = metric, - seed = seed[1]) + seed = seed[1], + modelType = modelType) optimizer <- rlang::enquo(optimizer) estimatorSettings$optimizer <- function() rlang::eval_tidy(optimizer) @@ -143,7 +145,7 @@ fitEstimator <- function(trainData, trainData$labels <- merge(trainData$labels, trainData$fold, by = "rowId") } - if (modelSettings$modelType == "Finetuner") { + if (modelSettings$estimatorSettings$modelType == "Finetuner") { # make sure to use same mapping from covariateIds to columns if finetuning path <- modelSettings$param[[1]]$modelPath oldCovImportance <- utils::read.csv(file.path(path, @@ -227,7 +229,7 @@ fitEstimator <- function(trainData, attrition = attr(trainData, "metaData")$attrition, trainingTime = paste(as.character(abs(comp)), attr(comp, "units")), trainingDate = Sys.Date(), - modelName = modelSettings$modelType, + modelName = modelSettings$estimatorSettings$modelType, finalModelParameters = cvResult$finalParam, hyperParamSearch = hyperSummary ), @@ -271,13 +273,14 @@ predictDeepEstimator <- function(plpModel, # get predictions prediction <- cohort if (is.character(plpModel$model)) { - modelSettings <- plpModel$modelDesign$modelSettings model <- torch$load(file.path(plpModel$model, "DeepEstimatorModel.pt"), map_location = "cpu") - estimator <- createEstimator(modelType = modelSettings$modelType, - modelParameters = model$model_parameters, - estimatorSettings = model$estimator_settings) + estimator <- + createEstimator(modelParameters = + snakeCaseToCamelCaseNames(model$model_parameters), + estimatorSettings = + snakeCaseToCamelCaseNames(model$estimator_settings)) estimator$model$load_state_dict(model$model_state_dict) prediction$value <- estimator$predict_proba(data) } else { @@ -308,7 +311,8 @@ gridCvDeep <- function(mappedData, modelLocation, analysisPath) { ParallelLogger::logInfo(paste0("Running hyperparameter search for ", - modelSettings$modelType, " model")) + modelSettings$estimatorSettings$modelType, + " model")) ########################################################################### @@ -343,15 +347,12 @@ gridCvDeep <- function(mappedData, fillEstimatorSettings(modelSettings$estimatorSettings, fitParams, paramSearch[[gridId]]) - currentEstimatorSettings$modelType <- modelSettings$modelType currentModelParams$catFeatures <- dataset$get_cat_features()$max() currentModelParams$numFeatures <- - dataset$get_numerical_features()$max() + dataset$get_numerical_features()$len() if (findLR) { - lrFinder <- createLRFinder(modelType = modelSettings$modelType, - modelParameters = currentModelParams, - estimatorSettings = currentEstimatorSettings - ) + lrFinder <- createLRFinder(modelParameters = currentModelParams, + estimatorSettings = currentEstimatorSettings) lr <- lrFinder$get_lr(dataset) ParallelLogger::logInfo(paste0("Auto learning rate selected as: ", lr)) currentEstimatorSettings$learningRate <- lr @@ -418,15 +419,14 @@ gridCvDeep <- function(mappedData, } modelParams$catFeatures <- dataset$get_cat_features()$max() - modelParams$numFeatures <- dataset$get_numerical_features()$max() + modelParams$numFeatures <- dataset$get_numerical_features()$len() estimatorSettings <- fillEstimatorSettings(modelSettings$estimatorSettings, fitParams, finalParam) estimatorSettings$learningRate <- finalParam$learnSchedule$LRs[[1]] - estimator <- createEstimator(modelType = modelSettings$modelType, - modelParameters = modelParams, + estimator <- createEstimator(modelParameters = modelParams, estimatorSettings = estimatorSettings) numericalIndex <- dataset$get_numerical_features() @@ -492,23 +492,22 @@ evalEstimatorSettings <- function(estimatorSettings) { estimatorSettings } -createEstimator <- function(modelType, - modelParameters, +createEstimator <- function(modelParameters, estimatorSettings) { path <- system.file("python", package = "DeepPatientLevelPrediction") - if (modelType == "Finetuner") { + if (estimatorSettings$modelType == "Finetuner") { estimatorSettings$finetune <- TRUE plpModel <- PatientLevelPrediction::loadPlpModel(modelParameters$modelPath) estimatorSettings$finetuneModelPath <- file.path(normalizePath(plpModel$model), "DeepEstimatorModel.pt") - modelType <- plpModel$modelDesign$modelSettings$modelType - oldModelParameters <- modelParameters - modelParameters <- - plpModel$trainDetails$finalModelParameters[plpModel$modelDesign$modelSettings$modelParamNames] + estimatorSettings$modelType <- + plpModel$modelDesign$modelSettings$estimatorSettings$modelType } - model <- reticulate::import_from_path(modelType, path = path)[[modelType]] + model <- + reticulate::import_from_path(estimatorSettings$modelType, + path = path)[[estimatorSettings$modelType]] estimator <- reticulate::import_from_path("Estimator", path = path)$Estimator modelParameters <- camelCaseToSnakeCaseNames(modelParameters) @@ -541,8 +540,7 @@ doCrossvalidation <- function(dataset, testDataset <- torch$utils$data$Subset(dataset, indices = as.integer(which(fold == i) - 1)) - estimator <- createEstimator(modelType = estimatorSettings$modelType, - modelParameters = modelSettings, + estimator <- createEstimator(modelParameters = modelSettings, estimatorSettings = estimatorSettings) estimator$fit(trainDataset, testDataset) diff --git a/R/HelperFunctions.R b/R/HelperFunctions.R index a1aa43c..08a18db 100644 --- a/R/HelperFunctions.R +++ b/R/HelperFunctions.R @@ -29,6 +29,33 @@ camelCaseToSnakeCase <- function(string) { return(string) } +#' Convert a camel case string to snake case +#' +#' @param string The string to be converted +#' +#' @return +#' A string +#' +snakeCaseToCamelCase <- function(string) { + string <- tolower(string) + for (letter in letters) { + string <- gsub(paste("_", letter, sep = ""), toupper(letter), string) + } + string <- gsub("_([0-9])", "\\1", string) + return(string) +} + +#' Convert the names of an object from snake case to camel case +#' +#' @param object The object of which the names should be converted +#' +#' @return +#' The same object, but with converted names. +snakeCaseToCamelCaseNames <- function(object) { + names(object) <- snakeCaseToCamelCase(names(object)) + return(object) +} + #' Convert the names of an object from camel case to snake case #' #' @param object The object of which the names should be converted diff --git a/R/LRFinder.R b/R/LRFinder.R index 4f73b29..8c3d12d 100644 --- a/R/LRFinder.R +++ b/R/LRFinder.R @@ -15,32 +15,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -createLRFinder <- function(modelType, - modelParameters, +createLRFinder <- function(modelParameters, estimatorSettings, lrSettings = NULL) { path <- system.file("python", package = "DeepPatientLevelPrediction") lrFinderClass <- reticulate::import_from_path("LrFinder", path = path)$LrFinder - estimatorSettings <- evalEstimatorSettings(estimatorSettings) - - model <- reticulate::import_from_path(modelType, path = path)[[modelType]] - modelParameters <- camelCaseToSnakeCaseNames(modelParameters) - estimatorSettings <- camelCaseToSnakeCaseNames(estimatorSettings) - estimatorSettings <- evalEstimatorSettings(estimatorSettings) - browser() - estimator <- createEstimator(modelType = estimatorSettings$modelType, - modelParameters = modelParameters, + estimator <- createEstimator(modelParameters = modelParameters, estimatorSettings = estimatorSettings) if (!is.null(lrSettings)) { lrSettings <- camelCaseToSnakeCaseNames(lrSettings) } - - lrFinder <- lrFinderClass(model = model, - model_parameters = modelParameters, - estimator_settings = estimatorSettings, + lrFinder <- lrFinderClass(estimator = estimator, lr_settings = lrSettings) return(lrFinder) diff --git a/R/MLP.R b/R/MLP.R index 771e244..5c973bd 100644 --- a/R/MLP.R +++ b/R/MLP.R @@ -93,13 +93,12 @@ setMultiLayerPerceptron <- function(numLayers = c(1:8), {param <- param[sample(length(param), randomSample)]})) } - attr(param, "settings")$modelType <- "MLP" - + estimatorSettings$modelType <- "MLP" + attr(param, "settings")$modelType <- estimatorSettings$modelType results <- list( fitFunction = "fitEstimator", param = param, estimatorSettings = estimatorSettings, - modelType = "MLP", saveType = "file", modelParamNames = c( "numLayers", "sizeHidden", diff --git a/R/ResNet.R b/R/ResNet.R index 32e7f3a..2c4a2b6 100644 --- a/R/ResNet.R +++ b/R/ResNet.R @@ -137,12 +137,12 @@ setResNet <- function(numLayers = c(1:8), {param <- param[sample(length(param), randomSample)]})) } - attr(param, "settings")$modelType <- "ResNet" + estimatorSettings$modelType <- "ResNet" + attr(param, "settings")$modelType <- estimatorSettings$modelType results <- list( fitFunction = "fitEstimator", param = param, estimatorSettings = estimatorSettings, - modelType = "ResNet", saveType = "file", modelParamNames = c("numLayers", "sizeHidden", "hiddenFactor", "residualDropout", "hiddenDropout", "sizeEmbedding") diff --git a/R/TransferLearning.R b/R/TransferLearning.R index 8ed18a0..b99be24 100644 --- a/R/TransferLearning.R +++ b/R/TransferLearning.R @@ -26,33 +26,25 @@ #' @param estimatorSettings settings created with `setEstimator` #' @export setFinetuner <- function(modelPath, - estimatorSettings = - setEstimator(learningRate = learningRate, - weightDecay = weightDecay, - batchSize = batchSize, - epochs = epochs, - device = device, - optimizer = optimizer, - scheduler = scheduler, - criterion = criterion, - earlyStopping = earlyStopping, - metric = metric, - seed = seed) + estimatorSettings = setEstimator() ) { if (!dir.exists(modelPath)) { stop(paste0("supplied modelPath does not exist, you supplied: modelPath = ", modelPath)) } + + # TODO check if it's a valid path to a plpModel + param <- list() param[[1]] <- list(modelPath = modelPath) - attr(param, "settings")$modelType <- "FineTuner" + estimatorSettings$modelType <- "Finetuner" + attr(param, "settings")$modelType <- estimatorSettings$modelType results <- list( fitFunction = "fitEstimator", param = param, estimatorSettings = estimatorSettings, - modelType = "Finetuner", saveType = "file", modelParamNames = c("modelPath") ) diff --git a/R/Transformer.R b/R/Transformer.R index cd4a968..f8d4212 100644 --- a/R/Transformer.R +++ b/R/Transformer.R @@ -180,12 +180,12 @@ setTransformer <- function(numBlocks = 3, {param <- param[sample(length(param), randomSample)]})) } - attr(param, "settings")$modelType <- "Transformer" + estimatorSettings$modelType <- "Transformer" + attr(param, "settings")$modelType <- estimatorSettings$modelType results <- list( fitFunction = "fitEstimator", param = param, estimatorSettings = estimatorSettings, - modelType = "Transformer", saveType = "file", modelParamNames = c( "numBlocks", "dimToken", "dimOut", "numHeads", diff --git a/inst/python/Dataset.py b/inst/python/Dataset.py index 330eb3c..45b7f81 100644 --- a/inst/python/Dataset.py +++ b/inst/python/Dataset.py @@ -30,6 +30,7 @@ def __init__(self, data, labels=None, numerical_features=None): .n_unique() .filter(pl.col("covariateValue") > 1) .select("columnId") + .sort("columnId") .collect()["columnId"] ) else: @@ -69,13 +70,19 @@ def __init__(self, data, labels=None, numerical_features=None): if pl.count(self.numerical_features) == 0: self.num = None else: - map_numerical = dict(zip(self.numerical_features.sort().to_list(), - list(range(len(self.numerical_features))))) + map_numerical = dict( + zip( + self.numerical_features.sort().to_list(), + list(range(len(self.numerical_features))), + ) + ) numerical_data = ( data.filter(pl.col("columnId").is_in(self.numerical_features)) - .with_columns(pl.col("columnId").replace(map_numerical), - pl.col("rowId") - 1) + .sort("columnId") + .with_columns( + pl.col("columnId").replace(map_numerical), pl.col("rowId") - 1 + ) .select( pl.col("rowId"), pl.col("columnId"), diff --git a/inst/python/Estimator.py b/inst/python/Estimator.py index 983bde6..d707a3d 100644 --- a/inst/python/Estimator.py +++ b/inst/python/Estimator.py @@ -3,7 +3,6 @@ import torch from torch.utils.data import DataLoader, BatchSampler, RandomSampler, SequentialSampler -import torch.nn.functional as F from tqdm import tqdm @@ -20,12 +19,25 @@ def __init__(self, model, model_parameters, estimator_settings): self.device = estimator_settings["device"] torch.manual_seed(seed=self.seed) - self.model = model(**model_parameters) + if estimator_settings["finetune"]: + path = estimator_settings["finetune_model_path"] + fitted_estimator = torch.load(path, map_location="cpu") + fitted_parameters = fitted_estimator["model_parameters"] + self.model = model(**fitted_parameters) + self.model.load_state_dict(fitted_estimator["model_state_dict"]) + for param in self.model.parameters(): + param.requires_grad = False + self.model.reset_head() + else: + self.model = model(**model_parameters) self.model_parameters = model_parameters self.estimator_settings = estimator_settings self.epochs = int(estimator_settings.get("epochs", 5)) - self.learning_rate = estimator_settings.get("learning_rate", 3e-4) + if estimator_settings["find_l_r"]: + self.learning_rate = 3e-4 + else: + self.learning_rate = estimator_settings.get("learning_rate", 3e-4) self.weight_decay = estimator_settings.get("weight_decay", 1e-5) self.batch_size = int(estimator_settings.get("batch_size", 1024)) self.prefix = estimator_settings.get("prefix", self.model.name) diff --git a/inst/python/LrFinder.py b/inst/python/LrFinder.py index e7141cd..7365995 100644 --- a/inst/python/LrFinder.py +++ b/inst/python/LrFinder.py @@ -19,7 +19,7 @@ def get_lr(self): class LrFinder: - def __init__(self, model, model_parameters, estimator_settings, lr_settings): + def __init__(self, estimator, lr_settings=None): if lr_settings is None: lr_settings = {} min_lr = lr_settings.get("min_lr", 1e-7) @@ -27,30 +27,20 @@ def __init__(self, model, model_parameters, estimator_settings, lr_settings): num_lr = lr_settings.get("num_lr", 100) smooth = lr_settings.get("smooth", 0.05) divergence_threshold = lr_settings.get("divergence_threshold", 4) - torch.manual_seed(seed=estimator_settings["seed"]) - self.seed = estimator_settings["seed"] - self.model = model(**model_parameters) - if callable(estimator_settings["device"]): - self.device = estimator_settings["device"]() - else: - self.device = estimator_settings["device"] - self.model.to(device=self.device) + torch.manual_seed(seed=estimator.seed) + self.seed = estimator.seed + self.min_lr = min_lr self.max_lr = max_lr self.num_lr = num_lr self.smooth = smooth self.divergence_threshold = divergence_threshold - self.optimizer = estimator_settings["optimizer"]( - params=self.model.parameters(), lr=self.min_lr - ) - - self.scheduler = ExponentialSchedulerPerBatch( - self.optimizer, self.max_lr, self.num_lr + estimator.scheduler = ExponentialSchedulerPerBatch( + estimator.optimizer, self.max_lr, self.num_lr ) - self.criterion = estimator_settings["criterion"]() - self.batch_size = int(estimator_settings["batch_size"]) + self.estimator = estimator self.losses = None self.loss_index = None @@ -60,24 +50,24 @@ def get_lr(self, dataset): losses = torch.empty(size=(self.num_lr,), dtype=torch.float) lrs = torch.empty(size=(self.num_lr,), dtype=torch.float) for i in tqdm(range(self.num_lr)): - self.optimizer.zero_grad() - random_batch = random.sample(batch_index, self.batch_size) + self.estimator.optimizer.zero_grad() + random_batch = random.sample(batch_index, self.estimator.batch_size) batch = dataset[random_batch] - batch = batch_to_device(batch, self.device) + batch = batch_to_device(batch, self.estimator.device) - out = self.model(batch[0]) - loss = self.criterion(out, batch[1]) + out = self.estimator.model(batch[0]) + loss = self.estimator.criterion(out, batch[1]) if self.smooth is not None and i != 0: losses[i] = ( self.smooth * loss.item() + (1 - self.smooth) * losses[i - 1] ) else: losses[i] = loss.item() - lrs[i] = self.optimizer.param_groups[0]["lr"] + lrs[i] = self.estimator.optimizer.param_groups[0]["lr"] loss.backward() - self.optimizer.step() - self.scheduler.step() + self.estimator.optimizer.step() + self.estimator.scheduler.step() if i == 0: best_loss = losses[i] diff --git a/inst/python/MLP.py b/inst/python/MLP.py index cd049a6..7e91f36 100644 --- a/inst/python/MLP.py +++ b/inst/python/MLP.py @@ -14,7 +14,7 @@ def __init__( activation=nn.ReLU, normalization=nn.BatchNorm1d, dropout=None, - d_out: int = 1, + dim_out: int = 1, ): super(MLP, self).__init__() self.name = "MLP" @@ -23,7 +23,7 @@ def __init__( size_embedding = int(size_embedding) size_hidden = int(size_hidden) num_layers = int(num_layers) - d_out = int(d_out) + dim_out = int(dim_out) self.embedding = nn.EmbeddingBag( cat_features + 1, size_embedding, padding_idx=0 @@ -44,7 +44,9 @@ def __init__( for _ in range(num_layers) ) self.last_norm = normalization(size_hidden) - self.head = nn.Linear(size_hidden, d_out) + self.head = nn.Linear(size_hidden, dim_out) + self.size_hidden = size_hidden + self.dim_out = dim_out self.last_act = activation() @@ -65,6 +67,9 @@ def forward(self, input): x = x.squeeze(-1) return x + def reset_head(self): + self.head = nn.Linear(self.size_hidden, self.dim_out) + class MLPLayer(nn.Module): def __init__( diff --git a/inst/python/ResNet.py b/inst/python/ResNet.py index 9df35d8..453e584 100644 --- a/inst/python/ResNet.py +++ b/inst/python/ResNet.py @@ -58,6 +58,8 @@ def __init__( self.last_norm = normalization(size_hidden) self.head = nn.Linear(size_hidden, dim_out) + self.size_hidden = size_hidden + self.dim_out = dim_out self.last_act = activation() @@ -87,6 +89,9 @@ def forward(self, x): x = x.squeeze(-1) return x + def reset_head(self): + self.head = nn.Linear(self.size_hidden, self.dim_out) + class ResLayer(nn.Module): def __init__( diff --git a/inst/python/Transformer.py b/inst/python/Transformer.py index ddbe929..58625a9 100644 --- a/inst/python/Transformer.py +++ b/inst/python/Transformer.py @@ -88,6 +88,10 @@ def __init__( normalization=head_norm, dim_out=dim_out, ) + self.dim_token = dim_token + self.head_activation = head_activation + self.head_normalization = head_norm + self.dim_out = dim_out def forward(self, x): mask = torch.where(x["cat"] == 0, True, False) @@ -141,6 +145,15 @@ def forward(self, x): x = self.head(x)[:, 0] return x + def reset_head(self): + self.head = Head( + self.dim_token, + bias=True, + activation=self.head_activation, + normalization=self.head_normalization, + dim_out=self.dim_out + ) + @staticmethod def start_residual(layer, stage, x): norm = f"{stage}_norm" diff --git a/man/setEstimator.Rd b/man/setEstimator.Rd index b8424a3..2abb3ae 100644 --- a/man/setEstimator.Rd +++ b/man/setEstimator.Rd @@ -16,7 +16,8 @@ setEstimator( criterion = torch$nn$BCEWithLogitsLoss, earlyStopping = list(useEarlyStopping = TRUE, params = list(patience = 4)), metric = "auc", - seed = NULL + seed = NULL, + modelType = NULL ) } \arguments{ diff --git a/man/setFinetuner.Rd b/man/setFinetuner.Rd index d6ff412..f439350 100644 --- a/man/setFinetuner.Rd +++ b/man/setFinetuner.Rd @@ -4,13 +4,7 @@ \alias{setFinetuner} \title{setFinetuner} \usage{ -setFinetuner( - modelPath, - estimatorSettings = setEstimator(learningRate = learningRate, weightDecay = - weightDecay, batchSize = batchSize, epochs = epochs, device = device, optimizer = - optimizer, scheduler = scheduler, criterion = criterion, earlyStopping = - earlyStopping, metric = metric, seed = seed) -) +setFinetuner(modelPath, estimatorSettings = setEstimator()) } \arguments{ \item{modelPath}{path to existing plpModel directory} diff --git a/man/snakeCaseToCamelCase.Rd b/man/snakeCaseToCamelCase.Rd new file mode 100644 index 0000000..1be21a8 --- /dev/null +++ b/man/snakeCaseToCamelCase.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/HelperFunctions.R +\name{snakeCaseToCamelCase} +\alias{snakeCaseToCamelCase} +\title{Convert a camel case string to snake case} +\usage{ +snakeCaseToCamelCase(string) +} +\arguments{ +\item{string}{The string to be converted} +} +\value{ +A string +} +\description{ +Convert a camel case string to snake case +} diff --git a/man/snakeCaseToCamelCaseNames.Rd b/man/snakeCaseToCamelCaseNames.Rd new file mode 100644 index 0000000..111dab1 --- /dev/null +++ b/man/snakeCaseToCamelCaseNames.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/HelperFunctions.R +\name{snakeCaseToCamelCaseNames} +\alias{snakeCaseToCamelCaseNames} +\title{Convert the names of an object from snake case to camel case} +\usage{ +snakeCaseToCamelCaseNames(object) +} +\arguments{ +\item{object}{The object of which the names should be converted} +} +\value{ +The same object, but with converted names. +} +\description{ +Convert the names of an object from snake case to camel case +} diff --git a/tests/testthat/test-Estimator.R b/tests/testthat/test-Estimator.R index ba82835..98f335f 100644 --- a/tests/testthat/test-Estimator.R +++ b/tests/testthat/test-Estimator.R @@ -1,5 +1,5 @@ catFeatures <- smallDataset$dataset$get_cat_features()$max() -numFeatures <- smallDataset$dataset$get_numerical_features()$max() +numFeatures <- smallDataset$dataset$get_numerical_features()$len() modelParameters <- list( cat_features = catFeatures, @@ -9,7 +9,7 @@ modelParameters <- list( num_layers = 2, hidden_factor = 2 ) - +modelType = "ResNet" estimatorSettings <- setEstimator(learningRate = 3e-4, weightDecay = 0.0, @@ -23,18 +23,19 @@ estimatorSettings <- scheduler = list(fun = torch$optim$lr_scheduler$ReduceLROnPlateau, params = list(patience = 1)), - earlyStopping = NULL) + earlyStopping = NULL, + modelType = modelType) -modelType <- "ResNet" -estimator <- createEstimator(modelType = modelType, - modelParameters = modelParameters, +estimator <- createEstimator(modelParameters = modelParameters, estimatorSettings = estimatorSettings) test_that("Estimator initialization works", { # count parameters in both instances path <- system.file("python", package = "DeepPatientLevelPrediction") - resNet <- reticulate::import_from_path(modelType, path = path)[[modelType]] + resNet <- + reticulate::import_from_path(estimatorSettings$modelType, + path = path)[[estimatorSettings$modelType]] testthat::expect_equal( sum(reticulate::iterate(estimator$model$parameters(), @@ -113,10 +114,9 @@ test_that("estimator fitting works", { batchSize = 128, epochs = 5, device = "cpu", - metric = "loss") - - estimator <- createEstimator(modelType = modelType, - modelParameters = modelParameters, + metric = "loss", + modelType = modelType) + estimator <- createEstimator(modelParameters = modelParameters, estimatorSettings = estimatorSettings) sink(nullfile()) @@ -216,10 +216,9 @@ test_that("Estimator without earlyStopping works", { batchSize = 128, epochs = 1, device = "cpu", - earlyStopping = NULL) - - estimator2 <- createEstimator(modelType = modelType, - modelParameters = modelParameters, + earlyStopping = NULL, + modelType = modelType) + estimator2 <- createEstimator(modelParameters = modelParameters, estimatorSettings = estimatorSettings) sink(nullfile()) estimator2$fit(smallDataset, smallDataset) @@ -241,10 +240,10 @@ test_that("Early stopper can use loss and stops early", { params = list(mode = c("min"), patience = 1)), metric = "loss", - seed = 42) + seed = 42, + modelType = modelType) - estimator <- createEstimator(modelType = modelType, - modelParameters = modelParameters, + estimator <- createEstimator(modelParameters = modelParameters, estimatorSettings = estimatorSettings) sink(nullfile()) estimator$fit(smallDataset, smallDataset) @@ -270,9 +269,9 @@ test_that("Custom metric in estimator works", { epochs = 1, metric = list(fun = metricFun, name = "auprc", - mode = "max")) - estimator <- createEstimator(modelType = modelType, - modelParameters = modelParameters, + mode = "max"), + modelType = modelType) + estimator <- createEstimator(modelParameters = modelParameters, estimatorSettings = estimatorSettings) expect_true(is.function(estimator$metric$fun)) expect_true(is.character(estimator$metric$name)) @@ -334,13 +333,13 @@ test_that("device as a function argument works", { } estimatorSettings <- setEstimator(device = getDevice, - learningRate = 3e-4) + learningRate = 3e-4, + modelType = modelType) model <- setDefaultResNet(estimatorSettings = estimatorSettings) model$param[[1]]$catFeatures <- 10 - estimator <- createEstimator(modelType = modelType, - modelParameters = model$param[[1]], + estimator <- createEstimator(modelParameters = model$param[[1]], estimatorSettings = estimatorSettings) expect_equal(estimator$device, "cpu") @@ -348,13 +347,13 @@ test_that("device as a function argument works", { Sys.setenv("testDeepPLPDevice" = "meta") estimatorSettings <- setEstimator(device = getDevice, - learningRate = 3e-4) + learningRate = 3e-4, + modelType = modelType) model <- setDefaultResNet(estimatorSettings = estimatorSettings) model$param[[1]]$catFeatures <- 10 - estimator <- createEstimator(modelType = modelType, - modelParameters = model$param[[1]], + estimator <- createEstimator(modelParameters = model$param[[1]], estimatorSettings = estimatorSettings) expect_equal(estimator$device, "meta") @@ -385,7 +384,8 @@ test_that("evaluation works on predictDeepEstimator output", { cohort = trainData$Test$labels) prediction$evaluationType <- 'Validation' - evaluation <- evaluatePlp(prediction, "evaluationType") + evaluation <- + PatientLevelPrediction::evaluatePlp(prediction, "evaluationType") expect_length(evaluation, 5) expect_s3_class(evaluation, "plpEvaluation") diff --git a/tests/testthat/test-LRFinder.R b/tests/testthat/test-LRFinder.R index a42416b..8509586 100644 --- a/tests/testthat/test-LRFinder.R +++ b/tests/testthat/test-LRFinder.R @@ -30,19 +30,20 @@ test_that("LR scheduler that changes per batch works", { test_that("LR finder works", { + estimatorSettings <- setEstimator(batchSize = 32L, + seed = 42, + modelType = "ResNet") lrFinder <- - createLRFinder(modelType = "ResNet", - modelParameters = - list(cat_features = - dataset$get_cat_features()$max(), - num_features = - dataset$get_numerical_features()$max(), - size_embedding = 32L, - size_hidden = 64L, - num_layers = 1L, - hidden_factor = 1L), - estimatorSettings = setEstimator(batchSize = 32L, - seed = 42), + createLRFinder(modelParameters = + list(cat_features = + dataset$get_cat_features()$max(), + num_features = + dataset$get_numerical_features()$len(), + size_embedding = 32L, + size_hidden = 64L, + num_layers = 1L, + hidden_factor = 1L), + estimatorSettings = estimatorSettings, lrSettings = list(minLr = 3e-4, maxLr = 10.0, numLr = 20L, @@ -61,17 +62,17 @@ test_that("LR finder works with device specified by a function", { dev } lrFinder <- createLRFinder( - model = "ResNet", modelParameters = list(cat_features = dataset$get_cat_features()$max(), - num_features = dataset$get_numerical_features()$max(), + num_features = dataset$get_numerical_features()$len(), size_embedding = 8L, size_hidden = 16L, num_layers = 1L, hidden_factor = 1L), estimatorSettings = setEstimator(batchSize = 32L, seed = 42, - device = deviceFun), + device = deviceFun, + modelType = "ResNet"), lrSettings = list(minLr = 3e-4, maxLr = 10.0, numLr = 20L,