From 560c6b0d6055ce79d9775da95a17e68aa50f1b6e Mon Sep 17 00:00:00 2001 From: egillax Date: Mon, 21 Oct 2024 11:06:01 +0200 Subject: [PATCH] [WIP] --- R/CustomEmbeddingModel.R | 18 ++++++-- R/HelperFunctions.R | 10 ++++ inst/python/CustomEmbeddings.py | 30 ++++++++++++ inst/python/Dataset.py | 4 +- inst/python/InitStrategy.py | 81 +++++++++++++-------------------- inst/python/Transformer.py | 12 +++-- 6 files changed, 94 insertions(+), 61 deletions(-) create mode 100644 inst/python/CustomEmbeddings.py diff --git a/R/CustomEmbeddingModel.R b/R/CustomEmbeddingModel.R index 332a173..7772b21 100644 --- a/R/CustomEmbeddingModel.R +++ b/R/CustomEmbeddingModel.R @@ -17,12 +17,17 @@ # limitations under the License. #' Create default settings a model using custom embeddings #' -#' @description A model that uses custom embeddings such as Poincare embeddings -#' @param embeddingFilePath path to the saved Poincare embedding +#' @description A model that uses custom embeddings such as Poincare embeddings or +#' embeddings form a foundation model +#' @param embeddingFilePath path to the saved embeddings. The embeddings file +#' should be a pytorch file including a dictionary with two two fields: +#' `concept_ids`: a pytorch long tensor with the concept ids and `embeddings`: +#' a pytorch float tensor with the embeddings #' @param estimatorSettings created with `setEstimator` -#' @param modelSettings for the model to use, needs to have an embedding layer +#' @param modelSettings for the model to use, needs to have an embedding layer +#' with a name `embedding` which will be replaced by the custom embeddings #' -#' @return settings for a model using custom embeddings:w +#' @return settings for a model using custom embeddings #' #' @export setCustomEmbeddingModel <- function( @@ -50,6 +55,11 @@ setCustomEmbeddingModel <- function( randomSample = 1 ) ) { + embeddingFilePath <- normalizePath(embeddingFilePath) + checkIsClass(embeddingFilePath, "character") + checkFileExists(embeddingFilePath) + + path <- system.file("python", package = "DeepPatientLevelPrediction") estimatorSettings$initStrategy <- reticulate::import_from_path("InitStrategy", diff --git a/R/HelperFunctions.R b/R/HelperFunctions.R index 08a18db..da7e925 100644 --- a/R/HelperFunctions.R +++ b/R/HelperFunctions.R @@ -105,3 +105,13 @@ checkHigherEqual <- function(parameter, value) { } return(TRUE) } + +#' helper function to check if a file exists +#' @param file the file to check +checkFileExists <- function(file) { + if (!file.exists(file)) { + ParallelLogger::logError(paste0("File ", file, " does not exist")) + stop(paste0("File ", file, " does not exist")) + } + return(TRUE) +} diff --git a/inst/python/CustomEmbeddings.py b/inst/python/CustomEmbeddings.py new file mode 100644 index 0000000..da60288 --- /dev/null +++ b/inst/python/CustomEmbeddings.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn + +class CustomEmbeddings(nn.Module): + def __init__(self, + custom_embedding_weights: torch.Tensor, + embedding_dim: int, + num_regular_embeddings: int, + custom_indices: torch.Tensor, + freeze: bool = True): + super(CustomEmbeddings, self).__init__() + + self.custom_embeddings = nn.Embedding.from_pretrained(custom_embedding_weights, freeze=freeze) + self.regular_embeddings = nn.Embedding(num_regular_embeddings, embedding_dim) + + self.custom_indices = custom_indices + + self.linear_transform = nn.Linear(custom_embedding_weights.shape[1], embedding_dim) + + def forward(self, x): + custom_embeddings_mask = torch.isin(x, self.custom_indices) + custom_features = x[custom_embeddings_mask] + regular_features = x[~custom_embeddings_mask] + + custom_embeddings = self.custom_embeddings(custom_features) + regular_embeddings = self.regular_embeddings(regular_features) + + custom_embeddings = self.linear_transform(custom_embeddings) + + return torch.cat([custom_embeddings, regular_embeddings], dim=-1) diff --git a/inst/python/Dataset.py b/inst/python/Dataset.py index 84cc24a..eb41be2 100644 --- a/inst/python/Dataset.py +++ b/inst/python/Dataset.py @@ -110,8 +110,8 @@ def __init__(self, data, labels=None, numerical_features=None): def get_feature_info(self): return { - "numerical_features": self.numerical_features, - "cat_features": self.cat_features, + "numerical_features": torch.tensor(self.numerical_features), + "categorical_features": torch.tensor(self.cat_features), "reference": self.data_ref } diff --git a/inst/python/InitStrategy.py b/inst/python/InitStrategy.py index 36ad189..489a16b 100644 --- a/inst/python/InitStrategy.py +++ b/inst/python/InitStrategy.py @@ -4,6 +4,8 @@ import torch import polars as pl +from CustomEmbeddings import CustomEmbeddings + class InitStrategy(ABC): @abstractmethod def initialize(self, model, parameters): @@ -28,57 +30,36 @@ def initialize(self, model, parameters): class CustomEmbeddingInitStrategy(InitStrategy): def initialize(self, model, parameters): - file_path = pathlib.Path(parameters["estimator_settings"].get("embedding_file_path")).expanduser() + file_path = pathlib.Path(parameters["estimator_settings"].get("embedding_file_path")) + data_reference = parameters["model_parameters"]["feature_info"]["reference"] # Instantiate the model with the provided parameters model = model(**parameters["model_parameters"]) - if file_path.exists(): - state = torch.load(file_path) - state_dict = state["state_dict"] - embedding_key = "embedding.weight" - - if embedding_key not in state_dict: - raise KeyError(f"The key '{embedding_key}' does not exist in the state dictionary") - - custom_embeddings = state_dict[embedding_key].float() - - # Ensure that model_temp.categorical_embedding_2 exists - if not hasattr(model, 'embedding'): - raise AttributeError("The model does not have an embedding layer named 'embedding'") - - # # replace weights - cat2_mapping = pl.read_json(os.path.expanduser("~/Desktop/cat2_mapping_train.json")) - - concept_df = pl.DataFrame({"conceptId": state['names']}).with_columns(pl.col("conceptId")) - - # Initialize tensor for mapped embeddings - mapped_embeddings = torch.zeros((cat2_mapping.shape[0] + 1, new_embeddings.shape[1])) - - # Map embeddings to their corresponding indices - for row in cat2_mapping.iter_rows(): - concept_id, covariate_id, index = row - if concept_id in concept_df["conceptId"]: - concept_idx = concept_df["conceptId"].to_list().index(concept_id) - mapped_embeddings[index] = new_embeddings[concept_idx] - - # Assign the mapped embeddings to the model - model_temp.categorical_embedding_2.weight = torch.nn.Parameter(mapped_embeddings) - model_temp.categorical_embedding_2.weight.requires_grad = False - - else: - raise FileNotFoundError(f"File not found or path is incorrect: {file_path}") - - - # Create a dummy input batch that matches the model inputs - dummy_input = { - "cat": torch.randint(0, model_parameters['cat_features'], (1, 10)).long(), - "cat_2": torch.randint(0, model_parameters['cat_2_features'], (1, 10)).long(), - "num": torch.randn(1, model_parameters['num_features']) if model_parameters['num_features'] > 0 else None - } - - # Ensure that the dummy input does not contain `None` values if num_features == 0 - if model_parameters['num_features'] == 0: - del dummy_input["num"] - - return model_temp + embeddings = torch.load(file_path, weights_only=True) + + if "concept_ids" not in embeddings.keys() : + raise KeyError(f"The embeddings file does not contain the required 'concept_ids' key") + if "embeddings" not in embeddings.keys(): + raise KeyError(f"The embeddings file does not contain the required 'embeddings' key") + if embeddings["concept_ids"].dtype != torch.long: + raise TypeError(f"The 'concept_ids' key in the embeddings file must be of type torch.long") + if embeddings["embeddings"].dtype != torch.float: + raise TypeError(f"The 'embeddings' key in the embeddings file must be of type torch.float") + + # Ensure that the model has an embedding layer + if not hasattr(model, 'embedding'): + raise AttributeError(f"The model: {model.name} does not have an embedding layer named 'embedding' as " + f"required for custom embeddings") + + # get indices of the custom embeddings from embeddings["concept_ids"] + # I need to select the rows from data_reference where embeddings["concept_ids"] is in data_reference["conceptId"] + # data reference is a polars lazyframe + custom_indices = data_reference.filter(pl.col("conceptId").is_in(embeddings["concept_ids"].tolist())).select("columnId").collect() - 1 + custom_indices = custom_indices.to_torch().squeeze() + + model.embedding = CustomEmbeddings(custom_embedding_weights=embeddings["embeddings"], + embedding_dim=model.embedding.embedding_dim, + num_regular_embeddings=model.embedding.num_embeddings, + custom_indices=custom_indices) + return model diff --git a/inst/python/Transformer.py b/inst/python/Transformer.py index 63acdc6..16f7951 100644 --- a/inst/python/Transformer.py +++ b/inst/python/Transformer.py @@ -51,15 +51,17 @@ def __init__( num_heads = int(num_heads) dim_hidden = int(dim_hidden) dim_out = int(dim_out) - cat_features = feature_info["cat_features"] - num_features = feature_info["num_features"] + cat_features = feature_info["categorical_features"] + num_features = feature_info["numerical_features"] + cat_feature_size = len(cat_features) + num_feature_size = len(num_features) self.embedding = nn.Embedding( - cat_features + 1, dim_token, padding_idx=0 + cat_feature_size + 1, dim_token, padding_idx=0 ) - if num_features != 0 and num_features is not None: - self.numerical_embedding = NumericalEmbedding(num_features, dim_token) + if num_feature_size != 0 and num_feature_size is not None: + self.numerical_embedding = NumericalEmbedding(num_feature_size, dim_token) self.use_numerical = True else: self.use_numerical = False