diff --git a/inst/python/Dataset.py b/inst/python/Dataset.py index 4f7a2c6..aaf5e87 100644 --- a/inst/python/Dataset.py +++ b/inst/python/Dataset.py @@ -10,6 +10,7 @@ import json import os +# this one needs to have a parameter to not do the mapping again class Data(Dataset): def __init__(self, data, labels=None, numerical_features=None, @@ -99,18 +100,6 @@ def __init__(self, data, labels=None, numerical_features=None, .collect() ) - cat2_concepts = ( - self.data_ref - .filter(pl.col("conceptId").is_in(cat2_feature_names)) - .select("conceptId") - .collect() - ) - concept_ids = cat2_concepts["conceptId"].to_list() - cat2_indices = [cat2_feature_names.index(concept_id) for concept_id in - concept_ids] - with open(desktop_path / "cat2_indices.json", 'w') as f: - json.dump(cat2_indices, f) - # Now, use 'cat2_ref' as a normal DataFrame and access "columnId" data_cat_1 = data_cat.filter( ~pl.col("covariateId").is_in(cat2_ref["covariateId"])) @@ -145,10 +134,19 @@ def __init__(self, data, labels=None, numerical_features=None, "covariateId": data_cat_2["covariateId"].unique(), "index": pl.Series(range(1, len(data_cat_2["covariateId"].unique()) + 1)) }) + cat_2_mapping = cat_2_mapping.lazy() + cat_2_mapping = ( + self.data_ref + .filter(pl.col("covariateId").is_in(data_cat_2["covariateId"].unique())) + .select(pl.col("conceptId"), pl.col("covariateId")) + .join(cat_2_mapping, on="covariateId", how="left") + .collect() + ) cat_2_mapping.write_json(str(desktop_path / "cat2_mapping.json")) + # cat_2_mapping.write_json(str(desktop_path / "cat2_mapping.json")) data_cat_2 = data_cat_2.join(cat_2_mapping, on="covariateId", how="left") \ - .select(pl.col("rowId"), pl.col("index").alias("covariateId")) # maybe rename this to something else + .select(pl.col("rowId"), pl.col("index").alias("covariateId")) # maybe rename this to something else cat_2_tensor = torch.tensor(data_cat_2.to_numpy()) tensor_list_2 = torch.split( @@ -227,9 +225,9 @@ def __getitem__(self, item): if batch["cat_2"].dim() == 1: batch["cat_2"] = batch["cat_2"].unsqueeze(0) if (batch["num"] is not None - and batch["num"].dim() == 1 - and not isinstance(item, list) - ): + and batch["num"].dim() == 1 + and not isinstance(item, list) + ): batch["num"] = batch["num"].unsqueeze(0) return [batch, self.target[item].squeeze()] diff --git a/inst/python/InitStrategy.py b/inst/python/InitStrategy.py index 9c88c4e..14216a1 100644 --- a/inst/python/InitStrategy.py +++ b/inst/python/InitStrategy.py @@ -3,6 +3,8 @@ import torch import os from torchviz import make_dot +import json +import polars as pl class InitStrategy(ABC): @abstractmethod @@ -31,71 +33,86 @@ def initialize(self, model, model_parameters, estimator_settings): file_path = estimator_settings.get("embedding_file_path") # Ensure `cat_2_features` is added to `model_parameters` - cat_2_features_default = 20 # Set a default value if you don't have one - model_parameters['cat_2_features'] = model_parameters.get('cat_2_features', cat_2_features_default) + # cat_2_features_default = 20 # Set a default value if you don't have one + print(model_parameters['cat_2_features']) + print(model_parameters['cat_features']) + print(model_parameters['num_features']) + # Instantiate the model with the provided parameters model_temp = model(**model_parameters) - # Create a dummy input batch that matches the model inputs - dummy_input = { - "cat": torch.randint(0, model_parameters['cat_features'], (1, 10)).long(), - "cat_2": torch.randint(0, model_parameters['cat_2_features'], (1, 10)).long(), - "num": torch.randn(1, model_parameters['num_features']) if model_parameters['num_features'] > 0 else None - } - - # Ensure that the dummy input does not contain `None` values if num_features == 0 - if model_parameters['num_features'] == 0: - del dummy_input["num"] - - if hasattr(model_temp, 'forward'): - try: - output = model_temp(dummy_input) - initial_graph = make_dot(output, params=dict(model_temp.named_parameters()), show_attrs=False, show_saved=False) - initial_graph.render("initial_model_architecture", format="png") - except Exception as e: - print(f"Error during initial model visualization: {e}") - - else: - raise AttributeError("The model does not have a forward method.") - if file_path and os.path.exists(file_path): state = torch.load(file_path) state_dict = state["state_dict"] - embedding_key = "embedding.weight" # Key in the state dict for the embedding + embedding_key = "embedding.weight" if embedding_key not in state_dict: raise KeyError(f"The key '{embedding_key}' does not exist in the state dictionary") new_embeddings = state_dict[embedding_key].float() + print(f"new_embeddings: {new_embeddings}") # Ensure that model_temp.categorical_embedding_2 exists if not hasattr(model_temp, 'categorical_embedding_2'): raise AttributeError("The model does not have an attribute 'categorical_embedding_2'") - # Replace the weights of `model_temp.categorical_embedding_2` - if isinstance(model_temp.categorical_embedding_2, torch.nn.Embedding): - with torch.no_grad(): - model_temp.categorical_embedding_2.weight = torch.nn.Parameter(new_embeddings) - else: - raise TypeError("The attribute 'categorical_embedding_2' is not of type `torch.nn.Embedding`") - + # # replace weights + # cat2_concept_mapping = pl.read_json(os.path.expanduser("~/Desktop/cat2_concept_mapping.json")) + cat2_mapping = pl.read_json(os.path.expanduser("~/Desktop/cat2_mapping.json")) + print(f"cat2_mapping: {cat2_mapping}") + + concept_df = pl.DataFrame({"conceptId": state['names']}).with_columns(pl.col("conceptId")) + print(f"concept_df: {concept_df}") + + # Initialize tensor for mapped embeddings + mapped_embeddings = torch.zeros((cat2_mapping.shape[0] + 1, new_embeddings.shape[1])) + + # Map embeddings to their corresponding indices + for row in cat2_mapping.iter_rows(): + concept_id, covariate_id, index = row + if concept_id in concept_df["conceptId"]: + concept_idx = concept_df["conceptId"].to_list().index(concept_id) + mapped_embeddings[index] = new_embeddings[concept_idx] + + print(f"mapped_embeddings: {mapped_embeddings}") + + # Assign the mapped embeddings to the model + model_temp.categorical_embedding_2.weight = torch.nn.Parameter(mapped_embeddings) + model_temp.categorical_embedding_2.weight.requires_grad = False + print("New Embeddings:") print(new_embeddings) print(f"Restored Epoch: {state['epoch']}") print(f"Restored Mean Rank: {state['mean_rank']}") print(f"Restored Loss: {state['loss']}") - print(f"Restored Names: {state['names']}") + print(f"Restored Names: {state['names'][:5]}") + print(f"Number of names: {len(state['names'])}") + # print(f"Filtered Embeddings: {filtered_embeddings}") else: raise FileNotFoundError(f"File not found or path is incorrect: {file_path}") - # Visualize the modified model architecture again - try: - output = model_temp(dummy_input) - modified_graph = make_dot(output, params=dict(model_temp.named_parameters())) - modified_graph.render("modified_model_architecture", format="png") - print("Modified model architecture rendered successfully.") - except Exception as e: - print(f"Error during modified model visualization: {e}") + # Create a dummy input batch that matches the model inputs + dummy_input = { + "cat": torch.randint(0, model_parameters['cat_features'], (1, 10)).long(), + "cat_2": torch.randint(0, model_parameters['cat_2_features'], (1, 10)).long(), + "num": torch.randn(1, model_parameters['num_features']) if model_parameters['num_features'] > 0 else None + } + + # Ensure that the dummy input does not contain `None` values if num_features == 0 + if model_parameters['num_features'] == 0: + del dummy_input["num"] + + if hasattr(model_temp, 'forward'): + try: + output = model_temp(dummy_input) + initial_graph = make_dot(output, params=dict(model_temp.named_parameters()), show_attrs=False, show_saved=False) + initial_graph.render("initial_model_architecture", format="png") + except Exception as e: + print(f"Error during initial model visualization: {e}") + + else: + raise AttributeError("The model does not have a forward method.") + return model_temp