Skip to content

Commit

Permalink
ipex 2.4 supports assisted decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Sep 4, 2024
1 parent 0d37ae9 commit e85240c
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 19 deletions.
19 changes: 14 additions & 5 deletions optimum/exporters/ipex/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

from typing import Optional, Tuple

from optimum.exporters.onnx.model_configs import LlamaOnnxConfig
from optimum.exporters.onnx.model_configs import FalconOnnxConfig, LlamaOnnxConfig, TextDecoderOnnxConfig
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.input_generators import DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator
from optimum.utils.input_generators import DummyPastKeyValuesGenerator, DummyTextInputGenerator
from optimum.utils.normalized_config import NormalizedTextConfig


class IPEXDummyPastKeyValuesGenerator(MistralDummyPastKeyValuesGenerator):
class IPEXDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
Expand All @@ -39,7 +39,7 @@ def __init__(
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
)
self.num_key_value_heads = normalized_config.num_key_value_heads
self.num_key_value_heads = getattr(normalized_config, "num_key_value_heads", 1)
self.max_position_embeddings = normalized_config.max_position_embeddings

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
Expand All @@ -63,8 +63,17 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int


class LlamaIPEXConfig(LlamaOnnxConfig):
DEFAULT_ONNX_OPSET = 14
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


ipex_onnx_config = {"llama": LlamaIPEXConfig}
class FalconIPEXConfig(FalconOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
IPEXDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator


ipex_onnx_config = {"llama": LlamaIPEXConfig, "falcon": FalconIPEXConfig}
103 changes: 103 additions & 0 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,109 @@ def _llama_model_forward(
)


# # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130
# def _llama_model_forward(
# self,
# input_ids: torch.LongTensor = None,
# attention_mask: Optional[torch.Tensor] = None,
# position_ids: Optional[torch.LongTensor] = None,
# past_key_values: Optional[List[torch.FloatTensor]] = None,
# inputs_embeds: Optional[torch.FloatTensor] = None,
# use_cache: Optional[bool] = None,
# output_attentions: Optional[bool] = None,
# output_hidden_states: Optional[bool] = None,
# return_dict: Optional[bool] = None,
# **kwargs,
# ) -> Union[Tuple, BaseModelOutputWithPast]:
# output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# output_hidden_states = (
# output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
# )
# use_cache = use_cache if use_cache is not None else self.config.use_cache

# return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# # retrieve input_ids and inputs_embeds
# if input_ids is not None and inputs_embeds is not None:
# raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
# elif input_ids is not None:
# batch_size, seq_length = input_ids.shape[:2]
# elif inputs_embeds is not None:
# batch_size, seq_length = inputs_embeds.shape[:2]
# else:
# raise ValueError("You have to specify either input_ids or inputs_embeds")

# past_key_values_length = 0
# if past_key_values is not None:
# past_key_values_length = past_key_values[0][0].shape[2]

# if position_ids is None:
# device = input_ids.device if input_ids is not None else inputs_embeds.device
# position_ids = torch.arange(
# past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
# )
# position_ids = position_ids.unsqueeze(0)

# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)

# if getattr(self.config, "_flash_attn_2_enabled", False):
# # 2d mask is passed through the layers
# attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
# else:
# # 4d mask is passed through the layers
# attention_mask = _prepare_4d_causal_attention_mask(
# attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
# )

# # embed positions
# hidden_states = inputs_embeds

# # decoder layers
# all_hidden_states = () if output_hidden_states else None
# all_self_attns = () if output_attentions else None
# next_decoder_cache = () if use_cache else None

# for idx, decoder_layer in enumerate(self.layers):
# if output_hidden_states:
# all_hidden_states += (hidden_states,)

# past_key_value = past_key_values[idx] if past_key_values is not None else None

# layer_outputs = decoder_layer(
# hidden_states,
# attention_mask=attention_mask,
# position_ids=position_ids,
# past_key_value=past_key_value,
# output_attentions=output_attentions,
# use_cache=use_cache,
# )

# hidden_states = layer_outputs[0]

# if use_cache:
# next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

# if output_attentions:
# all_self_attns += (layer_outputs[1],)

# hidden_states = self.norm(hidden_states)

# # add hidden states from the last decoder layer
# if output_hidden_states:
# all_hidden_states += (hidden_states,)

# next_cache = next_decoder_cache if use_cache else None
# if not return_dict:
# return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
# return BaseModelOutputWithPast(
# last_hidden_state=hidden_states,
# past_key_values=next_cache,
# hidden_states=all_hidden_states,
# attentions=all_self_attns,
# )


def _gpt2_block_forward(self, hidden_states, *args, **kwargs):
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
Expand Down
39 changes: 25 additions & 14 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
from transformers.models.auto.auto_factory import _get_model_class as get_model_class
from transformers.utils import WEIGHTS_NAME

import optimum
from optimum.exporters import TasksManager
from optimum.exporters.tasks import make_backend_config_constructor_for_task
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

Expand Down Expand Up @@ -91,8 +93,14 @@ def _is_patched_with_ipex(model, task):
def _prepare_inputs_for_ipex_model(model, task, use_cache):
task = _TASK_ALIASES.get(task, task)
signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__)
use_ipex_dummy_input = False
if _is_patched_with_ipex(model, task) and model.config.model_type in ipex_onnx_config:
onnx_config_class = ipex_onnx_config[model.config.model_type]
use_ipex_dummy_input = True
origin_dummy_batch_size = optimum.utils.input_generators.DEFAULT_DUMMY_SHAPES["batch_size"]
optimum.utils.input_generators.DEFAULT_DUMMY_SHAPES["batch_size"] = 1
onnx_config_class = make_backend_config_constructor_for_task(
ipex_onnx_config[model.config.model_type], task=task
)
else:
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
float_dtype = get_float_type(model.dtype)
Expand All @@ -105,6 +113,17 @@ def _prepare_inputs_for_ipex_model(model, task, use_cache):

dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")

# Check attention_mask shape
if use_ipex_dummy_input:
past_len = dummy_inputs["past_key_values"][0][0].shape[-2]
input_len = dummy_inputs["input_ids"].shape[-1]
attention_len = dummy_inputs["attention_mask"].shape[-1]
if attention_len != input_len + past_len:
dummy_inputs["attention_mask"] = torch.ones([dummy_inputs["input_ids"].shape[0], input_len + past_len]).to(
dummy_inputs["input_ids"].dtype
)
optimum.utils.input_generators.DEFAULT_DUMMY_SHAPES["batch_size"] = origin_dummy_batch_size

return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}


Expand Down Expand Up @@ -666,19 +685,16 @@ def _prepare_generation_config(
return generation_config, model_kwargs

def generate(self, *args, **kwargs):
if is_ipex_version("<", "2.5.0") and self._is_ipex_exported and kwargs.get("assistant_model", None):
# import pdb; pdb.set_trace()
if is_ipex_version("<", "2.4.0") and self._is_ipex_exported and kwargs.get("assistant_model", None):
raise ValueError(
f"Assisted decoding is not supported for patched models if ipex < 2.5, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)
# Patch functions to support IAKV cache
if self._is_ipex_exported:
_patch_crop_past_key_values()
try:
result = super().generate(*args, **kwargs)
except Exception as e:
_unpatch_crop_past_key_values()
raise e
_unpatch_crop_past_key_values()

return result


Expand All @@ -687,11 +703,6 @@ def _patch_crop_past_key_values():
transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values


def _unpatch_crop_past_key_values():
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values
transformers.generation.utils._crop_past_key_values = _crop_past_key_values


def _ipex_prepare_inputs_for_generation(
input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
Expand Down Expand Up @@ -772,7 +783,7 @@ def _ipex_reorder_cache(


def _ipex_crop_past_key_values(model, past_key_values, max_length):
if isinstance(model, IPEXModel):
if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"):
new_past_key_values = []
for i in range(len(past_key_values)):
pkv = []
Expand Down

0 comments on commit e85240c

Please sign in to comment.