Skip to content

Commit

Permalink
multithread multihead.
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz committed Apr 28, 2024
1 parent 93f4beb commit 9dbfa9f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/llama2-tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ void llamaMultiheadAttJoin(TASK_ARGS) {
int nHeadsPerThread = spec->nHeads / nThreads;

int hStart = threadIndex * nHeadsPerThread;
int hEnd = threadIndex == nThreads - 1 ? spec->nHeads : hEnd + nHeadsPerThread;
int hEnd = (threadIndex == nThreads - 1) ? spec->nHeads : (hStart + nHeadsPerThread);

for (int h = hStart; h < hEnd; h++) {
// get the query vector for this head
Expand Down

0 comments on commit 9dbfa9f

Please sign in to comment.