Skip to content

Commit

Permalink
Finish refactor and fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Oct 22, 2024
1 parent 560c6b0 commit d26adbc
Show file tree
Hide file tree
Showing 18 changed files with 180 additions and 181 deletions.
23 changes: 9 additions & 14 deletions R/CustomEmbeddingModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#' should be a pytorch file including a dictionary with two two fields:
#' `concept_ids`: a pytorch long tensor with the concept ids and `embeddings`:
#' a pytorch float tensor with the embeddings
#' @param estimatorSettings created with `setEstimator`
#' @param modelSettings for the model to use, needs to have an embedding layer
#' with a name `embedding` which will be replaced by the custom embeddings
#'
Expand All @@ -32,15 +31,6 @@
#' @export
setCustomEmbeddingModel <- function(
embeddingFilePath,
estimatorSettings =
setEstimator(
learningRate = "auto",
weightDecay = 1e-4,
batchSize = 256,
epochs = 2,
seed = NULL,
device = "cpu"
),
modelSettings = setTransformer(
numBlocks = 3,
dimToken = 16,
Expand All @@ -50,7 +40,12 @@ setCustomEmbeddingModel <- function(
ffnDropout = 0.1,
resDropout = 0.0,
dimHidden = 32,
estimatorSettings = estimatorSettings,
estimatorSettings = setEstimator(learningRate = "auto",
weightDecay = 1e-4,
batchSize = 256,
epochs = 2,
seed = NULL,
device = "cpu"),
hyperParamSearch = "random",
randomSample = 1
)
Expand All @@ -61,12 +56,12 @@ setCustomEmbeddingModel <- function(


path <- system.file("python", package = "DeepPatientLevelPrediction")
estimatorSettings$initStrategy <-
modelSettings$estimatorSettings$initStrategy <-
reticulate::import_from_path("InitStrategy",
path = path)$CustomEmbeddingInitStrategy()
estimatorSettings$embeddingFilePath <- embeddingFilePath
modelSettings$estimatorSettings$embeddingFilePath <- embeddingFilePath
transformerSettings <- modelSettings

attr(transformerSettings, "settings")$name <- "CustomEmbeddingTransformer"
attr(transformerSettings, "settings")$name <- "CustomEmbeddingModel"
return(transformerSettings)
}
20 changes: 2 additions & 18 deletions R/Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,11 @@ createDataset <- function(data, labels, plpModel = NULL) {
r_to_py(labels$outcomeCount),
numericalIndex)
} else {
cat_1_mapping <- plpModel$covariateImportance %>%
dplyr::select(covariateId, cat1Idx) %>%
dplyr::rename(index = cat1Idx) %>%
dplyr::filter(!is.na(index)) %>%
as.data.frame() %>%
r_to_py()

cat_2_mapping <- plpModel$covariateImportance %>%
dplyr::select(covariateId, cat2Idx) %>%
dplyr::rename(index = cat2Idx) %>%
dplyr::filter(!is.na(index)) %>%
as.data.frame() %>%
r_to_py()

numericalFeatures <-
r_to_py(as.array(which(plpModel$covariateImportance$isNumeric)))
data <- dataset(r_to_py(normalizePath(attributes(data)$path)),
numerical_features = numericalFeatures,
in_cat_2_mapping = cat_2_mapping,
in_cat_1_mapping = cat_1_mapping
)
numerical_features = numericalFeatures
)
}

return(data)
Expand Down
34 changes: 15 additions & 19 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ fitEstimator <- function(trainData,
included = incs,
covariateValue = 0,
isNumeric = .data$columnId %in% cvResult$numericalIndex
) %>%
left_join(cvResult$cat1Mapping %>% rename(cat1Idx = index), by = "covariateId") %>%
left_join(cvResult$cat2Mapping %>% rename(cat2Idx = index), by = "covariateId")
)

