Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit c57d25f

Browse files
[Fusion]enable bloom mha fusion (#286)
* enable bloom mha fusion Signed-off-by: intellinjun <[email protected]> * update proxy Signed-off-by: intellinjun <[email protected]> * update ci Signed-off-by: intellinjun <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: intellinjun <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ef42ce1 commit c57d25f

File tree

4 files changed

+80
-9
lines changed

4 files changed

+80
-9
lines changed

.github/workflows/cpp-graph-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ jobs:
6767
6868
- name: BF16 Benchmark
6969
run: |
70+
export https_proxy=http://proxy.ims.intel.com:911
71+
export http_proxy=http://proxy.ims.intel.com:911
7072
cd ${{ github.workspace }}/.github/workflows/scripts/models
7173
bash cpp_graph_inference.sh cpp-graph-test-neural-speed ${{ matrix.modelName }} ${{ env.INPUT_COMPILER_VERSION }}
7274

.github/workflows/scripts/models/calculate_percertiles.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@ def parse_output_file_acc(file_path):
2222
accuracy = 0
2323
with open(file_path, 'r', encoding='UTF-8', errors='ignore') as file:
2424
for line in file:
25-
accuracy_match = re.search(r"\|\s+\|\s+\|none\s+\|\s+0\|acc\s+\|\d\.\d+\|\±\s+\|\d\.\d+\|", line)
25+
accuracy_match = re.search(r"\|\s+\|\s+\|none\s+\|\s+0\|acc\s+\|\d\.\d+\|\±\s+\|+\s+\d\.\d+\|", line)
26+
if not accuracy_match:
27+
accuracy_match = re.search(r"\|\s+\|\s+\|none\s+\|\s+0\|acc\s+\|\d\.\d+\|\±\s+\|\d\.\d+\|", line)
2628
if accuracy_match:
2729
accuracy = float(re.search(r"\d+\.\d+", accuracy_match.group()).group())*100
30+
2831
return accuracy
2932

3033
def parse_memory_file(memory_file):

neural_speed/models/bloom/bloom.cpp

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,28 @@ static bool bloom_model_eval_internal(model_context* ctx, const model_input* inp
8888
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
8989
ne_cgraph gf = {};
9090
gf.n_threads = N >= 32 && ne_cpu_has_blas() ? 1 : n_threads;
91+
const bool run_mha_reordered = kv_self.k->type == NE_TYPE_BTLA;
92+
kv_cache_info_t kv_cache_info = {};
93+
if (run_mha_reordered) {
94+
NE_ASSERT(("kv cache should be the same dtype", kv_self.v->type == NE_TYPE_BTLA));
95+
attn_shape_t attn_shape = {
96+
/* .batch_size = */ 1,
97+
/* .head_num = */ n_head,
98+
/* .heads_kv = */ n_head,
99+
/* .head_size = */ head_dim,
100+
/* .sl_q = */ N, // Note: make sure that bestla reordered attn supports next token inference
101+
/* .sl_kv = */ n_past + N,
102+
};
103+
104+
NE_ASSERT(("bestla managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead",
105+
bestla_reordered_attn_fp32_support(&attn_shape)));
106+
kv_shape_t kv_shape{
107+
/* .heads_kv = */ static_cast<uint32_t>(n_head),
108+
/* .head_size = */ static_cast<uint32_t>(head_dim),
109+
/* .sl_kv_max = */ static_cast<uint32_t>(n_ctx),
110+
};
111+
bestla_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info);
112+
}
91113

92114
struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N);
93115
ne_set_name(embd, "embd");
@@ -129,12 +151,13 @@ static bool bloom_model_eval_internal(model_context* ctx, const model_input* inp
129151
// cur = ggml_debug(ctx0, cur);
130152

131153
// self-attention
132-
{
133-
struct ne_tensor* Qcur = ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0 * sizeof(float) * n_embd);
134-
struct ne_tensor* Kcur = ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1 * sizeof(float) * n_embd);
135-
struct ne_tensor* Vcur = ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2 * sizeof(float) * n_embd);
136154

