diff --git a/docs/tutorials/generator/speculative_decoding.md b/docs/tutorials/generator/speculative_decoding.md index d416b39..4100eea 100644 --- a/docs/tutorials/generator/speculative_decoding.md +++ b/docs/tutorials/generator/speculative_decoding.md @@ -33,6 +33,21 @@ generator_model_platform_args: assistant_model: meta-llama/Llama-3.2-1B ``` +### Universal Assisted Decoding +You may also use assistant model with different tokeniers from the target model. All you need to do is to explicitly specify the assistant tokenizer: + +```yaml +generator_model_name: google/gemma-2-9b +generator_model_platform: huggingface +generator_model_platform_args: + hf_generate_params: + assistant_model: double7/vicuna-68m + assistant_tokenizer: double7/vicuna-68m +``` + +Note: Transformers `v4.46.0` or above is required to support universal assisted decoding. + + ## Speculative Decoding with vLLM Models Speculative decoding with vLLM is also straightforward. Here is an example configuration that sets up vLLM in offline mode to use speculative decoding with a draft model, speculating 5 tokens at a time: diff --git a/src/agrag/modules/generator/generators/hf_generator.py b/src/agrag/modules/generator/generators/hf_generator.py index 2108ead..290dfa7 100644 --- a/src/agrag/modules/generator/generators/hf_generator.py +++ b/src/agrag/modules/generator/generators/hf_generator.py @@ -80,10 +80,17 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, **self.hf_tokenizer_init_params) self.assistant_model_name = self.hf_generate_params.get("assistant_model", None) + self.assistant_tokenizer_name = self.hf_generate_params.get("assistant_tokenizer", None) logger.info(f"Using Huggingface Model {self.assistant_model_name} as the Assistant Model") if self.assistant_model_name: assistant_model = AutoModelForCausalLM.from_pretrained(self.assistant_model_name).to(self.device) self.hf_generate_params["assistant_model"] = assistant_model + # To support Universal Assisted Decoding + # See https://huggingface.co/docs/transformers/main/en/generation_strategies#universal-assisted-decoding + if self.assistant_tokenizer_name: + assistant_tokenizer = AutoTokenizer.from_pretrained(self.assistant_tokenizer_name) + self.hf_generate_params["assistant_tokenizer"] = assistant_tokenizer + self.hf_generate_params["tokenizer"] = self.tokenizer self.model = AutoModelForCausalLM.from_pretrained(self.model_name, **self.hf_model_params).to(self.device) if self.num_gpus > 1: