Skip to content

Commit

Permalink
[WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Oct 21, 2024
1 parent d32db5c commit 560c6b0
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 61 deletions.
18 changes: 14 additions & 4 deletions R/CustomEmbeddingModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
# limitations under the License.
#' Create default settings a model using custom embeddings
#'
#' @description A model that uses custom embeddings such as Poincare embeddings
#' @param embeddingFilePath path to the saved Poincare embedding
#' @description A model that uses custom embeddings such as Poincare embeddings or
#' embeddings form a foundation model
#' @param embeddingFilePath path to the saved embeddings. The embeddings file
#' should be a pytorch file including a dictionary with two two fields:
#' `concept_ids`: a pytorch long tensor with the concept ids and `embeddings`:
#' a pytorch float tensor with the embeddings
#' @param estimatorSettings created with `setEstimator`
#' @param modelSettings for the model to use, needs to have an embedding layer
#' @param modelSettings for the model to use, needs to have an embedding layer
#' with a name `embedding` which will be replaced by the custom embeddings
#'
#' @return settings for a model using custom embeddings:w
#' @return settings for a model using custom embeddings
#'
#' @export
setCustomEmbeddingModel <- function(
Expand Down Expand Up @@ -50,6 +55,11 @@ setCustomEmbeddingModel <- function(
randomSample = 1
)
) {
embeddingFilePath <- normalizePath(embeddingFilePath)
checkIsClass(embeddingFilePath, "character")
checkFileExists(embeddingFilePath)


path <- system.file("python", package = "DeepPatientLevelPrediction")
estimatorSettings$initStrategy <-
reticulate::import_from_path("InitStrategy",
Expand Down
10 changes: 10 additions & 0 deletions R/HelperFunctions.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,13 @@ checkHigherEqual <- function(parameter, value) {
}
return(TRUE)
}

#' helper function to check if a file exists
#' @param file the file to check
checkFileExists <- function(file) {
if (!file.exists(file)) {
ParallelLogger::logError(paste0("File ", file, " does not exist"))
stop(paste0("File ", file, " does not exist"))
}
return(TRUE)
}
30 changes: 30 additions & 0 deletions inst/python/CustomEmbeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import torch.nn as nn

class CustomEmbeddings(nn.Module):
def __init__(self,
custom_embedding_weights: torch.Tensor,
embedding_dim: int,
num_regular_embeddings: int,
custom_indices: torch.Tensor,
freeze: bool = True):
super(CustomEmbeddings, self).__init__()

self.custom_embeddings = nn.Embedding.from_pretrained(custom_embedding_weights, freeze=freeze)
self.regular_embeddings = nn.Embedding(num_regular_embeddings, embedding_dim)

self.custom_indices = custom_indices

self.linear_transform = nn.Linear(custom_embedding_weights.shape[1], embedding_dim)

def forward(self, x):
custom_embeddings_mask = torch.isin(x, self.custom_indices)
custom_features = x[custom_embeddings_mask]
regular_features = x[~custom_embeddings_mask]

custom_embeddings = self.custom_embeddings(custom_features)
regular_embeddings = self.regular_embeddings(regular_features)

custom_embeddings = self.linear_transform(custom_embeddings)

return torch.cat([custom_embeddings, regular_embeddings], dim=-1)
4 changes: 2 additions & 2 deletions inst/python/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def __init__(self, data, labels=None, numerical_features=None):

def get_feature_info(self):
return {
"numerical_features": self.numerical_features,
"cat_features": self.cat_features,
"numerical_features": torch.tensor(self.numerical_features),
"categorical_features": torch.tensor(self.cat_features),
"reference": self.data_ref
}

Expand Down
81 changes: 31 additions & 50 deletions inst/python/InitStrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
import polars as pl

from CustomEmbeddings import CustomEmbeddings

class InitStrategy(ABC):
@abstractmethod
def initialize(self, model, parameters):
Expand All @@ -28,57 +30,36 @@ def initialize(self, model, parameters):

class CustomEmbeddingInitStrategy(InitStrategy):
def initialize(self, model, parameters):
file_path = pathlib.Path(parameters["estimator_settings"].get("embedding_file_path")).expanduser()
file_path = pathlib.Path(parameters["estimator_settings"].get("embedding_file_path"))
data_reference = parameters["model_parameters"]["feature_info"]["reference"]

# Instantiate the model with the provided parameters
model = model(**parameters["model_parameters"])

if file_path.exists():
state = torch.load(file_path)
state_dict = state["state_dict"]
embedding_key = "embedding.weight"

if embedding_key not in state_dict:
raise KeyError(f"The key '{embedding_key}' does not exist in the state dictionary")

custom_embeddings = state_dict[embedding_key].float()

# Ensure that model_temp.categorical_embedding_2 exists
if not hasattr(model, 'embedding'):
raise AttributeError("The model does not have an embedding layer named 'embedding'")

# # replace weights
cat2_mapping = pl.read_json(os.path.expanduser("~/Desktop/cat2_mapping_train.json"))

concept_df = pl.DataFrame({"conceptId": state['names']}).with_columns(pl.col("conceptId"))

# Initialize tensor for mapped embeddings
mapped_embeddings = torch.zeros((cat2_mapping.shape[0] + 1, new_embeddings.shape[1]))

# Map embeddings to their corresponding indices
for row in cat2_mapping.iter_rows():
concept_id, covariate_id, index = row
if concept_id in concept_df["conceptId"]:
concept_idx = concept_df["conceptId"].to_list().index(concept_id)
mapped_embeddings[index] = new_embeddings[concept_idx]

# Assign the mapped embeddings to the model
model_temp.categorical_embedding_2.weight = torch.nn.Parameter(mapped_embeddings)
model_temp.categorical_embedding_2.weight.requires_grad = False

else:
raise FileNotFoundError(f"File not found or path is incorrect: {file_path}")


# Create a dummy input batch that matches the model inputs
dummy_input = {
"cat": torch.randint(0, model_parameters['cat_features'], (1, 10)).long(),
"cat_2": torch.randint(0, model_parameters['cat_2_features'], (1, 10)).long(),
"num": torch.randn(1, model_parameters['num_features']) if model_parameters['num_features'] > 0 else None
}

# Ensure that the dummy input does not contain `None` values if num_features == 0
if model_parameters['num_features'] == 0:
del dummy_input["num"]

return model_temp
embeddings = torch.load(file_path, weights_only=True)

if "concept_ids" not in embeddings.keys() :
raise KeyError(f"The embeddings file does not contain the required 'concept_ids' key")
if "embeddings" not in embeddings.keys():
raise KeyError(f"The embeddings file does not contain the required 'embeddings' key")
if embeddings["concept_ids"].dtype != torch.long:
raise TypeError(f"The 'concept_ids' key in the embeddings file must be of type torch.long")
if embeddings["embeddings"].dtype != torch.float:
raise TypeError(f"The 'embeddings' key in the embeddings file must be of type torch.float")

# Ensure that the model has an embedding layer
if not hasattr(model, 'embedding'):
raise AttributeError(f"The model: {model.name} does not have an embedding layer named 'embedding' as "
f"required for custom embeddings")

# get indices of the custom embeddings from embeddings["concept_ids"]
# I need to select the rows from data_reference where embeddings["concept_ids"] is in data_reference["conceptId"]
# data reference is a polars lazyframe
custom_indices = data_reference.filter(pl.col("conceptId").is_in(embeddings["concept_ids"].tolist())).select("columnId").collect() - 1
custom_indices = custom_indices.to_torch().squeeze()

model.embedding = CustomEmbeddings(custom_embedding_weights=embeddings["embeddings"],
embedding_dim=model.embedding.embedding_dim,
num_regular_embeddings=model.embedding.num_embeddings,
custom_indices=custom_indices)
return model
12 changes: 7 additions & 5 deletions inst/python/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,17 @@ def __init__(
num_heads = int(num_heads)
dim_hidden = int(dim_hidden)
dim_out = int(dim_out)
cat_features = feature_info["cat_features"]
num_features = feature_info["num_features"]
cat_features = feature_info["categorical_features"]
num_features = feature_info["numerical_features"]
cat_feature_size = len(cat_features)
num_feature_size = len(num_features)

self.embedding = nn.Embedding(
cat_features + 1, dim_token, padding_idx=0
cat_feature_size + 1, dim_token, padding_idx=0
)

if num_features != 0 and num_features is not None:
self.numerical_embedding = NumericalEmbedding(num_features, dim_token)
if num_feature_size != 0 and num_feature_size is not None:
self.numerical_embedding = NumericalEmbedding(num_feature_size, dim_token)
self.use_numerical = True
else:
self.use_numerical = False
Expand Down

0 comments on commit 560c6b0

Please sign in to comment.