Skip to content

Commit

Permalink
feat: accelerator structure. (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored Jun 12, 2024
1 parent 802d70a commit 1b6024e
Show file tree
Hide file tree
Showing 21 changed files with 754 additions and 618 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
make funcs-test
make quants-test
make tokenizer-test
make transformer-test
make commands-test
make llama2-tasks-test
make grok1-tasks-test
- name: funcs-test
Expand All @@ -40,8 +40,8 @@ jobs:
run: ./quants-test
- name: tokenizer-test
run: ./tokenizer-test
- name: transformer-test
run: ./transformer-test
- name: commands-test
run: ./commands-test
- name: llama2-tasks-test
run: ./llama2-tasks-test
- name: grok1-tasks-test
Expand All @@ -64,7 +64,7 @@ jobs:
make funcs-test
make quants-test
make tokenizer-test
make transformer-test
make commands-test
make llama2-tasks-test
make grok1-tasks-test
- name: funcs-test
Expand All @@ -73,8 +73,8 @@ jobs:
run: ./quants-test
- name: tokenizer-test
run: ./tokenizer-test
- name: transformer-test
run: ./transformer-test
- name: commands-test
run: ./commands-test
- name: llama2-tasks-test
run: ./llama2-tasks-test
- name: grok1-tasks-test
Expand Down
26 changes: 14 additions & 12 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ funcs: src/funcs.cpp
$(CXX) $(CXXFLAGS) -c src/funcs.cpp -o funcs.o
funcs-test: src/funcs-test.cpp funcs
$(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o
commands: src/commands.cpp
$(CXX) $(CXXFLAGS) -c src/commands.cpp -o commands.o
socket: src/socket.cpp
$(CXX) $(CXXFLAGS) -c src/socket.cpp -o socket.o
transformer: src/utils.cpp
Expand All @@ -33,20 +35,20 @@ tokenizer: src/tokenizer.cpp
app: src/app.cpp
$(CXX) $(CXXFLAGS) -c src/app.cpp -o app.o

dllama: src/apps/dllama/dllama.cpp utils quants funcs socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app
$(CXX) $(CXXFLAGS) src/apps/dllama/dllama.cpp -o dllama utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)
dllama-api: src/apps/dllama-api/dllama-api.cpp utils quants funcs socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app
$(CXX) $(CXXFLAGS) src/apps/dllama-api/dllama-api.cpp -o dllama-api utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)
dllama: src/apps/dllama/dllama.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app
$(CXX) $(CXXFLAGS) src/apps/dllama/dllama.cpp -o dllama utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)
dllama-api: src/apps/dllama-api/dllama-api.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app
$(CXX) $(CXXFLAGS) src/apps/dllama-api/dllama-api.cpp -o dllama-api utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)

