Skip to content

Commit

Permalink
Add poincare class
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Oct 28, 2024
1 parent 761059a commit 584d0ce
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 62 deletions.
14 changes: 10 additions & 4 deletions R/CustomEmbeddingModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#' a pytorch float tensor with the embeddings
#' @param modelSettings for the model to use, needs to have an embedding layer
#' with a name `embedding` which will be replaced by the custom embeddings
#' @param embeddingsClass the class of the custom embeddings, e.g. `CustomEmbeddings`
#' or `PoincareEmbeddings`
#'
#' @return settings for a model using custom embeddings
#'
Expand All @@ -48,18 +50,22 @@ setCustomEmbeddingModel <- function(
device = "cpu"),
hyperParamSearch = "random",
randomSample = 1
)
),
embeddingsClass = "CustomEmbeddings"
) {
embeddingFilePath <- normalizePath(embeddingFilePath)
checkIsClass(embeddingFilePath, "character")
checkFileExists(embeddingFilePath)

checkIsClass(embeddingsClass, "character")
checkInStringVector(embeddingsClass, c("CustomEmbeddings", "PoincareEmbeddings"))

path <- system.file("python", package = "DeepPatientLevelPrediction")
modelSettings$estimatorSettings$initStrategy <-
reticulate::import_from_path("InitStrategy",
path = path)$CustomEmbeddingInitStrategy()
modelSettings$estimatorSettings$embeddingFilePath <- embeddingFilePath
path = path)$CustomEmbeddingInitStrategy(
embedding_class = embeddingsClass,
embedding_file = embeddingFilePath
)
transformerSettings <- modelSettings

attr(transformerSettings, "settings")$name <- "CustomEmbeddingModel"
Expand Down
11 changes: 11 additions & 0 deletions R/HelperFunctions.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,14 @@ checkFileExists <- function(file) {
}
return(TRUE)
}

checkInStringVector <- function(parameter, values) {
name <- deparse(substitute(parameter))
if (!parameter %in% values) {
ParallelLogger::logError(paste0(name, " should be ",
paste0(as.character(values),
collapse = "or ")))
stop(paste0(name, " has incorrect value"))
}
return(TRUE)
}
19 changes: 19 additions & 0 deletions inst/python/CustomEmbeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,33 @@ def __init__(self,

self.linear_transform = nn.Linear(custom_embedding_weights.shape[1], embedding_dim)

@staticmethod
def process_custom_embeddings(embeddings: torch.Tensor):
return embeddings

def forward(self, x):
custom_embeddings_mask = torch.isin(x, self.custom_indices.to(x.device))
custom_features = torch.where(custom_embeddings_mask, x, torch.tensor(0))
regular_features = torch.where(~custom_embeddings_mask, x, torch.tensor(0))

custom_embeddings = self.custom_embeddings(custom_features)
custom_embeddings = self.process_custom_embeddings(custom_embeddings)
regular_embeddings = self.regular_embeddings(regular_features)

custom_embeddings = self.linear_transform(custom_embeddings)

return custom_embeddings + regular_embeddings

def logmap0(input_tensor: torch.Tensor):
curvature = 1.0
norm_input = torch.norm(input_tensor, dim=-1, keepdim=True)
sqrt_c = torch.sqrt(torch.tensor(curvature, dtype=input_tensor.dtype, device=input_tensor.device))
scale = torch.arctanh(sqrt_c * norm_input) / (sqrt_c * norm_input)
scale[torch.isnan(scale)] = 1.0
return scale * input_tensor


class PoincareEmbeddings(CustomEmbeddings):
@staticmethod
def process_custom_embeddings(embeddings: torch.Tensor):
return logmap0(embeddings)
19 changes: 14 additions & 5 deletions inst/python/InitStrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import polars as pl

from CustomEmbeddings import CustomEmbeddings
from CustomEmbeddings import CustomEmbeddings, PoincareEmbeddings

class InitStrategy(ABC):
@abstractmethod
Expand All @@ -29,6 +29,14 @@ def initialize(self, model, parameters):


class CustomEmbeddingInitStrategy(InitStrategy):
def __init__(self, embedding_class: str, embedding_file: str):
self.embedding_class = embedding_class
self.embedding_file = embedding_file
self.class_names_to_class = {
"CustomEmbeddings": CustomEmbeddings,
"PoincareEmbeddings": PoincareEmbeddings
}

def initialize(self, model, parameters):
file_path = pathlib.Path(parameters["estimator_settings"].get("embedding_file_path"))
data_reference = parameters["model_parameters"]["feature_info"]["reference"]
Expand Down Expand Up @@ -58,8 +66,9 @@ def initialize(self, model, parameters):
custom_indices = data_reference.filter(pl.col("conceptId").is_in(embeddings["concept_ids"].tolist())).select("columnId").collect() - 1
custom_indices = custom_indices.to_torch().squeeze()

model.embedding = CustomEmbeddings(custom_embedding_weights=embeddings["embeddings"],
embedding_dim=model.embedding.embedding_dim,
num_regular_embeddings=model.embedding.num_embeddings,
custom_indices=custom_indices)
embedding_class = self.class_names_to_class[self.embedding_class]
model.embedding = embedding_class(custom_embedding_weights=embeddings["embeddings"],
embedding_dim=model.embedding.embedding_dim,
num_regular_embeddings=model.embedding.num_embeddings,
custom_indices=custom_indices)
return model
10 changes: 0 additions & 10 deletions inst/python/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,6 @@

from ResNet import NumericalEmbedding

class LogMap0(nn.Module):
def forward(self, y):
curvature=1.0
norm_y = torch.norm(y, dim=-1, keepdim=True)
sqrt_c = torch.sqrt(torch.tensor(curvature, dtype=y.dtype, device=y.device))
scale = torch.arctanh(sqrt_c * norm_y) / (sqrt_c * norm_y)
scale[torch.isnan(scale)] = 1.0
return scale * y

def reglu(x):
a, b = x.chunk(2, dim=-1)
return a * F.relu(b)
Expand Down Expand Up @@ -100,7 +91,6 @@ def __init__(
self.head_activation = head_activation
self.head_normalization = head_norm
self.dim_out = dim_out
self.logmap0 = LogMap0()

def forward(self, x):
mask = torch.where(x["cat"] == 0, True, False)
Expand Down
6 changes: 5 additions & 1 deletion man/setCustomEmbeddingModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 0 additions & 42 deletions tests/testthat/test-PoincareTransformer.R

This file was deleted.

0 comments on commit 584d0ce

Please sign in to comment.