Skip to content

Commit

Permalink
fix model_type
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Dec 14, 2023
1 parent b85799a commit afaa815
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 76 deletions.
44 changes: 22 additions & 22 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ setEstimator <- function(learningRate = "auto",
earlyStopping = list(useEarlyStopping = TRUE,
params = list(patience = 4)),
metric = "auc",
seed = NULL,
modelType = NULL
seed = NULL
) {

checkIsClass(learningRate, c("numeric", "character"))
Expand Down Expand Up @@ -90,8 +89,7 @@ setEstimator <- function(learningRate = "auto",
earlyStopping = earlyStopping,
findLR = findLR,
metric = metric,
seed = seed[1],
modelType = modelType)
seed = seed[1])

optimizer <- rlang::enquo(optimizer)
estimatorSettings$optimizer <- function() rlang::eval_tidy(optimizer)
Expand Down Expand Up @@ -144,11 +142,11 @@ fitEstimator <- function(trainData,
if (!is.null(trainData$folds)) {
trainData$labels <- merge(trainData$labels, trainData$fold, by = "rowId")
}
if (modelSettings$estimatorSettings$modelType == "Finetuner") {

if (modelSettings$modelType == "Finetuner") {
# make sure to use same mapping from covariateIds to columns if finetuning
path <- modelSettings$param[[1]]$modelPath
oldCovImportance <- utils::read.csv(file.path(path,
oldCovImportance <- utils::read.csv(file.path(path,
"covariateImportance.csv"))
mapping <- oldCovImportance %>% dplyr::select("columnId", "covariateId")
numericalIndex <- which(oldCovImportance %>% dplyr::pull("isNumeric"))
Expand Down Expand Up @@ -229,7 +227,7 @@ fitEstimator <- function(trainData,
attrition = attr(trainData, "metaData")$attrition,
trainingTime = paste(as.character(abs(comp)), attr(comp, "units")),
trainingDate = Sys.Date(),
modelName = modelSettings$estimatorSettings$modelType,
modelName = modelSettings$modelType,
finalModelParameters = cvResult$finalParam,
hyperParamSearch = hyperSummary
),
Expand Down Expand Up @@ -278,9 +276,9 @@ predictDeepEstimator <- function(plpModel,
map_location = "cpu")
estimator <-
createEstimator(modelParameters =
snakeCaseToCamelCaseNames(model$model_parameters),
snakeCaseToCamelCaseNames(model$model_parameters),
estimatorSettings =
snakeCaseToCamelCaseNames(model$estimator_settings))
snakeCaseToCamelCaseNames(model$estimator_settings))
estimator$model$load_state_dict(model$model_state_dict)
prediction$value <- estimator$predict_proba(data)
} else {
Expand Down Expand Up @@ -311,7 +309,7 @@ gridCvDeep <- function(mappedData,
modelLocation,
analysisPath) {
ParallelLogger::logInfo(paste0("Running hyperparameter search for ",
modelSettings$estimatorSettings$modelType,
modelSettings$modelType,
" model"))

###########################################################################
Expand Down Expand Up @@ -342,7 +340,9 @@ gridCvDeep <- function(mappedData,
ParallelLogger::logInfo(paste(names(paramSearch[[gridId]]),
paramSearch[[gridId]], collapse = " | "))
currentModelParams <- paramSearch[[gridId]][modelSettings$modelParamNames]
attr(currentModelParams, "metaData")$names <- modelSettings$modelParamNames
attr(currentModelParams, "metaData")$names <-
modelSettings$modelParamNames
currentModelParams$modelType <- modelSettings$modelType
currentEstimatorSettings <-
fillEstimatorSettings(modelSettings$estimatorSettings,
fitParams,
Expand Down Expand Up @@ -420,7 +420,7 @@ gridCvDeep <- function(mappedData,

modelParams$catFeatures <- dataset$get_cat_features()$max()
modelParams$numFeatures <- dataset$get_numerical_features()$len()

modelParams$modelType <- modelSettings$modelType

estimatorSettings <- fillEstimatorSettings(modelSettings$estimatorSettings,
fitParams,
Expand Down Expand Up @@ -495,19 +495,19 @@ evalEstimatorSettings <- function(estimatorSettings) {
createEstimator <- function(modelParameters,
estimatorSettings) {
path <- system.file("python", package = "DeepPatientLevelPrediction")
if (estimatorSettings$modelType == "Finetuner") {

if (modelParameters$modelType == "Finetuner") {
estimatorSettings$finetune <- TRUE
plpModel <- PatientLevelPrediction::loadPlpModel(modelParameters$modelPath)
estimatorSettings$finetuneModelPath <-
estimatorSettings$finetuneModelPath <-
file.path(normalizePath(plpModel$model), "DeepEstimatorModel.pt")
estimatorSettings$modelType <-
plpModel$modelDesign$modelSettings$estimatorSettings$modelType
}
modelParameters$modelType <-
plpModel$modelDesign$modelSettings$modelType
}

model <-
reticulate::import_from_path(estimatorSettings$modelType,
path = path)[[estimatorSettings$modelType]]
reticulate::import_from_path(modelParameters$modelType,
path = path)[[modelParameters$modelType]]
estimator <- reticulate::import_from_path("Estimator", path = path)$Estimator

modelParameters <- camelCaseToSnakeCaseNames(modelParameters)
Expand Down
7 changes: 4 additions & 3 deletions R/MLP.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ setMultiLayerPerceptron <- function(numLayers = c(1:8),
{param <- param[sample(length(param),
randomSample)]}))
}
estimatorSettings$modelType <- "MLP"
attr(param, "settings")$modelType <- estimatorSettings$modelType
results <- list(
fitFunction = "fitEstimator",
param = param,
Expand All @@ -103,8 +101,11 @@ setMultiLayerPerceptron <- function(numLayers = c(1:8),
modelParamNames = c(
"numLayers", "sizeHidden",
"dropout", "sizeEmbedding"
)
),
modelType = "MLP"
)
attr(results$param, "settings")$modelType <- results$modelType


class(results) <- "modelSettings"

Expand Down
8 changes: 4 additions & 4 deletions R/ResNet.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,17 @@ setResNet <- function(numLayers = c(1:8),
{param <- param[sample(length(param),
randomSample)]}))
}
estimatorSettings$modelType <- "ResNet"
attr(param, "settings")$modelType <- estimatorSettings$modelType
results <- list(
fitFunction = "fitEstimator",
param = param,
estimatorSettings = estimatorSettings,
saveType = "file",
modelParamNames = c("numLayers", "sizeHidden", "hiddenFactor",
"residualDropout", "hiddenDropout", "sizeEmbedding")
"residualDropout", "hiddenDropout", "sizeEmbedding"),
modelType = "ResNet"
)

attr(results$param, "settings")$modelType <- results$modelType

class(results) <- "modelSettings"

return(results)
Expand Down
10 changes: 5 additions & 5 deletions R/TransferLearning.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ setFinetuner <- function(modelPath,

param <- list()
param[[1]] <- list(modelPath = modelPath)

estimatorSettings$modelType <- "Finetuner"
attr(param, "settings")$modelType <- estimatorSettings$modelType

results <- list(
fitFunction = "fitEstimator",
param = param,
estimatorSettings = estimatorSettings,
saveType = "file",
modelParamNames = c("modelPath")
modelParamNames = c("modelPath"),
modelType = "Finetuner"
)

attr(results$param, "settings")$modelType <- results$modelType

class(results) <- "modelSettings"

return(results)
Expand Down
7 changes: 3 additions & 4 deletions R/Transformer.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,6 @@ setTransformer <- function(numBlocks = 3,
{param <- param[sample(length(param),
randomSample)]}))
}
estimatorSettings$modelType <- "Transformer"
attr(param, "settings")$modelType <- estimatorSettings$modelType
results <- list(
fitFunction = "fitEstimator",
param = param,
Expand All @@ -190,9 +188,10 @@ setTransformer <- function(numBlocks = 3,
modelParamNames = c(
"numBlocks", "dimToken", "dimOut", "numHeads",
"attDropout", "ffnDropout", "resDropout", "dimHidden"
)
),
modelType = "Transformer"
)

attr(results$param, "settings")$modelType <- results$modelType
class(results) <- "modelSettings"
return(results)
}
3 changes: 2 additions & 1 deletion inst/python/MLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ def __init__(
normalization=nn.BatchNorm1d,
dropout=None,
dim_out: int = 1,
model_type="MLP"
):
super(MLP, self).__init__()
self.name = "MLP"
self.name = model_type
cat_features = int(cat_features)
num_features = int(num_features)
size_embedding = int(size_embedding)
Expand Down
3 changes: 2 additions & 1 deletion inst/python/ResNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ def __init__(
residual_dropout=0,
dim_out: int = 1,
concat_num=True,
model_type="ResNet"
):
super(ResNet, self).__init__()
self.name = "ResNet"
self.name = model_type
cat_features = int(cat_features)
num_features = int(num_features)
size_embedding = int(size_embedding)
Expand Down
3 changes: 2 additions & 1 deletion inst/python/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ def __init__(
ffn_norm=nn.LayerNorm,
head_norm=nn.LayerNorm,
att_norm=nn.LayerNorm,
model_type="Transformer"
):
super(Transformer, self).__init__()
self.name = "Transformer"
self.name = model_type
cat_features = int(cat_features)
num_features = int(num_features)
num_blocks = int(num_blocks)
Expand Down
3 changes: 1 addition & 2 deletions man/setEstimator.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 22 additions & 27 deletions tests/testthat/test-Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ catFeatures <- smallDataset$dataset$get_cat_features()$max()
numFeatures <- smallDataset$dataset$get_numerical_features()$len()

modelParameters <- list(
cat_features = catFeatures,
num_features = numFeatures,
size_embedding = 16,
size_hidden = 16,
num_layers = 2,
hidden_factor = 2
catFeatures = catFeatures,
numFeatures = numFeatures,
sizeEmbedding = 16,
sizeHidden = 16,
numLayers = 2,
hiddenFactor = 2,
modelType = "ResNet"
)
modelType = "ResNet"

estimatorSettings <-
setEstimator(learningRate = 3e-4,
weightDecay = 0.0,
Expand All @@ -23,8 +24,7 @@ estimatorSettings <-
scheduler =
list(fun = torch$optim$lr_scheduler$ReduceLROnPlateau,
params = list(patience = 1)),
earlyStopping = NULL,
modelType = modelType)
earlyStopping = NULL)

estimator <- createEstimator(modelParameters = modelParameters,
estimatorSettings = estimatorSettings)
Expand All @@ -34,19 +34,19 @@ test_that("Estimator initialization works", {
# count parameters in both instances
path <- system.file("python", package = "DeepPatientLevelPrediction")
resNet <-
reticulate::import_from_path(estimatorSettings$modelType,
path = path)[[estimatorSettings$modelType]]
reticulate::import_from_path(modelParameters$modelType,
path = path)[[modelParameters$modelType]]

testthat::expect_equal(
testthat::expect_equal(
sum(reticulate::iterate(estimator$model$parameters(),
function(x) x$numel())),
sum(reticulate::iterate(do.call(resNet, modelParameters)$parameters(),
sum(reticulate::iterate(do.call(resNet, camelCaseToSnakeCaseNames(modelParameters))$parameters(),
function(x) x$numel()))
)

testthat::expect_equal(
estimator$model_parameters,
modelParameters
camelCaseToSnakeCaseNames(modelParameters)
)
})

Expand Down Expand Up @@ -114,8 +114,7 @@ test_that("estimator fitting works", {
batchSize = 128,
epochs = 5,
device = "cpu",
metric = "loss",
modelType = modelType)
metric = "loss")
estimator <- createEstimator(modelParameters = modelParameters,
estimatorSettings = estimatorSettings)

Expand Down Expand Up @@ -216,8 +215,7 @@ test_that("Estimator without earlyStopping works", {
batchSize = 128,
epochs = 1,
device = "cpu",
earlyStopping = NULL,
modelType = modelType)
earlyStopping = NULL)
estimator2 <- createEstimator(modelParameters = modelParameters,
estimatorSettings = estimatorSettings)
sink(nullfile())
Expand All @@ -240,8 +238,7 @@ test_that("Early stopper can use loss and stops early", {
params = list(mode = c("min"),
patience = 1)),
metric = "loss",
seed = 42,
modelType = modelType)
seed = 42)

estimator <- createEstimator(modelParameters = modelParameters,
estimatorSettings = estimatorSettings)
Expand Down Expand Up @@ -269,8 +266,7 @@ test_that("Custom metric in estimator works", {
epochs = 1,
metric = list(fun = metricFun,
name = "auprc",
mode = "max"),
modelType = modelType)
mode = "max"))
estimator <- createEstimator(modelParameters = modelParameters,
estimatorSettings = estimatorSettings)
expect_true(is.function(estimator$metric$fun))
Expand Down Expand Up @@ -333,12 +329,11 @@ test_that("device as a function argument works", {
}

estimatorSettings <- setEstimator(device = getDevice,
learningRate = 3e-4,
modelType = modelType)
learningRate = 3e-4)

model <- setDefaultResNet(estimatorSettings = estimatorSettings)
model$param[[1]]$catFeatures <- 10

model$param[[1]]$modelType <- "ResNet"
estimator <- createEstimator(modelParameters = model$param[[1]],
estimatorSettings = estimatorSettings)

Expand All @@ -347,11 +342,11 @@ test_that("device as a function argument works", {
Sys.setenv("testDeepPLPDevice" = "meta")

estimatorSettings <- setEstimator(device = getDevice,
learningRate = 3e-4,
modelType = modelType)
learningRate = 3e-4)

model <- setDefaultResNet(estimatorSettings = estimatorSettings)
model$param[[1]]$catFeatures <- 10
model$param[[1]]$modelType <- "ResNet"

estimator <- createEstimator(modelParameters = model$param[[1]],
estimatorSettings = estimatorSettings)
Expand Down
Loading

0 comments on commit afaa815

Please sign in to comment.