comp <- start - Sys.time()
modelSettings$estimatorSettings$initStrategy <- NULL
Expand Down Expand Up @@ -330,11 +328,12 @@ predictDeepEstimator <- function(plpModel,
}
model$estimator_settings$device <-
plpModel$modelDesign$modelSettings$estimatorSettings$device
modelParameters <- snakeCaseToCamelCaseNames(model$model_parameters)
estimatorSettings <- snakeCaseToCamelCaseNames(model$estimator_settings)
parameters <- list(modelParameters = modelParameters,
estimatorSettings = estimatorSettings)
estimator <-
createEstimator(modelParameters =
snakeCaseToCamelCaseNames(model$model_parameters),
estimatorSettings =
snakeCaseToCamelCaseNames(model$estimator_settings))
createEstimator(parameters = parameters)
estimator$model$load_state_dict(model$model_state_dict)
prediction$value <- estimator$predict_proba(data)
} else {
Expand Down Expand Up @@ -433,9 +432,7 @@ gridCvDeep <- function(mappedData,
dplyr::select(-"index")
prediction$cohortStartDate <- as.Date(prediction$cohortStartDate,
origin = "1970-01-01")
numericalIndex <- dataset$get_numerical_features()
cat1Mapping <- as.data.frame(dataset$get_cat_1_mapping())
cat2Mapping <- as.data.frame(dataset$get_cat_2_mapping())
numericalIndex <- dataset$numerical_features$to_list()

# save torch code here
if (!dir.exists(file.path(modelLocation))) {
Expand All @@ -448,9 +445,7 @@ gridCvDeep <- function(mappedData,
prediction = prediction,
finalParam = finalParam,
paramGridSearch = paramGridSearch,
numericalIndex = numericalIndex$to_list(),
cat1Mapping = cat1Mapping,
cat2Mapping = cat2Mapping
numericalIndex = numericalIndex
)
)
}
Expand Down Expand Up @@ -623,7 +618,7 @@ doCrossValidationImpl <- function(dataset,
testDataset <- torch$utils$data$Subset(dataset,
indices =
as.integer(which(fold == i) - 1))
estimator <- createEstimator(parameters)
estimator <- createEstimator(currentParameters)
fit_estimator(estimator, trainDataset, testDataset)

ParallelLogger::logInfo("Calculating predictions on left out fold set...")
Expand Down Expand Up @@ -676,9 +671,7 @@ trainFinalModel <- function(dataset, finalParam, modelSettings, labels) {

fitParams <- names(finalParam)[grepl("^estimator", names(finalParam))]

modelParams$catFeatures <- dataset$get_cat_features()$len()
modelParams$cat2Features <- dataset$get_cat_2_features()$len()
modelParams$numFeatures <- dataset$get_numerical_features()$len()
modelParams$featureInfo <- dataset$get_feature_info()
modelParams$modelType <- modelSettings$modelType

estimatorSettings <- fillEstimatorSettings(
Expand All @@ -687,8 +680,11 @@ trainFinalModel <- function(dataset, finalParam, modelSettings, labels) {
finalParam
)
estimatorSettings$learningRate <- finalParam$learnSchedule$LRs[[1]]
estimator <- createEstimator(modelParameters = modelParams,
estimatorSettings = estimatorSettings)
parameters <- list(
modelParameters = modelParams,
estimatorSettings = estimatorSettings
)
estimator <- createEstimator(parameters = parameters)
estimator$fit_whole_training_set(dataset, finalParam$learnSchedule$LRs)

ParallelLogger::logInfo("Calculating predictions on all train data...")
Expand Down
15 changes: 9 additions & 6 deletions inst/python/CustomEmbeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,24 @@ def __init__(self,
freeze: bool = True):
super(CustomEmbeddings, self).__init__()

self.custom_embeddings = nn.Embedding.from_pretrained(custom_embedding_weights, freeze=freeze)
self.regular_embeddings = nn.Embedding(num_regular_embeddings, embedding_dim)
# make sure padding idx refers to all zero embeddings at position 0
custom_embedding_weights = torch.cat([torch.zeros(1, custom_embedding_weights.shape[1]), custom_embedding_weights])
self.custom_embeddings = nn.Embedding.from_pretrained(custom_embedding_weights, freeze=freeze,
padding_idx=0)
self.regular_embeddings = nn.Embedding(num_regular_embeddings, embedding_dim, padding_idx=0)

self.custom_indices = custom_indices

self.linear_transform = nn.Linear(custom_embedding_weights.shape[1], embedding_dim)

def forward(self, x):
custom_embeddings_mask = torch.isin(x, self.custom_indices)
custom_features = x[custom_embeddings_mask]
regular_features = x[~custom_embeddings_mask]
custom_embeddings_mask = torch.isin(x, self.custom_indices.to(x.device))
custom_features = torch.where(custom_embeddings_mask, x, torch.tensor(0))
regular_features = torch.where(~custom_embeddings_mask, x, torch.tensor(0))

custom_embeddings = self.custom_embeddings(custom_features)
regular_embeddings = self.regular_embeddings(regular_features)

custom_embeddings = self.linear_transform(custom_embeddings)

return torch.cat([custom_embeddings, regular_embeddings], dim=-1)
return custom_embeddings + regular_embeddings
10 changes: 5 additions & 5 deletions inst/python/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(self, data, labels=None, numerical_features=None):
# select rowId and columnId
data_cat = (
data.filter(~pl.col("columnId").is_in(self.numerical_features))
.select(pl.col("rowId"), pl.col("covariateId"))
.sort(["rowId", "covariateId"])
.select(pl.col("rowId"), pl.col("columnId"))
.sort(["rowId", "columnId"])
.with_columns(pl.col("rowId") - 1)
.collect()
)
Expand All @@ -66,7 +66,7 @@ def __init__(self, data, labels=None, numerical_features=None):
for i, i2 in enumerate(idx):
total_list[i2] = tensor_list[i]
self.cat = torch.nn.utils.rnn.pad_sequence(total_list, batch_first=True)
self.cat_features = data_cat["covariateId"].unique()
self.categorical_features = data_cat["columnId"].unique()

# numerical data,
# N x C, dense matrix with values for N patients/visits for C numerical features
Expand Down Expand Up @@ -110,8 +110,8 @@ def __init__(self, data, labels=None, numerical_features=None):

def get_feature_info(self):
return {
"numerical_features": torch.tensor(self.numerical_features),
"categorical_features": torch.tensor(self.cat_features),
"numerical_features": len(self.numerical_features),
"categorical_features": self.categorical_features.max(),
"reference": self.data_ref
}

Expand Down
7 changes: 3 additions & 4 deletions inst/python/MultiLayerPerceptron.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
class MultiLayerPerceptron(nn.Module):
def __init__(
self,
cat_features: int,
num_features: int,
feature_info: dict,
size_embedding: int,
size_hidden: int,
num_layers: int,
Expand All @@ -19,8 +18,8 @@ def __init__(
):
super(MultiLayerPerceptron, self).__init__()
self.name = model_type
cat_features = int(cat_features)
num_features = int(num_features)
cat_features = int(feature_info["categorical_features"])
num_features = int(feature_info.get("numerical_features", 0))
size_embedding = int(size_embedding)
size_hidden = int(size_hidden)
num_layers = int(num_layers)
Expand Down
7 changes: 3 additions & 4 deletions inst/python/ResNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
class ResNet(nn.Module):
def __init__(
self,
cat_features: int,
num_features: int = 0,
feature_info: dict,
size_embedding: int = 256,
size_hidden: int = 256,
num_layers: int = 2,
Expand All @@ -23,8 +22,8 @@ def __init__(
):
super(ResNet, self).__init__()
self.name = model_type
cat_features = int(cat_features)
num_features = int(num_features)
cat_features = int(feature_info["categorical_features"])
num_features = int(feature_info.get("numerical_features", 0))
size_embedding = int(size_embedding)
size_hidden = int(size_hidden)
num_layers = int(num_layers)
Expand Down
7 changes: 3 additions & 4 deletions inst/python/Transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from idlelib.debugger_r import wrap_info

import torch
from torch import nn
Expand Down Expand Up @@ -51,10 +52,8 @@ def __init__(
num_heads = int(num_heads)
dim_hidden = int(dim_hidden)
dim_out = int(dim_out)
cat_features = feature_info["categorical_features"]
num_features = feature_info["numerical_features"]
cat_feature_size = len(cat_features)
num_feature_size = len(num_features)
cat_feature_size = int(feature_info["categorical_features"])
num_feature_size = int(feature_info.get("numerical_features", 0))

self.embedding = nn.Embedding(
cat_feature_size + 1, dim_token, padding_idx=0
Expand Down
14 changes: 14 additions & 0 deletions man/checkFileExists.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 11 additions & 9 deletions man/setCustomEmbeddingModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 3 additions & 5 deletions tests/testthat/test-Dataset.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
test_that("number of num and cat features sum correctly", {
testthat::expect_equal(
length(dataset$get_numerical_features()) +
length(dataset$get_cat_features()),
dataset$get_feature_info()[["categorical_features"]],
dplyr::n_distinct(mappedData$covariates %>%
dplyr::collect() %>%
dplyr::pull(covariateId))
Expand Down Expand Up @@ -56,7 +55,6 @@ test_that(".getbatch works", {

test_that("Column order is preserved when features are missing", {
# important for transfer learning and external validation

reducedCovData <- Andromeda::copyAndromeda(trainData$Train$covariateData)

# remove one numerical and one categorical
Expand Down Expand Up @@ -119,7 +117,7 @@ test_that("Column order is preserved when features are missing", {
reducedCounts <- reducedCounts[-1]

expect_false(isTRUE(all.equal(counts, reducedCounts)))
expect_equal(dataset$get_cat_features()$max(),
reducedDataset$get_cat_features()$max())
expect_equal(max(dataset$categorical_features$to_list()),
max(reducedDataset$categorical_features$to_list()))

})
Loading

0 comments on commit d26adbc

Please sign in to comment.