Skip to content

Commit

Permalink
a next fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz committed Jun 12, 2024
1 parent e11d166 commit 88666a5
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/transformer-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,26 @@ void testRopeSlice(const TransformerArchType archType, const int nSliceTests, co
for (int j = 0; j < spec.kvDim; j++) k[j] = 1.0;

for (slice_index_t sliceIndex = 0; sliceIndex < spec.nSlices; sliceIndex++) {
RopeSlice* slice;
RopeSlice slice(spec.dim, spec.kvDim, spec.nKvHeads, spec.nSlices, spec.seqLen, spec.headSize, spec.ropeTheta, sliceIndex);
RopeCommand* rope;
if (archType == LLAMA) {
slice = new LlamaRopeSlice(&spec, sliceIndex);
rope = new LlamaRopeCommand(&slice);
} else if (archType == MIXTRAL) {
slice = new FalconRopeSlice(&spec, sliceIndex);
rope = new FalconRopeCommand(&slice);
}

for (int threadIndex = 0; threadIndex < nThreads; threadIndex++) {
slice->forward(
rope->forward(
true,
&q[(sliceIndex * spec.dim) / spec.nSlices],
pos, nThreads, threadIndex);
slice->forward(
rope->forward(
false,
&k[(sliceIndex * spec.kvDim) / spec.nSlices],
pos, nThreads, threadIndex);
}

delete slice;
delete rope;
}

if (si == 0 && nThreads == 1) {
Expand Down

0 comments on commit 88666a5

Please sign in to comment.