Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 13 additions & 91 deletions QEfficient/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from typing import List, Optional, Tuple, Type, Union

import torch
import torch.nn.functional as F
from torch import nn
from transformers.cache_utils import Cache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from transformers.models.granitemoe.modeling_granitemoe import (
GraniteMoeAttention,
GraniteMoeConfig,
GraniteMoeDecoderLayer,
GraniteMoeForCausalLM,
GraniteMoeModel,
GraniteMoeMoE,
Expand Down Expand Up @@ -196,88 +196,6 @@ def eager_attention_forward(
return attn_output, attn_weights


class QEffGraniteMoeDecoderLayer(GraniteMoeDecoderLayer):
"""
Copied from GraniteForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granite/modeling_granite.py
The only differences are:
- add new args batch idx for the CB models although its not supported yet.
"""

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
output_router_logits (`bool`, *optional*):
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
should not be returned during inference.
position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)

hidden_states = residual + hidden_states * self.residual_multiplier

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, router_logits = self.block_sparse_moe(hidden_states)

hidden_states = residual + hidden_states * self.residual_multiplier

outputs = (hidden_states,)

if output_attentions:
outputs += (self_attn_weights,)

if output_router_logits:
outputs += (router_logits,)

return outputs


class QEffGraniteMoeModel(GraniteMoeModel):
"""Copied from GraniteMoeModel: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoe/modeling_granitemoe.py
The only differences are:
Expand Down Expand Up @@ -316,6 +234,12 @@ def forward(

inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama

# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
# if not isinstance(past_key_values, (type(None), Cache)):
# raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")

# if use_cache and past_key_values is None:
# past_key_values = QEffDynamicCache()
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
Expand Down Expand Up @@ -492,13 +416,7 @@ def forward(self, hidden_states):
logits = self.layer(hidden_states).float()
top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=1) # [num_tokens, top_k]
top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]

B, K = top_k_indices.shape
E = int(self.num_experts)
flat = top_k_indices.reshape(-1)
mask = torch.zeros((B * K, E), dtype=torch.int64, device=top_k_indices.device)
mask[torch.arange(B * K, device=flat.device), flat] = 1
expert_mask = mask.view(B, K, E).permute(2, 1, 0)
expert_mask = F.one_hot(top_k_indices, num_classes=self.num_experts).permute(2, 1, 0)
return top_k_gates, expert_mask, logits, self.num_experts


Expand Down Expand Up @@ -574,9 +492,14 @@ def forward(
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
# labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
# output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
# output_router_logits: Optional[bool] = None,
# return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
# logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Expand Down Expand Up @@ -632,7 +555,6 @@ def forward(
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
logits = self.lm_head(hidden_states).float()
# logits = logits / self.config.logits_scaling

return MoeCausalLMOutputWithPast(
loss=None,
Expand Down
14 changes: 13 additions & 1 deletion tests/configs/causal_model_configs.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,19 @@
"rotary_dim": 16
}
},

{
"model_name": "ibm-granite/granite-3.1-1b-a400m-base",
"model_type": "granitemoe",
"additional_params": {
"max_position_embeddings": 128,
"num_hidden_layers": 1,
"num_attention_heads": 2,
"hidden_size": 64,
"intermediate_size": 256,
"vocab_size": 49155,
"num_key_value_heads": 1
}
},
{
"model_name": "microsoft/Phi-3-mini-4k-instruct",
"model_type": "phi3",
Expand Down
Loading