Skip to content

Commit

Permalink
Load custom weights into embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
lhjohn committed Aug 27, 2024
1 parent cf0a59c commit af5a3df
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 57 deletions.
30 changes: 14 additions & 16 deletions inst/python/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
import os

# this one needs to have a parameter to not do the mapping again

class Data(Dataset):
def __init__(self, data, labels=None, numerical_features=None,
Expand Down Expand Up @@ -99,18 +100,6 @@ def __init__(self, data, labels=None, numerical_features=None,
.collect()
)

cat2_concepts = (
self.data_ref
.filter(pl.col("conceptId").is_in(cat2_feature_names))
.select("conceptId")
.collect()
)
concept_ids = cat2_concepts["conceptId"].to_list()
cat2_indices = [cat2_feature_names.index(concept_id) for concept_id in
concept_ids]
with open(desktop_path / "cat2_indices.json", 'w') as f:
json.dump(cat2_indices, f)

# Now, use 'cat2_ref' as a normal DataFrame and access "columnId"
data_cat_1 = data_cat.filter(
~pl.col("covariateId").is_in(cat2_ref["covariateId"]))
Expand Down Expand Up @@ -145,10 +134,19 @@ def __init__(self, data, labels=None, numerical_features=None,
"covariateId": data_cat_2["covariateId"].unique(),
"index": pl.Series(range(1, len(data_cat_2["covariateId"].unique()) + 1))
})
cat_2_mapping = cat_2_mapping.lazy()
cat_2_mapping = (
self.data_ref
.filter(pl.col("covariateId").is_in(data_cat_2["covariateId"].unique()))
.select(pl.col("conceptId"), pl.col("covariateId"))
.join(cat_2_mapping, on="covariateId", how="left")
.collect()
)
cat_2_mapping.write_json(str(desktop_path / "cat2_mapping.json"))
# cat_2_mapping.write_json(str(desktop_path / "cat2_mapping.json"))

data_cat_2 = data_cat_2.join(cat_2_mapping, on="covariateId", how="left") \
.select(pl.col("rowId"), pl.col("index").alias("covariateId")) # maybe rename this to something else
.select(pl.col("rowId"), pl.col("index").alias("covariateId")) # maybe rename this to something else

cat_2_tensor = torch.tensor(data_cat_2.to_numpy())
tensor_list_2 = torch.split(
Expand Down Expand Up @@ -227,9 +225,9 @@ def __getitem__(self, item):
if batch["cat_2"].dim() == 1:
batch["cat_2"] = batch["cat_2"].unsqueeze(0)
if (batch["num"] is not None
and batch["num"].dim() == 1
and not isinstance(item, list)
):
and batch["num"].dim() == 1
and not isinstance(item, list)
):
batch["num"] = batch["num"].unsqueeze(0)
return [batch, self.target[item].squeeze()]

99 changes: 58 additions & 41 deletions inst/python/InitStrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import os
from torchviz import make_dot
import json
import polars as pl

class InitStrategy(ABC):
@abstractmethod
Expand Down Expand Up @@ -31,71 +33,86 @@ def initialize(self, model, model_parameters, estimator_settings):
file_path = estimator_settings.get("embedding_file_path")

# Ensure `cat_2_features` is added to `model_parameters`
cat_2_features_default = 20 # Set a default value if you don't have one
model_parameters['cat_2_features'] = model_parameters.get('cat_2_features', cat_2_features_default)
# cat_2_features_default = 20 # Set a default value if you don't have one
print(model_parameters['cat_2_features'])
print(model_parameters['cat_features'])
print(model_parameters['num_features'])


# Instantiate the model with the provided parameters
model_temp = model(**model_parameters)

# Create a dummy input batch that matches the model inputs
dummy_input = {
"cat": torch.randint(0, model_parameters['cat_features'], (1, 10)).long(),
"cat_2": torch.randint(0, model_parameters['cat_2_features'], (1, 10)).long(),
"num": torch.randn(1, model_parameters['num_features']) if model_parameters['num_features'] > 0 else None
}

# Ensure that the dummy input does not contain `None` values if num_features == 0
if model_parameters['num_features'] == 0:
del dummy_input["num"]

