Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ The *total-token* is the number of draft tokens. For smaller models and advanced
### With Code
You can use our provided "eagenerate" for speedup generation just like using 'generate' from Hugging Face. Here is an example.
```python
import torch
from eagle.model.ea_model import EaModel
from fastchat.model import get_conversation_template
model = EaModel.from_pretrained(
Expand All @@ -205,7 +206,10 @@ model = EaModel.from_pretrained(
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
total_token=-1
total_token=-1,
# If using EAGLE (not EAGLE 3), please uncomment the following line
# use_eagle3=False

)
model.eval()
your_message="Hello"
Expand Down
27 changes: 25 additions & 2 deletions eagle/model/cnets1.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,8 +777,30 @@ def topK_genrate(self, hidden_states, input_ids, head, logits_processor):
tree_position_ids = torch.sum(tree_mask, dim=1) - 1

tree_mask = tree_mask.float()[None, None]
########################################
# Change
########################################
# top_scores_index = torch.sort(top_scores.indices).values

# get the cumulative log-probs (log-densities) for the selected drafts
draft_log_scores = scores_list[top_scores_index] # shape [total_tokens]
# convert to probabilities if you prefer:
# draft_probs = torch.exp(draft_log_scores) # shape [total_tokens]

# then build draft_tokens as before
draft_tokens = ss_token_list[top_scores_index]
draft_tokens = torch.cat((sample_token, draft_tokens), dim=0)

# make sure device is consistent
draft_log_scores = draft_log_scores.to(draft_tokens.device)[None] # shape [1, total_tokens]
# draft_probs = draft_probs.to(draft_tokens.device)[None]
########################################
# Change
########################################

draft_tokens = draft_tokens[None]


del parents_list, scores_list, ss_token, ss_token_list, draft_parents

# with Timer("retrieve"):
Expand Down Expand Up @@ -819,7 +841,7 @@ def custom_sort(lst):
del mask_index, mask_index_list, noleaf_index, noleaf_num, leaf_num, max_depth, rid
tree_position_ids = tree_position_ids.to(hidden_states.device)

return draft_tokens, retrieve_indices, tree_mask, tree_position_ids
return draft_tokens, retrieve_indices, tree_mask, tree_position_ids, draft_log_scores



Expand All @@ -832,4 +854,5 @@ def count_parameters(model):
if __name__ == "__main__":
config = EConfig.from_pretrained('config.json')
model = Model(config, load_emb=False)
print(model)

print(model)
Loading