diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index b06252843..8916b03c4 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -278,15 +278,30 @@ def forward( class _IPEXLlamaAttention(_IPEXAttention): def __init__(self, module, config) -> None: super().__init__(module, config) + concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]) + bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias] + use_bias = bias_list != [] + self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias) + self.concat_qkv.weight = nn.Parameter(concat_weight) + if use_bias: + concat_bias = torch.concat(bias_list, 0) + self.concat_linear.bias = nn.Parameter(concat_bias) + self.q_slice = self.q_proj.out_features + self.k_slice = self.q_slice + self.k_proj.out_features + self.v_slice = self.k_slice + self.v_proj.out_features + del self.__dict__["_modules"]["q_proj"] + del self.__dict__["_modules"]["k_proj"] + del self.__dict__["_modules"]["v_proj"] if self.module_device == "cpu": if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: self.mha_linear_add = LinearAdd(module.o_proj) del self.__dict__["_modules"]["o_proj"] def qkv_gemm(self, hidden_states): - query = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim) - key = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim) - value = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim) + qkv_out = self.concat_qkv(hidden_states) + query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim) + key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim) + value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim) return query, key, value