Skip to content

Commit

Permalink
fix olmo accuracy for bf16, add sdpa for persimmon, support jais (hug…
Browse files Browse the repository at this point in the history
…gingface#726)

* fix olmo accuracy for bf16, add sdpa for persimmon, support jais

* apply sdpa for jais

* snowflake export

* apply review comments
  • Loading branch information
eaidova authored Jun 6, 2024
1 parent 6888c0a commit 1ab78d5
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 2 deletions.
36 changes: 35 additions & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,19 @@

from .model_patcher import (
AquilaModelPatcher,
ArcticModelPatcher,
BaichuanModelPatcher,
ChatGLMModelPatcher,
CodeGenModelPatcher,
DBRXModelPatcher,
GemmaModelPatcher,
InternLM2Patcher,
InternLMModelPatcher,
JaisModelPatcher,
LlamaModelPatcher,
MixtralModelPatcher,
MPTModelPatcher,
PersimmonModelPatcher,
Phi3ModelPatcher,
QwenModelPatcher,
XverseModelPatcher,
Expand Down Expand Up @@ -473,7 +476,7 @@ class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):


@register_in_tasks_manager("olmo", *["text-generation", "text-generation-with-past"], library_name="transformers")
class OlmoOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
class OlmoOpenVINOConfig(LlamaOpenVINOConfig):
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

Expand Down Expand Up @@ -630,6 +633,11 @@ class PersimmonOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return PersimmonModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("biogpt", *["text-generation", "text-generation-with-past"], library_name="transformers")
class BioGPTOpenVINOConfig(TextDecoderOnnxConfig):
Expand Down Expand Up @@ -785,3 +793,29 @@ def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return DBRXModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"jais",
*["text-generation", "text-generation-with-past"],
library_name="transformers",
)
class JaisOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14

NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return JaisModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("arctic", *["text-generation", "text-generation-with-past"], library_name="transformers")
class ArcticOpenVINOConfig(MixtralOpenVINOConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return ArcticModelPatcher(self, model, model_kwargs=model_kwargs)
230 changes: 230 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ def __exit__(self, exc_type, exc_value, traceback):
layer.block_sparse_moe.forward = layer.block_sparse_moe._unpatched_forward


class ArcticModelPatcher(MixtralModelPatcher):
def __enter__(self):
# model initialize some weights for matrix multiplication in bfloat16, that lead to inconsistency of dtype
try:
self._model.to(torch.float32)
except Exception:
pass

super().__enter__()


def _chatglm_transformer_forward(
self,
input_ids,
Expand Down Expand Up @@ -1771,3 +1782,222 @@ def __exit__(self, exc_type, exc_value, traceback):
self._model.transformer._update_causal_mask = self._model.transformer._orig_update_causal_mask
for block in self._model.transformer.blocks:
block.ffn.experts.forward = block.ffn.experts._orig_forward


def _persimmon_self_attn_sdpa_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
from transformers.models.persimmon.modeling_persimmon import apply_rotary_pos_emb

if output_attentions:
return self._orig_forward(
hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache
)

bsz, q_len, _ = hidden_states.size()

# [batch_size, seq_length, 3 x hidden_size]
fused_qkv = self.query_key_value(hidden_states)

# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_states, key_states, value_states) = self._split_heads(fused_qkv)

if self.qk_layernorm:
query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states)

# [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim]
query_states = query_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

# Partial rotary embedding
query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim],
query_states[..., self.rotary_emb.dim :],
)
key_rot, key_pass = (
key_states[..., : self.rotary_emb.dim],
key_states[..., self.rotary_emb.dim :],
)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)

# [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1)
key_states = torch.cat((key_rot, key_pass), dim=-1)

if past_key_value is not None:
# Specific to RoPE models with partial rotation
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
scale=1 / math.sqrt(self.head_dim),
dropout_p=self.attention_dropout.p,
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.dense(attn_output)

return attn_output, None, past_key_value


class PersimmonModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
if is_torch_version(">=", "2.1.0"):
orig_self_attn_fwd = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_persimmon_self_attn_sdpa_forward, layer.self_attn)
layer.self_attn._orig_forward = orig_self_attn_fwd

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward


def _jais_attn_forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
position_bias: Optional[torch.FloatTensor] = None,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `JAISAttention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

if use_cache is True:
present = (key, value)
else:
present = None

if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(
query, key, value, attention_mask, head_mask, position_bias
)
else:
# Difference with original: override attn realization with sdpa if not output_attentions
if not output_attentions:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask, position_bias)
else:
attn_output, attn_weights = self._orig_attn(query, key, value, attention_mask, head_mask, position_bias)

attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)

return outputs


def _jais_attn(self, query, key, value, attention_mask=None, head_mask=None, position_bias=None):
scale = 1.0
if self.scale_attn_weights:
scale = 1 / self.head_dim**self.attn_scale_power

# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
scale = scale / float(self.layer_idx + 1)

query_length = query.size(-2)
attention_mask_sdpa = torch.ones(
(query.shape[0], query.shape[1], query.shape[2], key.shape[2]),
dtype=query.dtype,
)

if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(torch.float16).min
attention_mask_sdpa.masked_fill_(~causal_mask, mask_value)

if attention_mask is not None:
# Apply the attention mask
attention_mask_sdpa = attention_mask_sdpa + attention_mask

if position_bias is not None:
attention_mask_sdpa += position_bias.type_as(attention_mask_sdpa).unsqueeze(0)

# Mask heads if we want to
if head_mask is not None:
attention_mask_sdpa = attention_mask_sdpa * head_mask

attn_output = F.scaled_dot_product_attention(
query, key, value, attention_mask_sdpa, dropout_p=self.attn_dropout.p, scale=scale
)

return attn_output, None


class JaisModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()

for layer in self._model.transformer.h:
if is_torch_version(">=", "2.1.0"):
orig_self_attn_fwd = layer.attn._attn
layer.attn._attn = types.MethodType(_jais_attn, layer.attn)
layer.attn._orig_attn = orig_self_attn_fwd
layer.attn._orig_forward = layer.attn.forward
layer.attn.forward = types.MethodType(_jais_attn_forward, layer.attn)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.transformer.h:
if hasattr(layer.attn, "_orig_attn"):
layer.attn._attn = layer.attn._orig_attn
layer.attn.forward = layer.attn._orig_forward
8 changes: 7 additions & 1 deletion tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,8 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"internlm",
"dbrx",
"qwen2-moe",
"jais",
"arctic",
)
GENERATION_LENGTH = 100
REMOTE_CODE_MODELS = (
Expand All @@ -581,6 +583,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"xverse",
"internlm",
"codegen2",
"arctic",
)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
Expand Down Expand Up @@ -622,7 +625,7 @@ def test_compare_to_transformers(self, model_arch):

set_seed(SEED)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
if model_arch == "qwen":
if model_arch in ["qwen", "arctic"]:
transformers_model.to(torch.float32)

with torch.no_grad():
Expand Down Expand Up @@ -869,6 +872,9 @@ def test_beam_search(self, model_arch):
model_id, export=True, use_cache=True, stateful=False, **model_kwargs
)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)

if model_arch == "arctic":
transformers_model.to(torch.float32)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
tokens.pop("token_type_ids", None)
Expand Down
2 changes: 2 additions & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"ibert": "hf-internal-testing/tiny-random-ibert",
"internlm": "katuni4ka/tiny-random-internlm",
"internlm2": "katuni4ka/tiny-random-internlm2",
"jais": "katuni4ka/tiny-random-jais",
"levit": "hf-internal-testing/tiny-random-LevitModel",
"longt5": "hf-internal-testing/tiny-random-longt5",
"llama": "HuggingFaceM4/tiny-random-LlamaForCausalLM",
Expand Down Expand Up @@ -109,6 +110,7 @@
"latent-consistency": "echarlaix/tiny-random-latent-consistency",
"sew": "hf-internal-testing/tiny-random-SEWModel",
"sew_d": "asapp/sew-d-tiny-100k-ft-ls100h",
"arctic": "katuni4ka/tiny-random-snowflake",
"swin": "hf-internal-testing/tiny-random-SwinModel",
"t5": "hf-internal-testing/tiny-random-t5",
"trocr": "microsoft/trocr-small-handwritten",
Expand Down

0 comments on commit 1ab78d5

Please sign in to comment.