funcs-test: src/funcs-test.cpp funcs utils quants
$(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o utils.o quants.o $(LIBS)
quants-test: src/quants.cpp utils quants
$(CXX) $(CXXFLAGS) src/quants-test.cpp -o quants-test utils.o quants.o $(LIBS)
tokenizer-test: src/tokenizer-test.cpp tokenizer funcs utils quants
$(CXX) $(CXXFLAGS) src/tokenizer-test.cpp -o tokenizer-test tokenizer.o funcs.o utils.o quants.o $(LIBS)
transformer-test: src/transformer-test.cpp funcs utils quants transformer socket
$(CXX) $(CXXFLAGS) src/transformer-test.cpp -o transformer-test funcs.o utils.o quants.o transformer.o socket.o $(LIBS)
llama2-tasks-test: src/llama2-tasks-test.cpp utils quants funcs socket transformer tasks llama2-tasks tokenizer
$(CXX) $(CXXFLAGS) src/llama2-tasks-test.cpp -o llama2-tasks-test utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o tokenizer.o $(LIBS)
grok1-tasks-test: src/grok1-tasks-test.cpp utils quants funcs socket transformer tasks llama2-tasks grok1-tasks tokenizer
$(CXX) $(CXXFLAGS) src/grok1-tasks-test.cpp -o grok1-tasks-test utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o tokenizer.o $(LIBS)
tokenizer-test: src/tokenizer-test.cpp tokenizer funcs commands utils quants
$(CXX) $(CXXFLAGS) src/tokenizer-test.cpp -o tokenizer-test tokenizer.o funcs.o commands.o utils.o quants.o $(LIBS)
commands-test: src/commands-test.cpp funcs commands utils quants transformer socket
$(CXX) $(CXXFLAGS) src/commands-test.cpp -o commands-test funcs.o commands.o utils.o quants.o transformer.o socket.o $(LIBS)
llama2-tasks-test: src/llama2-tasks-test.cpp utils quants funcs commands socket transformer tasks llama2-tasks tokenizer
$(CXX) $(CXXFLAGS) src/llama2-tasks-test.cpp -o llama2-tasks-test utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o tokenizer.o $(LIBS)
grok1-tasks-test: src/grok1-tasks-test.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks tokenizer
$(CXX) $(CXXFLAGS) src/grok1-tasks-test.cpp -o grok1-tasks-test utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o tokenizer.o $(LIBS)
2 changes: 1 addition & 1 deletion examples/macbeth.sh
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ Macbeth. Thou seest the moon"

echo "Generating, it can take a while..."

OUTPUT=$(( ./dllama generate --seed 12345 --temperature 0.9 --topp 0.9 --prompt "$PROMPT" --weights-float-type q40 --buffer-float-type f32 --nthreads 8 --steps 2048 --model converter/dllama_meta-llama-3-8b_q40.bin --tokenizer converter/dllama_meta-llama3-tokenizer.t ) 2>&1)
OUTPUT=$(( ./dllama generate --seed 12345 --temperature 0.9 --topp 0.9 --prompt "$PROMPT" --weights-float-type q40 --buffer-float-type f32 --nthreads 2 --steps 2048 --model models/llama3_8b_q40/dllama_model_llama3_8b_q40.m --tokenizer models/llama3_8b_q40/dllama_tokenizer_llama3_8b_q40.t --workers 127.0.0.1:9999 127.0.0.1:9998 127.0.0.1:9997 ) 2>&1)
echo "$OUTPUT"
Expand Down
7 changes: 4 additions & 3 deletions src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ TransformerArch TransformerArchFactory::create(TransformerSpec* spec) {
exit(EXIT_FAILURE);
}

void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec)) {
void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc)) {
if (args->modelPath == NULL) {
throw std::runtime_error("Model is required");
}
Expand All @@ -119,14 +119,15 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s
args->steps = spec.seqLen;
}

Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, socketPool);
AcceleratorContext acc(0, 1, NULL);
Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, socketPool, &acc);
socketPool->setTurbo(true);

Inference inference = Inference(&arch, args->nThreads, &transformer, socketPool);

Sampler sampler(spec.vocabSize, args->temperature, args->topp, args->seed);

program(&inference, socketPool, &tokenizer, &sampler, args, &spec);
program(&inference, socketPool, &tokenizer, &sampler, args, &spec, &acc);

delete socketPool;
}
8 changes: 4 additions & 4 deletions src/app.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#ifndef FUNCS_HPP
#define FUNCS_HPP
#ifndef APP_HPP
#define APP_HPP

#include "quants.hpp"
#include "transformer.hpp"
#include "utils.hpp"
#include "socket.hpp"
#include "utils.hpp"
#include "app.hpp"
#include "transformer.hpp"
#include "tasks.hpp"
Expand Down Expand Up @@ -46,7 +46,7 @@ class TransformerArchFactory {

class App {
public:
static void run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec));
static void run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc));
};

#endif
2 changes: 1 addition & 1 deletion src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ void handleModelsRequest(HttpRequest& request) {
"] }");
}

