diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 2fa7305c0..d055a4454 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -8,6 +8,7 @@ from typing import List, Optional, Tuple, Type, Union import torch +import torch.nn.functional as F from torch import nn from transformers.cache_utils import Cache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter @@ -15,7 +16,6 @@ from transformers.models.granitemoe.modeling_granitemoe import ( GraniteMoeAttention, GraniteMoeConfig, - GraniteMoeDecoderLayer, GraniteMoeForCausalLM, GraniteMoeModel, GraniteMoeMoE, @@ -196,88 +196,6 @@ def eager_attention_forward( return attn_output, attn_weights -class QEffGraniteMoeDecoderLayer(GraniteMoeDecoderLayer): - """ - Copied from GraniteForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granite/modeling_granite.py - The only differences are: - - add new args batch idx for the CB models although its not supported yet. - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - output_router_logits: Optional[bool] = False, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - output_router_logits (`bool`, *optional*): - Whether or not to return the logits of all the routers. They are useful for computing the router loss, and - should not be returned during inference. - position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - - hidden_states = residual + hidden_states * self.residual_multiplier - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states, router_logits = self.block_sparse_moe(hidden_states) - - hidden_states = residual + hidden_states * self.residual_multiplier - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if output_router_logits: - outputs += (router_logits,) - - return outputs - - class QEffGraniteMoeModel(GraniteMoeModel): """Copied from GraniteMoeModel: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoe/modeling_granitemoe.py The only differences are: @@ -316,6 +234,12 @@ def forward( inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + # if not isinstance(past_key_values, (type(None), Cache)): + # raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + # if use_cache and past_key_values is None: + # past_key_values = QEffDynamicCache() return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): return_legacy_cache = True @@ -492,13 +416,7 @@ def forward(self, hidden_states): logits = self.layer(hidden_states).float() top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=1) # [num_tokens, top_k] top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k] - - B, K = top_k_indices.shape - E = int(self.num_experts) - flat = top_k_indices.reshape(-1) - mask = torch.zeros((B * K, E), dtype=torch.int64, device=top_k_indices.device) - mask[torch.arange(B * K, device=flat.device), flat] = 1 - expert_mask = mask.view(B, K, E).permute(2, 1, 0) + expert_mask = F.one_hot(top_k_indices, num_classes=self.num_experts).permute(2, 1, 0) return top_k_gates, expert_mask, logits, self.num_experts @@ -574,9 +492,14 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, + # labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + # output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + # output_router_logits: Optional[bool] = None, + # return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + # logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" @@ -632,7 +555,6 @@ def forward( logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] logits = self.lm_head(hidden_states).float() - # logits = logits / self.config.logits_scaling return MoeCausalLMOutputWithPast( loss=None, diff --git a/tests/configs/causal_model_configs.json b/tests/configs/causal_model_configs.json index d6183a7fb..bf0fd642d 100644 --- a/tests/configs/causal_model_configs.json +++ b/tests/configs/causal_model_configs.json @@ -53,7 +53,19 @@ "rotary_dim": 16 } }, - + { + "model_name": "ibm-granite/granite-3.1-1b-a400m-base", + "model_type": "granitemoe", + "additional_params": { + "max_position_embeddings": 128, + "num_hidden_layers": 1, + "num_attention_heads": 2, + "hidden_size": 64, + "intermediate_size": 256, + "vocab_size": 49155, + "num_key_value_heads": 1 + } + }, { "model_name": "microsoft/Phi-3-mini-4k-instruct", "model_type": "phi3",