Skip to content

Commit

Permalink
multithread multihead.
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz committed Apr 28, 2024
1 parent dc61185 commit 93f4beb
Showing 1 changed file with 36 additions and 35 deletions.
71 changes: 36 additions & 35 deletions src/llama2-tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,42 +103,43 @@ void llamaMultiheadAttRope(TASK_ARGS) {

void llamaMultiheadAttJoin(TASK_ARGS) {
TASK_VARIABLES;
if (threadIndex == 0) {
float* q = (float*)transformer->buffer->getUnit(TB_SLICED_Q);
float* xb = (float*)transformer->buffer->getUnit(TB_UNIT_XB);
int kvMul = spec->nHeads / spec->nKvHeads; // integer multiplier of the kv sharing in multiquery

// multihead attention. iterate over all heads
int h;
for (h = 0; h < spec->nHeads; h++) {
// get the query vector for this head
float* _q = q + h * spec->headSize;
// attention scores for this head
float* _att = block->att + h * spec->seqLen;
// iterate over all timesteps, including the current one
for (int t = 0; t <= transformer->pos; t++) {
// get the key vector for this head and at this timestep
float* k = block->keyCache + t * spec->kvDim + (h / kvMul) * spec->headSize;
// calculate the attention score as the dot product of q and k
float score = dotProduct(_q, k, spec->headSize) / sqrtf(spec->headSize);
_att[t] = score;
}
float* q = (float*)transformer->buffer->getUnit(TB_SLICED_Q);
float* xb = (float*)transformer->buffer->getUnit(TB_UNIT_XB);

int kvMul = spec->nHeads / spec->nKvHeads; // integer multiplier of the kv sharing in multiquery
int nHeadsPerThread = spec->nHeads / nThreads;

int hStart = threadIndex * nHeadsPerThread;
int hEnd = threadIndex == nThreads - 1 ? spec->nHeads : hEnd + nHeadsPerThread;

for (int h = hStart; h < hEnd; h++) {
// get the query vector for this head
float* _q = q + h * spec->headSize;
// attention scores for this head
float* _att = block->att + h * spec->seqLen;
// iterate over all timesteps, including the current one
for (int t = 0; t <= transformer->pos; t++) {
// get the key vector for this head and at this timestep
float* k = block->keyCache + t * spec->kvDim + (h / kvMul) * spec->headSize;
// calculate the attention score as the dot product of q and k
float score = dotProduct(_q, k, spec->headSize) / sqrtf(spec->headSize);
_att[t] = score;
}

// softmax the scores to get attention weights, from 0..pos inclusively
softmax(_att, transformer->pos + 1);

// weighted sum of the values, store back into xb
float* _xb = xb + h * spec->headSize;
memset(_xb, 0, spec->headSize * sizeof(float));
for (int t = 0; t <= transformer->pos; t++) {
// get the value vector for this head and at this timestep
float* _v = block->valueCache + t * spec->kvDim + (h / kvMul) * spec->headSize;
// get the attention weight for this timestep
float a = _att[t];
// accumulate the weighted value into xb
for (int i = 0; i < spec->headSize; i++) {
_xb[i] += a * _v[i];
}
// softmax the scores to get attention weights, from 0..pos inclusively
softmax(_att, transformer->pos + 1);

// weighted sum of the values, store back into xb
float* _xb = xb + h * spec->headSize;
memset(_xb, 0, spec->headSize * sizeof(float));
for (int t = 0; t <= transformer->pos; t++) {
// get the value vector for this head and at this timestep
float* _v = block->valueCache + t * spec->kvDim + (h / kvMul) * spec->headSize;
// get the attention weight for this timestep
float a = _att[t];
// accumulate the weighted value into xb
for (int i = 0; i < spec->headSize; i++) {
_xb[i] += a * _v[i];
}
}
}
Expand Down

0 comments on commit 93f4beb

Please sign in to comment.