void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc) {
SocketServer* server = new SocketServer(args->port);

TokenizerChatStops stops(tokenizer);
Expand Down
7 changes: 4 additions & 3 deletions src/apps/dllama/dllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "../../tokenizer.hpp"
#include "../../app.hpp"

void generate(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
void generate(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc) {
if (args->prompt == NULL)
throw std::runtime_error("Prompt is required");

Expand Down Expand Up @@ -193,7 +193,7 @@ class Chat {
}
};

void chat(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
void chat(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc) {
TokenizerChatStops stops(tokenizer);
ChatTemplate chatTemplate(tokenizer->chatTemplate, stops.stops[0]);
EosDetector eosDetector(tokenizer->chatEosId, stops.nStops, stops.stops, stops.maxStopLength, stops.maxStopLength);
Expand All @@ -210,7 +210,8 @@ void worker(AppArgs* args) {
SocketServer server(args->port);
Socket socket = server.accept();
TransformerSpec spec;
Transformer transformer = Transformer::loadSlice(&spec, &socket);
AcceleratorContext acc(0, 1, NULL);
Transformer transformer = Transformer::loadSlice(&spec, &socket, &acc);
TransformerArch arch = TransformerArchFactory::create(&spec);

Worker worker = Worker(&arch, args->nThreads, &transformer, &socket);
Expand Down
85 changes: 85 additions & 0 deletions src/commands-test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include "commands.hpp"
#include <cmath>
#include <cstdio>
#include <cstring>

void testRopeSlice(int arch, const int nSliceTests, const int nPosTests, const int nThreadTests) {
int dim = 4096;
int headSize = 128;
int nKvHeads = 8;
int seqLen = 2048;
int nHeads = dim / headSize;
int kvDim = (dim * nKvHeads) / nHeads;
int ropeTheta = 10000.0f;

float* q = new float[dim];
float* k = new float[kvDim];
float* correctQ = new float[dim];
float* correctK = new float[kvDim];

for (int pos = 0; pos < seqLen; pos += seqLen / nPosTests) {
for (int si = 0; si < nSliceTests; si++) {
int nSlices = pow(2, si);

for (int nThreads = 1; nThreads <= nThreadTests; nThreads++) {
printf("pos=%d nSlices=%d threads=%d\n", pos, nSlices, nThreads);

for (int j = 0; j < dim; j++) q[j] = 1.0;
for (int j = 0; j < kvDim; j++) k[j] = 1.0;

for (slice_index_t sliceIndex = 0; sliceIndex < nSlices; sliceIndex++) {
RopeSlice slice(dim, kvDim, nKvHeads, nSlices, seqLen, headSize, ropeTheta, sliceIndex);
RopeCommand* rope;
if (arch == 1) {
rope = new LlamaRopeCommand(&slice);
} else if (arch == 2) {
rope = new FalconRopeCommand(&slice);
}

for (int threadIndex = 0; threadIndex < nThreads; threadIndex++) {
rope->forward(
true,
&q[(sliceIndex * dim) / nSlices],
pos, nThreads, threadIndex);
rope->forward(
false,
&k[(sliceIndex * kvDim) / nSlices],
pos, nThreads, threadIndex);
}

delete rope;
}

if (si == 0 && nThreads == 1) {
memcpy(correctQ, q, dim * sizeof(float));
memcpy(correctK, k, kvDim * sizeof(float));
} else {
for (int j = 0; j < dim; j++) {
if (fabs(q[j] - correctQ[j]) > 1e-6) {
printf("q[%d] mismatch: %f != %f (arch=%d)\n", j, q[j], correctQ[j], arch);
exit(EXIT_FAILURE);
}
}
for (int j = 0; j < kvDim; j++) {
if (fabs(k[j] - correctK[j]) > 1e-6) {
printf("k[%d] mismatch: %f != %f (arch=%d)\n", j, k[j], correctK[j], arch);
exit(EXIT_FAILURE);
}
}
}
}
}
}

delete[] q;
delete[] k;
delete[] correctQ;
delete[] correctK;
printf("✅ ropeSlice (arch=%d)\n", arch);
}

int main() {
testRopeSlice(2, 4, 6, 3);
testRopeSlice(1, 6, 4, 3);
return 0;
}
Loading

0 comments on commit 1b6024e

Please sign in to comment.