137-
// store key and value to memory
155+
struct ne_tensor* Qcur = ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0 * sizeof(float) * n_embd);
156+
struct ne_tensor* Kcur = ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1 * sizeof(float) * n_embd);
157+
struct ne_tensor* Vcur = ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2 * sizeof(float) * n_embd);
158+
const float attn_scale = 1.0f / sqrtf(static_cast<float>(head_dim));
159+
// store key and value to memory
160+
if (!run_mha_reordered) {
138161
if (N >= 1) {
139162
struct ne_tensor* k =
140163
ne_view_1d(ctx0, kv_self.k, N * n_embd, (ne_element_size(kv_self.k) * n_embd) * (il * n_ctx + n_past));
@@ -193,11 +216,52 @@ static bool bloom_model_eval_internal(model_context* ctx, const model_input* inp
193216

194217
// cur = KQV_merged.contiguous().view(n_embd, N)
195218
cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N, NE_SIZE_CALC));
219+
} else {
220+
const auto seq_kv = n_past + N;
221+
const auto k_size = kv_cache_info.k_bytes;
222+
const auto v_size = kv_cache_info.v_bytes;
223+
// store key and value to memory
224+
{
225+
const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor
226+
head_dim, n_ctx, n_head, // ne
227+
0, 0, // nb (bestla managed)
228+
il * k_size); // offset
229+
Kcur = ne_view_3d(ctx0, Kcur, head_dim, n_head, N, Kcur->nb[0] * head_dim, Kcur->nb[1], 0);
230+
ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past, false));
231+
const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor
232+
head_dim, n_ctx, n_head, // ne
233+
0, 0, // nb (bestla managed)
234+
il * v_size); // offset
235+
Vcur = ne_view_3d(ctx0, Vcur, head_dim, n_head, N, Vcur->nb[0] * head_dim, Vcur->nb[1], 0);
236+
ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur, n_past, false));
237+
}
196238

197-
// projection
198-
cur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
199-
cur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[3], cur), cur);
239+
struct ne_tensor* Q = ne_view_3d(ctx0, Qcur, head_dim, n_head, N, Qcur->nb[0] * head_dim, Qcur->nb[1], 0);
240+
Q = ne_permute(ctx0, Q, 0, 2, 1, 3);
241+
ne_set_name(Q, "Q");
242+
struct ne_tensor* K =
243+
ne_view_3d(ctx0, kv_self.k, // tensor
244+
head_dim, seq_kv, n_head, // ne
245+
kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (bestla managed)
246+
il * k_size); // offset
247+
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout
248+
ne_set_name(K, "K");
249+
struct ne_tensor* V =
250+
ne_view_3d(ctx0, kv_self.v, // tensor
251+
seq_kv, head_dim, n_head, // ne
252+
kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (bestla managed)
253+
il * v_size); // offset
254+
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout
255+
ne_set_name(V, "V");
256+
257+
ne_attn_flags_t attn_flags = NE_ATTN_FLAG_IS_ALIBI8; // mpt uses alibi operation
258+
if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases
259+
struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags);
260+
cur = ne_view_2d(ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size(KQV_Out), 0);
200261
}
262+
// projection
263+
cur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
264+
cur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[3], cur), cur);
201265

202266
struct ne_tensor* inpFF = ne_add(ctx0, cur, inpSA);
203267

neural_speed/models/bloom/bloom_utils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ void model_load_internal(const std::string& fname, model_archs arch, model_conte
4545
std::unique_ptr<BLOOM> ms(new BLOOM());
4646
ms->init(fname.c_str(), ctx, n_gpu_layers, use_mmap, use_mlock, vocab_only);
4747
ms->load(ctx, progress_callback, progress_callback_user_data);
48+
model_context& lctx = *ctx;
49+
lctx.support_bestla_kv = true;
4850
}
4951

5052
void BLOOM::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bool use_mmap_, bool use_mlock_,

0 commit comments

Comments
 (0)