Skip to content

Commit

Permalink
enable qkv
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Oct 21, 2024
1 parent 80e8071 commit 07b5058
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,30 @@ def forward(
class _IPEXLlamaAttention(_IPEXAttention):
def __init__(self, module, config) -> None:
super().__init__(module, config)
concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight])
bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias]
use_bias = bias_list != []
self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias)
self.concat_qkv.weight = nn.Parameter(concat_weight)
if use_bias:
concat_bias = torch.concat(bias_list, 0)
self.concat_linear.bias = nn.Parameter(concat_bias)
self.q_slice = self.q_proj.out_features
self.k_slice = self.q_slice + self.k_proj.out_features
self.v_slice = self.k_slice + self.v_proj.out_features
del self.__dict__["_modules"]["q_proj"]
del self.__dict__["_modules"]["k_proj"]
del self.__dict__["_modules"]["v_proj"]
if self.module_device == "cpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = LinearAdd(module.o_proj)
del self.__dict__["_modules"]["o_proj"]

def qkv_gemm(self, hidden_states):
query = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim)
key = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
value = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
qkv_out = self.concat_qkv(hidden_states)
query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim)
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)

return query, key, value

Expand Down

0 comments on commit 07b5058

Please sign in to comment.