Skip to content

Commit

Permalink
split.
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz committed Apr 28, 2024
1 parent 45c3e5b commit dc61185
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 19 deletions.
30 changes: 11 additions & 19 deletions src/llama2-tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,17 @@ void llamaDequantizeAtt(TASK_ARGS) {
dequantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XB2_QUANTIZED, TB_SLICED_XB2);
}

void llamaRmfFfn(TASK_ARGS) {
void llamaMergeAtt(TASK_ARGS) {
TASK_VARIABLES;
float* xb2 = (float*)transformer->buffer->getUnit(TB_SLICED_XB2);
float* x = (float*)transformer->x;
add(x, xb2, spec->dim, nThreads, threadIndex);
}

void llamaRmfFfn(TASK_ARGS) {
TASK_VARIABLES;
if (threadIndex == 0) {
float* xb2 = (float*)transformer->buffer->getUnit(TB_SLICED_XB2);
float* x = (float*)transformer->x;

for (int i = 0; i < spec->dim; i++) {
x[i] += xb2[i];
}
transformer->rms = rms(x, spec->dim);
transformer->rms = rms(transformer->x, spec->dim);
}
}

Expand Down Expand Up @@ -220,17 +220,8 @@ void llamaFfn(TASK_ARGS) {
matmul(spec->weightsFloatType, spec->bufferFloatType, hb0, xb, block->w10, block->w10Slice->n, block->w10Slice->d0, nThreads, threadIndex);
matmul(spec->weightsFloatType, spec->bufferFloatType, block->hb20, xb, block->w30, block->w30Slice->n, block->w30Slice->d0, nThreads, threadIndex);

// SwiGLU non-linearity
int d00 = block->w10Slice->d0 / nThreads;
int d0Offset = d00 * threadIndex;
for (int i = 0; i < d00; i++) {
float val = hb0[i + d0Offset];
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
val *= (1.0f / (1.0f + expf(-val)));
// elementwise multiply with w3(x)
val *= block->hb20[i + d0Offset];
hb0[i + d0Offset] = val;
}
silu(hb0, block->w10Slice->d0, nThreads, threadIndex);
mul(hb0, block->hb20, block->w10Slice->d0, nThreads, threadIndex);
}

void llamaQuantizeFfnA(TASK_ARGS) {
Expand Down Expand Up @@ -331,6 +322,7 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) {
a.I(llamaQuantizeAtt, TASK_TYPE_INFERENCE);
a.I(llamaSyncAtt, TASK_TYPE_TRANSFER);
a.I(llamaDequantizeAtt, TASK_TYPE_INFERENCE);
a.I(llamaMergeAtt, TASK_TYPE_INFERENCE);
a.I(llamaRmfFfn, TASK_TYPE_INFERENCE);
a.I(llamaRmfFfnNorm, TASK_TYPE_INFERENCE);
a.I(llamaQuantizeRmfFfn, TASK_TYPE_INFERENCE);
Expand Down
1 change: 1 addition & 0 deletions src/llama2-tasks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ void llamaAtt(TASK_ARGS);
void llamaQuantizeAtt(TASK_ARGS);
void llamaSyncAtt(TASK_ARGS);
void llamaDequantizeAtt(TASK_ARGS);
void llamaMergeAtt(TASK_ARGS);
void llamaRmfFfn(TASK_ARGS);
void llamaRmfFfnNorm(TASK_ARGS);
void llamaNextBlock(TASK_ARGS);
Expand Down
1 change: 1 addition & 0 deletions src/mixtral-tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ TransformerArch buildMixtralArch(TransformerSpec* spec) {
a.I(llamaQuantizeAtt, TASK_TYPE_INFERENCE);
a.I(llamaSyncAtt, TASK_TYPE_TRANSFER);
a.I(llamaDequantizeAtt, TASK_TYPE_INFERENCE);
a.I(llamaMergeAtt, TASK_TYPE_INFERENCE);
a.I(llamaRmfFfn, TASK_TYPE_INFERENCE);
a.I(llamaRmfFfnNorm, TASK_TYPE_INFERENCE);

Expand Down

0 comments on commit dc61185

Please sign in to comment.