diff --git a/src/transformer-test.cpp b/src/transformer-test.cpp index e7af09d..888631b 100644 --- a/src/transformer-test.cpp +++ b/src/transformer-test.cpp @@ -29,25 +29,26 @@ void testRopeSlice(const TransformerArchType archType, const int nSliceTests, co for (int j = 0; j < spec.kvDim; j++) k[j] = 1.0; for (slice_index_t sliceIndex = 0; sliceIndex < spec.nSlices; sliceIndex++) { - RopeSlice* slice; + RopeSlice slice(spec.dim, spec.kvDim, spec.nKvHeads, spec.nSlices, spec.seqLen, spec.headSize, spec.ropeTheta, sliceIndex); + RopeCommand* rope; if (archType == LLAMA) { - slice = new LlamaRopeSlice(&spec, sliceIndex); + rope = new LlamaRopeCommand(&slice); } else if (archType == MIXTRAL) { - slice = new FalconRopeSlice(&spec, sliceIndex); + rope = new FalconRopeCommand(&slice); } for (int threadIndex = 0; threadIndex < nThreads; threadIndex++) { - slice->forward( + rope->forward( true, &q[(sliceIndex * spec.dim) / spec.nSlices], pos, nThreads, threadIndex); - slice->forward( + rope->forward( false, &k[(sliceIndex * spec.kvDim) / spec.nSlices], pos, nThreads, threadIndex); } - delete slice; + delete rope; } if (si == 0 && nThreads == 1) {