if hasattr(model_temp, 'forward'):
try:
output = model_temp(dummy_input)
initial_graph = make_dot(output, params=dict(model_temp.named_parameters()), show_attrs=False, show_saved=False)
initial_graph.render("initial_model_architecture", format="png")
except Exception as e:
print(f"Error during initial model visualization: {e}")

else:
raise AttributeError("The model does not have a forward method.")

if file_path and os.path.exists(file_path):
state = torch.load(file_path)
state_dict = state["state_dict"]
embedding_key = "embedding.weight" # Key in the state dict for the embedding
embedding_key = "embedding.weight"

if embedding_key not in state_dict:
raise KeyError(f"The key '{embedding_key}' does not exist in the state dictionary")

new_embeddings = state_dict[embedding_key].float()
print(f"new_embeddings: {new_embeddings}")

# Ensure that model_temp.categorical_embedding_2 exists
if not hasattr(model_temp, 'categorical_embedding_2'):
raise AttributeError("The model does not have an attribute 'categorical_embedding_2'")

# Replace the weights of `model_temp.categorical_embedding_2`
if isinstance(model_temp.categorical_embedding_2, torch.nn.Embedding):
with torch.no_grad():
model_temp.categorical_embedding_2.weight = torch.nn.Parameter(new_embeddings)
else:
raise TypeError("The attribute 'categorical_embedding_2' is not of type `torch.nn.Embedding`")

# # replace weights
# cat2_concept_mapping = pl.read_json(os.path.expanduser("~/Desktop/cat2_concept_mapping.json"))
cat2_mapping = pl.read_json(os.path.expanduser("~/Desktop/cat2_mapping.json"))
print(f"cat2_mapping: {cat2_mapping}")

concept_df = pl.DataFrame({"conceptId": state['names']}).with_columns(pl.col("conceptId"))
print(f"concept_df: {concept_df}")

# Initialize tensor for mapped embeddings
mapped_embeddings = torch.zeros((cat2_mapping.shape[0] + 1, new_embeddings.shape[1]))

# Map embeddings to their corresponding indices
for row in cat2_mapping.iter_rows():
concept_id, covariate_id, index = row
if concept_id in concept_df["conceptId"]:
concept_idx = concept_df["conceptId"].to_list().index(concept_id)
mapped_embeddings[index] = new_embeddings[concept_idx]

print(f"mapped_embeddings: {mapped_embeddings}")

# Assign the mapped embeddings to the model
model_temp.categorical_embedding_2.weight = torch.nn.Parameter(mapped_embeddings)
model_temp.categorical_embedding_2.weight.requires_grad = False

print("New Embeddings:")
print(new_embeddings)
print(f"Restored Epoch: {state['epoch']}")
print(f"Restored Mean Rank: {state['mean_rank']}")
print(f"Restored Loss: {state['loss']}")
print(f"Restored Names: {state['names']}")
print(f"Restored Names: {state['names'][:5]}")
print(f"Number of names: {len(state['names'])}")
# print(f"Filtered Embeddings: {filtered_embeddings}")
else:
raise FileNotFoundError(f"File not found or path is incorrect: {file_path}")

# Visualize the modified model architecture again
try:
output = model_temp(dummy_input)
modified_graph = make_dot(output, params=dict(model_temp.named_parameters()))
modified_graph.render("modified_model_architecture", format="png")
print("Modified model architecture rendered successfully.")
except Exception as e:
print(f"Error during modified model visualization: {e}")

# Create a dummy input batch that matches the model inputs
dummy_input = {
"cat": torch.randint(0, model_parameters['cat_features'], (1, 10)).long(),
"cat_2": torch.randint(0, model_parameters['cat_2_features'], (1, 10)).long(),
"num": torch.randn(1, model_parameters['num_features']) if model_parameters['num_features'] > 0 else None
}

# Ensure that the dummy input does not contain `None` values if num_features == 0
if model_parameters['num_features'] == 0:
del dummy_input["num"]

if hasattr(model_temp, 'forward'):
try:
output = model_temp(dummy_input)
initial_graph = make_dot(output, params=dict(model_temp.named_parameters()), show_attrs=False, show_saved=False)
initial_graph.render("initial_model_architecture", format="png")
except Exception as e:
print(f"Error during initial model visualization: {e}")

else:
raise AttributeError("The model does not have a forward method.")

return model_temp

0 comments on commit af5a3df

Please sign in to comment.