Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Universal Assisted Model #90

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/tutorials/generator/speculative_decoding.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ generator_model_platform_args:
assistant_model: meta-llama/Llama-3.2-1B
```

### Universal Assisted Decoding
You may also use assistant model with different tokeniers from the target model. All you need to do is to explicitly specify the assistant tokenizer:

```yaml
generator_model_name: google/gemma-2-9b
generator_model_platform: huggingface
generator_model_platform_args:
hf_generate_params:
assistant_model: double7/vicuna-68m
assistant_tokenizer: double7/vicuna-68m
```

Note: Transformers `v4.46.0` or above is required to support universal assisted decoding.


## Speculative Decoding with vLLM Models

Speculative decoding with vLLM is also straightforward. Here is an example configuration that sets up vLLM in offline mode to use speculative decoding with a draft model, speculating 5 tokens at a time:
Expand Down
7 changes: 7 additions & 0 deletions src/agrag/modules/generator/generators/hf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,17 @@ def __init__(

self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, **self.hf_tokenizer_init_params)
self.assistant_model_name = self.hf_generate_params.get("assistant_model", None)
self.assistant_tokenizer_name = self.hf_generate_params.get("assistant_tokenizer", None)
logger.info(f"Using Huggingface Model {self.assistant_model_name} as the Assistant Model")
if self.assistant_model_name:
assistant_model = AutoModelForCausalLM.from_pretrained(self.assistant_model_name).to(self.device)
self.hf_generate_params["assistant_model"] = assistant_model
# To support Universal Assisted Decoding
# See https://huggingface.co/docs/transformers/main/en/generation_strategies#universal-assisted-decoding
if self.assistant_tokenizer_name:
assistant_tokenizer = AutoTokenizer.from_pretrained(self.assistant_tokenizer_name)
self.hf_generate_params["assistant_tokenizer"] = assistant_tokenizer
self.hf_generate_params["tokenizer"] = self.tokenizer
Copy link
Collaborator

@suzhoum suzhoum Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be more logical to move Line 93 out of the if statement and place after Line 81?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line is usually not needed, but it is compulsory if there is an assistant tokenizer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK that makes sense.

self.model = AutoModelForCausalLM.from_pretrained(self.model_name, **self.hf_model_params).to(self.device)

if self.num_gpus > 1:
Expand Down
Loading