From d446a1b31873fc671ebf287fa633948deacb3eb3 Mon Sep 17 00:00:00 2001 From: Henrik John Date: Thu, 29 Aug 2024 11:32:32 +0200 Subject: [PATCH] Return mappings during training --- R/Estimator.R | 21 ++++++++++++++++----- R/Transformer.R | 10 +++++----- inst/python/Dataset.py | 26 +++++++++++++++----------- inst/python/InitStrategy.py | 32 ++++++++++++++------------------ 4 files changed, 50 insertions(+), 39 deletions(-) diff --git a/R/Estimator.R b/R/Estimator.R index 850eab1..84e0875 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -219,7 +219,9 @@ fitEstimator <- function(trainData, covariateValue = 0, isNumeric = .data$columnId %in% cvResult$numericalIndex # get mapping maybe here - ) + ) %>% + left_join(cvResult$cat1Mapping %>% rename(cat1Idx = index), by = "covariateId") %>% + left_join(cvResult$cat2Mapping %>% rename(cat2Idx = index), by = "covariateId") comp <- start - Sys.time() result <- list( @@ -298,7 +300,14 @@ predictDeepEstimator <- function(plpModel, plpModel <- list(model = plpModel) attr(plpModel, "modelType") <- "binary" } - if ("plpData" %in% class(data)) { + + if (!is.null(plpModel$covariateImportance)) { + # this means that the model finished training since only in the end covariateImportance is added + browser() + + # data <- createDataset(mappedData, plpModel = plpModel) + + } else if ("plpData" %in% class(data)) { mappedData <- PatientLevelPrediction::MapIds(data$covariateData, cohort = cohort, mapping = plpModel$covariateImportance %>% @@ -424,7 +433,8 @@ gridCvDeep <- function(mappedData, prediction$cohortStartDate <- as.Date(prediction$cohortStartDate, origin = "1970-01-01") numericalIndex <- dataset$get_numerical_features() - # get mapping as above + cat1Mapping <- as.data.frame(dataset$get_cat_1_mapping()) + cat2Mapping <- as.data.frame(dataset$get_cat_2_mapping()) # save torch code here if (!dir.exists(file.path(modelLocation))) { @@ -437,8 +447,9 @@ gridCvDeep <- function(mappedData, prediction = prediction, finalParam = finalParam, paramGridSearch = paramGridSearch, - numericalIndex = numericalIndex$to_list() - # add mapping here, two columns [covariateId, columnId] + numericalIndex = numericalIndex$to_list(), + cat1Mapping = cat1Mapping, + cat2Mapping = cat2Mapping ) ) } diff --git a/R/Transformer.R b/R/Transformer.R index efa92f6..5f70805 100644 --- a/R/Transformer.R +++ b/R/Transformer.R @@ -64,8 +64,8 @@ setCustomEmbeddingTransformer <- function( setEstimator( learningRate = "auto", weightDecay = 1e-4, - batchSize = 512, - epochs = 10, + batchSize = 256, + epochs = 2, seed = NULL, device = "cpu" ) @@ -77,13 +77,13 @@ setCustomEmbeddingTransformer <- function( estimatorSettings$embeddingFilePath <- embeddingFilePath transformerSettings <- setTransformer( numBlocks = 3, - dimToken = 192, + dimToken = 16, dimOut = 1, - numHeads = 8, + numHeads = 4, attDropout = 0.2, ffnDropout = 0.1, resDropout = 0.0, - dimHidden = 256, + dimHidden = 32, estimatorSettings = estimatorSettings, hyperParamSearch = "random", randomSample = 1 diff --git a/inst/python/Dataset.py b/inst/python/Dataset.py index aaf5e87..04cd972 100644 --- a/inst/python/Dataset.py +++ b/inst/python/Dataset.py @@ -78,8 +78,6 @@ def __init__(self, data, labels=None, numerical_features=None, if cat2_feature_names is None: cat2_feature_names = [] - self.feature_mapping = {} - cat2_feature_names += embed_names # filter by categorical columns, @@ -103,13 +101,13 @@ def __init__(self, data, labels=None, numerical_features=None, # Now, use 'cat2_ref' as a normal DataFrame and access "columnId" data_cat_1 = data_cat.filter( ~pl.col("covariateId").is_in(cat2_ref["covariateId"])) - cat_1_mapping = pl.DataFrame({ + self.cat_1_mapping = pl.DataFrame({ "covariateId": data_cat_1["covariateId"].unique(), "index": pl.Series(range(1, len(data_cat_1["covariateId"].unique()) + 1)) }) - cat_1_mapping.write_json(str(desktop_path / "cat1_mapping.json")) + self.cat_1_mapping.write_json(str(desktop_path / "cat1_mapping.json")) - data_cat_1 = data_cat_1.join(cat_1_mapping, on="covariateId", how="left") \ + data_cat_1 = data_cat_1.join(self.cat_1_mapping, on="covariateId", how="left") \ .select(pl.col("rowId"), pl.col("index").alias("covariateId")) cat_tensor = torch.tensor(data_cat_1.to_numpy()) @@ -130,22 +128,22 @@ def __init__(self, data, labels=None, numerical_features=None, # process cat_2 features data_cat_2 = data_cat.filter( pl.col("covariateId").is_in(cat2_ref)) - cat_2_mapping = pl.DataFrame({ + self.cat_2_mapping = pl.DataFrame({ "covariateId": data_cat_2["covariateId"].unique(), "index": pl.Series(range(1, len(data_cat_2["covariateId"].unique()) + 1)) }) - cat_2_mapping = cat_2_mapping.lazy() - cat_2_mapping = ( + self.cat_2_mapping = self.cat_2_mapping.lazy() + self.cat_2_mapping = ( self.data_ref .filter(pl.col("covariateId").is_in(data_cat_2["covariateId"].unique())) .select(pl.col("conceptId"), pl.col("covariateId")) - .join(cat_2_mapping, on="covariateId", how="left") + .join(self.cat_2_mapping, on="covariateId", how="left") .collect() ) - cat_2_mapping.write_json(str(desktop_path / "cat2_mapping.json")) + self.cat_2_mapping.write_json(str(desktop_path / "cat2_mapping.json")) # cat_2_mapping.write_json(str(desktop_path / "cat2_mapping.json")) - data_cat_2 = data_cat_2.join(cat_2_mapping, on="covariateId", how="left") \ + data_cat_2 = data_cat_2.join(self.cat_2_mapping, on="covariateId", how="left") \ .select(pl.col("rowId"), pl.col("index").alias("covariateId")) # maybe rename this to something else cat_2_tensor = torch.tensor(data_cat_2.to_numpy()) @@ -211,6 +209,12 @@ def get_cat_features(self): def get_cat_2_features(self): return self.cat_2_features + + def get_cat_2_mapping(self): + return self.cat_2_mapping + + def get_cat_1_mapping(self): + return self.cat_1_mapping def __len__(self): return self.target.size()[0] diff --git a/inst/python/InitStrategy.py b/inst/python/InitStrategy.py index 14216a1..c0f72ba 100644 --- a/inst/python/InitStrategy.py +++ b/inst/python/InitStrategy.py @@ -32,12 +32,9 @@ class CustomEmbeddingInitStrategy(InitStrategy): def initialize(self, model, model_parameters, estimator_settings): file_path = estimator_settings.get("embedding_file_path") - # Ensure `cat_2_features` is added to `model_parameters` - # cat_2_features_default = 20 # Set a default value if you don't have one - print(model_parameters['cat_2_features']) - print(model_parameters['cat_features']) - print(model_parameters['num_features']) - + # print(model_parameters['cat_2_features']) + # print(model_parameters['cat_features']) + # print(model_parameters['num_features']) # Instantiate the model with the provided parameters model_temp = model(**model_parameters) @@ -51,7 +48,7 @@ def initialize(self, model, model_parameters, estimator_settings): raise KeyError(f"The key '{embedding_key}' does not exist in the state dictionary") new_embeddings = state_dict[embedding_key].float() - print(f"new_embeddings: {new_embeddings}") + # print(f"new_embeddings: {new_embeddings}") # Ensure that model_temp.categorical_embedding_2 exists if not hasattr(model_temp, 'categorical_embedding_2'): @@ -60,10 +57,10 @@ def initialize(self, model, model_parameters, estimator_settings): # # replace weights # cat2_concept_mapping = pl.read_json(os.path.expanduser("~/Desktop/cat2_concept_mapping.json")) cat2_mapping = pl.read_json(os.path.expanduser("~/Desktop/cat2_mapping.json")) - print(f"cat2_mapping: {cat2_mapping}") + # print(f"cat2_mapping: {cat2_mapping}") concept_df = pl.DataFrame({"conceptId": state['names']}).with_columns(pl.col("conceptId")) - print(f"concept_df: {concept_df}") + # print(f"concept_df: {concept_df}") # Initialize tensor for mapped embeddings mapped_embeddings = torch.zeros((cat2_mapping.shape[0] + 1, new_embeddings.shape[1])) @@ -75,20 +72,19 @@ def initialize(self, model, model_parameters, estimator_settings): concept_idx = concept_df["conceptId"].to_list().index(concept_id) mapped_embeddings[index] = new_embeddings[concept_idx] - print(f"mapped_embeddings: {mapped_embeddings}") + # print(f"mapped_embeddings: {mapped_embeddings}") # Assign the mapped embeddings to the model model_temp.categorical_embedding_2.weight = torch.nn.Parameter(mapped_embeddings) model_temp.categorical_embedding_2.weight.requires_grad = False - print("New Embeddings:") - print(new_embeddings) - print(f"Restored Epoch: {state['epoch']}") - print(f"Restored Mean Rank: {state['mean_rank']}") - print(f"Restored Loss: {state['loss']}") - print(f"Restored Names: {state['names'][:5]}") - print(f"Number of names: {len(state['names'])}") - # print(f"Filtered Embeddings: {filtered_embeddings}") + # print("New Embeddings:") + # print(new_embeddings) + # print(f"Restored Epoch: {state['epoch']}") + # print(f"Restored Mean Rank: {state['mean_rank']}") + # print(f"Restored Loss: {state['loss']}") + # print(f"Restored Names: {state['names'][:5]}") + # print(f"Number of names: {len(state['names'])}") else: raise FileNotFoundError(f"File not found or path is incorrect: {file_path}")