Skip to content

Commit

Permalink
[WIP] custom embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Oct 21, 2024
1 parent 95d883d commit d32db5c
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 311 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Remotes:
RoxygenNote: 7.3.2
Encoding: UTF-8
Config/testthat/edition: 3
Config/testthat/parallel: TRUE
Config/testthat/parallel: FALSE
Config/reticulate:
list(
packages = list(
Expand Down
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
export(fitEstimator)
export(gridCvDeep)
export(predictDeepEstimator)
export(setCustomEmbeddingTransformer)
export(setCustomEmbeddingModel)
export(setDefaultResNet)
export(setDefaultTransformer)
export(setEstimator)
Expand Down
62 changes: 62 additions & 0 deletions R/CustomEmbeddingModel.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# @file CustomEmbeddingModel.R
#
# Copyright 2024 Observational Health Data Sciences and Informatics
#
# This file is part of DeepPatientLevelPrediction
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#' Create default settings a model using custom embeddings
#'
#' @description A model that uses custom embeddings such as Poincare embeddings
#' @param embeddingFilePath path to the saved Poincare embedding
#' @param estimatorSettings created with `setEstimator`
#' @param modelSettings for the model to use, needs to have an embedding layer
#'
#' @return settings for a model using custom embeddings:w
#'
#' @export
setCustomEmbeddingModel <- function(
embeddingFilePath,
estimatorSettings =
setEstimator(
learningRate = "auto",
weightDecay = 1e-4,
batchSize = 256,
epochs = 2,
seed = NULL,
device = "cpu"
),
modelSettings = setTransformer(
numBlocks = 3,
dimToken = 16,
dimOut = 1,
numHeads = 4,
attDropout = 0.2,
ffnDropout = 0.1,
resDropout = 0.0,
dimHidden = 32,
estimatorSettings = estimatorSettings,
hyperParamSearch = "random",
randomSample = 1
)
) {
path <- system.file("python", package = "DeepPatientLevelPrediction")
estimatorSettings$initStrategy <-
reticulate::import_from_path("InitStrategy",
path = path)$CustomEmbeddingInitStrategy()
estimatorSettings$embeddingFilePath <- embeddingFilePath
transformerSettings <- modelSettings

attr(transformerSettings, "settings")$name <- "CustomEmbeddingTransformer"
return(transformerSettings)
}
33 changes: 16 additions & 17 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -482,22 +482,20 @@ evalEstimatorSettings <- function(estimatorSettings) {
estimatorSettings
}

createEstimator <- function(modelParameters,
estimatorSettings) {
createEstimator <- function(parameters) {
path <- system.file("python", package = "DeepPatientLevelPrediction")
model <-
reticulate::import_from_path(modelParameters$modelType,
path = path)[[modelParameters$modelType]]
reticulate::import_from_path(parameters$modelParameters$modelType,
path = path)[[parameters$modelParameters$modelType]]
estimator <- reticulate::import_from_path("Estimator", path = path)$Estimator

modelParameters <- camelCaseToSnakeCaseNames(modelParameters)
estimatorSettings <- camelCaseToSnakeCaseNames(estimatorSettings)
estimatorSettings <- evalEstimatorSettings(estimatorSettings)

parameters$modelParameters <- camelCaseToSnakeCaseNames(parameters$modelParameters)
parameters$estimatorSettings <- camelCaseToSnakeCaseNames(parameters$estimatorSettings)
parameters$estimatorSettings <- evalEstimatorSettings(parameters$estimatorSettings)
parameters <- camelCaseToSnakeCaseNames(parameters)
estimator <- estimator(
model = model,
model_parameters = modelParameters,
estimator_settings = estimatorSettings
parameters = parameters
)
return(estimator)
}
Expand Down Expand Up @@ -588,17 +586,19 @@ doCrossValidationImpl <- function(dataset,
)]
currentModelParams <- parameters[modelSettings$modelParamNames]
attr(currentModelParams, "metaData")$names <-
modelSettings$modelParamNameCH
modelSettings$modelParamNames
currentModelParams$modelType <- modelSettings$modelType
currentEstimatorSettings <-
fillEstimatorSettings(modelSettings$estimatorSettings,
fitParams,
parameters)
currentModelParams$catFeatures <- dataset$get_cat_features()$len()
currentModelParams$numFeatures <- dataset$get_numerical_features()$len()
currentModelParams$cat2Features <- dataset$get_cat_2_features()$len()
currentModelParams$feature_info <- dataset$get_feature_info()
currentParameters <- list(
modelParameters = currentModelParams,
estimatorSettings = currentEstimatorSettings
)
if (currentEstimatorSettings$findLR) {
lr <- getLR(currentModelParams, currentEstimatorSettings, dataset)
lr <- getLR(currentParameters, dataset)
ParallelLogger::logInfo(paste0("Auto learning rate selected as: ", lr))
currentEstimatorSettings$learningRate <- lr
}
Expand All @@ -623,8 +623,7 @@ doCrossValidationImpl <- function(dataset,
testDataset <- torch$utils$data$Subset(dataset,
indices =
as.integer(which(fold == i) - 1))
estimator <- createEstimator(modelParameters = currentModelParams,
estimatorSettings = currentEstimatorSettings)
estimator <- createEstimator(parameters)
fit_estimator(estimator, trainDataset, testDataset)

ParallelLogger::logInfo("Calculating predictions on left out fold set...")
Expand Down
8 changes: 3 additions & 5 deletions R/LRFinder.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
getLR <- function(modelParameters,
estimatorSettings,
getLR <- function(parameters,
dataset,
lrSettings = NULL) {
path <- system.file("python", package = "DeepPatientLevelPrediction")
estimator <- createEstimator(modelParameters = modelParameters,
estimatorSettings = estimatorSettings)
estimator <- createEstimator(parameters = parameters)
if (!is.null(lrSettings)) {
lrSettings <- camelCaseToSnakeCaseNames(lrSettings)
}
get_lr <- reticulate::import_from_path("LrFinder", path)$get_lr
lr <- get_lr(estimator, dataset, lrSettings)
return(lr)
}
}
42 changes: 0 additions & 42 deletions R/Transformer.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,48 +50,6 @@ setDefaultTransformer <- function(estimatorSettings =
return(transformerSettings)
}

#' Create default settings for a non-temporal transformer
#'
#' @description A transformer model with a Poincare embedding of diseases
#' @details Hierarchical embedding of disease concept in SNOMED medical terms
#' @param estimatorSettings created with `setEstimator`
#' @param embeddingFilePath path to the saved Poincare embedding
#'
#' @export
setCustomEmbeddingTransformer <- function(
embeddingFilePath,
estimatorSettings =
setEstimator(
learningRate = "auto",
weightDecay = 1e-4,
batchSize = 256,
epochs = 2,
seed = NULL,
device = "cpu"
)
) {
path <- system.file("python", package = "DeepPatientLevelPrediction")
estimatorSettings$initStrategy <-
reticulate::import_from_path("InitStrategy",
path = path)$CustomEmbeddingInitStrategy()
estimatorSettings$embeddingFilePath <- embeddingFilePath
transformerSettings <- setTransformer(
numBlocks = 3,
dimToken = 16,
dimOut = 1,
numHeads = 4,
attDropout = 0.2,
ffnDropout = 0.1,
resDropout = 0.0,
dimHidden = 32,
estimatorSettings = estimatorSettings,
hyperParamSearch = "random",
randomSample = 1
)
attr(transformerSettings, "settings")$name <- "customEmbeddingTransformer"
return(transformerSettings)
}

#' create settings for training a non-temporal transformer
#'
#' @description A transformer model
Expand Down
132 changes: 12 additions & 120 deletions inst/python/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,8 @@
import torch
from torch.utils.data import Dataset

from pathlib import Path
import json
import os

class Data(Dataset):
def __init__(self, data, labels=None, numerical_features=None,
in_cat_1_mapping=None, in_cat_2_mapping=None):
desktop_path = Path.home() / "Desktop"

desktop_path = Path.home() / "Desktop"
with open(desktop_path / "data_path.txt", 'w') as f:
f.write(data)
with open(desktop_path / "labels.json", 'w') as f:
json.dump(labels, f)

# desktop_path = Path.home() / "Desktop"
# with open(desktop_path / "data_path.txt", 'r') as f:
# data = f.read().strip()
# with open(desktop_path / "labels.json", 'r') as f:
# labels = json.load(f)

file_path = "/Users/henrikjohn/Desktop/poincare_model_dim_3.pt"
state = torch.load(file_path)
embed_names = state["names"]
# print(f"Restored Names: {state['names']}")

def __init__(self, data, labels=None, numerical_features=None):
"""
data: path to a covariates dataframe either arrow dataset or sqlite object
labels: a list of either 0 or 1, 1 if the patient got the outcome
Expand Down Expand Up @@ -67,9 +43,6 @@ def __init__(self, data, labels=None, numerical_features=None,
else:
self.target = torch.zeros(size=(observations,))

cat2_feature_names = []
cat2_feature_names += embed_names

# filter by categorical columns,
# select rowId and columnId
data_cat = (
Expand All @@ -80,35 +53,7 @@ def __init__(self, data, labels=None, numerical_features=None,
.collect()
)

# find concepts from the embedding that are available in the data
cat2_ref = (
self.data_ref
.filter(pl.col("conceptId").is_in(cat2_feature_names))
.select("covariateId")
.collect()
)

# Now, use 'cat2_ref' as a normal DataFrame and access "columnId"
data_cat_1 = data_cat.filter(
~pl.col("covariateId").is_in(cat2_ref["covariateId"]))

self.cat_1_mapping = None
if in_cat_1_mapping is None:
self.cat_1_mapping = pl.DataFrame({
"covariateId": data_cat_1["covariateId"].unique(),
"index": pl.Series(range(1, len(data_cat_1["covariateId"].unique()) + 1))
})
# self.cat_1_mapping = pl.DataFrame(self.cat_1_mapping)
self.cat_1_mapping.write_json(str(desktop_path / "cat1_mapping_train.json"))
else:
self.cat_1_mapping = pl.DataFrame(in_cat_1_mapping).with_columns(pl.col('index').cast(pl.Int64), pl.col('covariateId').cast(pl.Float64))
self.cat_1_mapping.write_json(str(desktop_path / "cat1_mapping_test.json"))


data_cat_1 = data_cat_1.join(self.cat_1_mapping, on="covariateId", how="left") \
.select(pl.col("rowId"), pl.col("index").alias("covariateId"))

cat_tensor = torch.tensor(data_cat_1.to_numpy())
cat_tensor = torch.tensor(data_cat.to_numpy())
tensor_list = torch.split(
cat_tensor[:, 1],
torch.unique_consecutive(cat_tensor[:, 0], return_counts=True)[1].tolist(),
Expand All @@ -117,55 +62,12 @@ def __init__(self, data, labels=None, numerical_features=None,
# because of subjects without cat features, I need to create a list with all zeroes and then insert
# my tensorList. That way I can still index the dataset correctly.
total_list = [torch.as_tensor((0,))] * observations
idx = data_cat_1["rowId"].unique().to_list()
idx = data_cat["rowId"].unique().to_list()
for i, i2 in enumerate(idx):
total_list[i2] = tensor_list[i]
self.cat = torch.nn.utils.rnn.pad_sequence(total_list, batch_first=True)
self.cat_features = data_cat_1["covariateId"].unique()

# process cat_2 features
data_cat_2 = data_cat.filter(
pl.col("covariateId").is_in(cat2_ref))
self.cat_features = data_cat["covariateId"].unique()

self.cat_2_mapping = None
if in_cat_2_mapping is None:
self.cat_2_mapping = pl.DataFrame({
"covariateId": data_cat_2["covariateId"].unique(),
"index": pl.Series(range(1, len(data_cat_2["covariateId"].unique()) + 1))
})
self.cat_2_mapping = self.cat_2_mapping.lazy()
self.cat_2_mapping = (
self.data_ref
.filter(pl.col("covariateId").is_in(data_cat_2["covariateId"].unique()))
.select(pl.col("conceptId"), pl.col("covariateId"))
.join(self.cat_2_mapping, on="covariateId", how="left")
.collect()
)
self.cat_2_mapping.write_json(str(desktop_path / "cat2_mapping_train.json"))
else:
self.cat_2_mapping = pl.DataFrame(in_cat_2_mapping).with_columns(pl.col('index').cast(pl.Int64), pl.col('covariateId').cast(pl.Float64))
self.cat_2_mapping.write_json(str(desktop_path / "cat2_mapping_test.json"))

# cat_2_mapping.write_json(str(desktop_path / "cat2_mapping.json"))

data_cat_2 = data_cat_2.join(self.cat_2_mapping, on="covariateId", how="left") \
.select(pl.col("rowId"), pl.col("index").alias("covariateId")) # maybe rename this to something else

cat_2_tensor = torch.tensor(data_cat_2.to_numpy())
tensor_list_2 = torch.split(
cat_2_tensor[:, 1],
torch.unique_consecutive(cat_2_tensor[:, 0], return_counts=True)[
1].tolist(),
)

total_list_2 = [torch.as_tensor((0,))] * observations
idx_2 = data_cat_2["rowId"].unique().to_list()
for i, i2 in enumerate(idx_2):
total_list_2[i2] = tensor_list_2[i]
self.cat_2 = torch.nn.utils.rnn.pad_sequence(total_list_2,
batch_first=True)
self.cat_2_features = data_cat_2["covariateId"].unique()

# numerical data,
# N x C, dense matrix with values for N patients/visits for C numerical features
if self.numerical_features.count() == 0:
Expand Down Expand Up @@ -206,33 +108,23 @@ def __init__(self, data, labels=None, numerical_features=None,
delta = time.time() - start
print(f"Processed data in {delta:.2f} seconds")

def get_numerical_features(self):
return self.numerical_features

def get_cat_features(self):
return self.cat_features

def get_cat_2_features(self):
return self.cat_2_features

def get_cat_2_mapping(self):
return self.cat_2_mapping

def get_cat_1_mapping(self):
return self.cat_1_mapping
def get_feature_info(self):
return {
"numerical_features": self.numerical_features,
"cat_features": self.cat_features,
"reference": self.data_ref
}

def __len__(self):
return self.target.size()[0]

def __getitem__(self, item):
if self.num is not None:
batch = {"cat": self.cat[item, :], "num": self.num[item, :], "cat_2": self.cat_2[item, :]}
batch = {"cat": self.cat[item, :], "num": self.num[item, :]}
else:
batch = {"cat": self.cat[item, :].squeeze(), "num": None, "cat_2": self.cat_2[item, :].squeeze(), }
batch = {"cat": self.cat[item, :].squeeze(), "num": None}
if batch["cat"].dim() == 1:
batch["cat"] = batch["cat"].unsqueeze(0)
if batch["cat_2"].dim() == 1:
batch["cat_2"] = batch["cat_2"].unsqueeze(0)
if (batch["num"] is not None
and batch["num"].dim() == 1
and not isinstance(item, list)
Expand Down
Loading

0 comments on commit d32db5c

Please sign in to comment.