@@ -88,6 +88,28 @@ static bool bloom_model_eval_internal(model_context* ctx, const model_input* inp
8888 // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
8989 ne_cgraph gf = {};
9090 gf.n_threads = N >= 32 && ne_cpu_has_blas () ? 1 : n_threads;
91+ const bool run_mha_reordered = kv_self.k ->type == NE_TYPE_BTLA;
92+ kv_cache_info_t kv_cache_info = {};
93+ if (run_mha_reordered) {
94+ NE_ASSERT ((" kv cache should be the same dtype" , kv_self.v ->type == NE_TYPE_BTLA));
95+ attn_shape_t attn_shape = {
96+ /* .batch_size = */ 1 ,
97+ /* .head_num = */ n_head,
98+ /* .heads_kv = */ n_head,
99+ /* .head_size = */ head_dim,
100+ /* .sl_q = */ N, // Note: make sure that bestla reordered attn supports next token inference
101+ /* .sl_kv = */ n_past + N,
102+ };
103+
104+ NE_ASSERT ((" bestla managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead" ,
105+ bestla_reordered_attn_fp32_support (&attn_shape)));
106+ kv_shape_t kv_shape{
107+ /* .heads_kv = */ static_cast <uint32_t >(n_head),
108+ /* .head_size = */ static_cast <uint32_t >(head_dim),
109+ /* .sl_kv_max = */ static_cast <uint32_t >(n_ctx),
110+ };
111+ bestla_reordered_attn_fp32_batch_kv_info (&kv_shape, &kv_cache_info);
112+ }
91113
92114 struct ne_tensor * embd = d_ne_new_tensor_1d (ctx0, NE_TYPE_I32, N);
93115 ne_set_name (embd, " embd" );
@@ -129,12 +151,13 @@ static bool bloom_model_eval_internal(model_context* ctx, const model_input* inp
129151 // cur = ggml_debug(ctx0, cur);
130152
131153 // self-attention
132- {
133- struct ne_tensor * Qcur = ne_view_2d (ctx0, cur, n_embd, N, cur->nb [1 ], 0 * sizeof (float ) * n_embd);
134- struct ne_tensor * Kcur = ne_view_2d (ctx0, cur, n_embd, N, cur->nb [1 ], 1 * sizeof (float ) * n_embd);
135- struct ne_tensor * Vcur = ne_view_2d (ctx0, cur, n_embd, N, cur->nb [1 ], 2 * sizeof (float ) * n_embd);
136154
137- // store key and value to memory
155+ struct ne_tensor * Qcur = ne_view_2d (ctx0, cur, n_embd, N, cur->nb [1 ], 0 * sizeof (float ) * n_embd);
156+ struct ne_tensor * Kcur = ne_view_2d (ctx0, cur, n_embd, N, cur->nb [1 ], 1 * sizeof (float ) * n_embd);
157+ struct ne_tensor * Vcur = ne_view_2d (ctx0, cur, n_embd, N, cur->nb [1 ], 2 * sizeof (float ) * n_embd);
158+ const float attn_scale = 1 .0f / sqrtf (static_cast <float >(head_dim));
159+ // store key and value to memory
160+ if (!run_mha_reordered) {
138161 if (N >= 1 ) {
139162 struct ne_tensor * k =
140163 ne_view_1d (ctx0, kv_self.k , N * n_embd, (ne_element_size (kv_self.k ) * n_embd) * (il * n_ctx + n_past));
@@ -193,11 +216,52 @@ static bool bloom_model_eval_internal(model_context* ctx, const model_input* inp
193216
194217 // cur = KQV_merged.contiguous().view(n_embd, N)
195218 cur = ne_cpy (ctx0, KQV_merged, ne_new_tensor_2d (ctx0, NE_TYPE_F32, n_embd, N, NE_SIZE_CALC));
219+ } else {
220+ const auto seq_kv = n_past + N;
221+ const auto k_size = kv_cache_info.k_bytes ;
222+ const auto v_size = kv_cache_info.v_bytes ;
223+ // store key and value to memory
224+ {
225+ const auto k_cache = ne_view_3d (ctx0, kv_self.k , // tensor
226+ head_dim, n_ctx, n_head, // ne
227+ 0 , 0 , // nb (bestla managed)
228+ il * k_size); // offset
229+ Kcur = ne_view_3d (ctx0, Kcur, head_dim, n_head, N, Kcur->nb [0 ] * head_dim, Kcur->nb [1 ], 0 );
230+ ne_build_forward_expand (&gf, ne_flash_attn_update_k (ctx0, k_cache, Kcur, n_past, false ));
231+ const auto v_cache = ne_view_3d (ctx0, kv_self.v , // tensor
232+ head_dim, n_ctx, n_head, // ne
233+ 0 , 0 , // nb (bestla managed)
234+ il * v_size); // offset
235+ Vcur = ne_view_3d (ctx0, Vcur, head_dim, n_head, N, Vcur->nb [0 ] * head_dim, Vcur->nb [1 ], 0 );
236+ ne_build_forward_expand (&gf, ne_flash_attn_update_v (ctx0, v_cache, Vcur, n_past, false ));
237+ }
196238
197- // projection
198- cur = ne_mul_mat (ctx0, model.layers [il].attn [2 ], cur);
199- cur = ne_add (ctx0, ne_repeat (ctx0, model.layers [il].attn [3 ], cur), cur);
239+ struct ne_tensor * Q = ne_view_3d (ctx0, Qcur, head_dim, n_head, N, Qcur->nb [0 ] * head_dim, Qcur->nb [1 ], 0 );
240+ Q = ne_permute (ctx0, Q, 0 , 2 , 1 , 3 );
241+ ne_set_name (Q, " Q" );
242+ struct ne_tensor * K =
243+ ne_view_3d (ctx0, kv_self.k , // tensor
244+ head_dim, seq_kv, n_head, // ne
245+ kv_cache_info.stride_k_sl , kv_cache_info.stride_k_head_num , // nb (bestla managed)
246+ il * k_size); // offset
247+ *reinterpret_cast <ATTN_FWD_LAYOUT*>(&K->nb [0 ]) = kv_cache_info.k_layout ; // us nb0 for layout
248+ ne_set_name (K, " K" );
249+ struct ne_tensor * V =
250+ ne_view_3d (ctx0, kv_self.v , // tensor
251+ seq_kv, head_dim, n_head, // ne
252+ kv_cache_info.stride_v_head_size , kv_cache_info.stride_v_head_num , // nb (bestla managed)
253+ il * v_size); // offset
254+ *reinterpret_cast <ATTN_FWD_LAYOUT*>(&V->nb [0 ]) = kv_cache_info.v_layout ; // us nb0 for layout
255+ ne_set_name (V, " V" );
256+
257+ ne_attn_flags_t attn_flags = NE_ATTN_FLAG_IS_ALIBI8; // mpt uses alibi operation
258+ if (n_past == 0 ) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases
259+ struct ne_tensor * KQV_Out = ne_flash_attn (ctx0, Q, K, V, attn_scale, attn_flags);
260+ cur = ne_view_2d (ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size (KQV_Out), 0 );
200261 }
262+ // projection
263+ cur = ne_mul_mat (ctx0, model.layers [il].attn [2 ], cur);
264+ cur = ne_add (ctx0, ne_repeat (ctx0, model.layers [il].attn [3 ], cur), cur);
201265
202266 struct ne_tensor * inpFF = ne_add (ctx0, cur, inpSA);
203267
0 commit comments