Skip to content

Commit

Permalink
refactor. (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored Jun 29, 2024
1 parent 1b6024e commit 56b4060
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 107 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Inference, Chat, API
| `--model <path>` | Path to model. | `dllama_model_meta-llama-3-8b_q40.m` |
| `--tokenizer <path>` | Tokenizer to model. | `dllama_tokenizer_llama3.t` |
| `--buffer-float-type <type>` | Float precision of synchronization. | `q80` |
| `--workers <workers>` | Addresses of workers (ip:port), separated by space. | `0.0.0.1:9991 10.0.0.2:9991` |
| `--workers <workers>` | Addresses of workers (ip:port), separated by space. | `10.0.0.1:9991 10.0.0.2:9991` |

Inference, Chat, Worker, API

Expand Down Expand Up @@ -158,6 +158,7 @@ sudo apt install git
```sh
git clone https://github.com/b4rtaz/distributed-llama.git
make dllama
make dllama-api
```
6. Transfer weights and the tokenizer file to the root device.
7. Optional: assign static IP addresses.
Expand Down Expand Up @@ -196,6 +197,7 @@ sudo apt install git build-essential
```sh
git clone https://github.com/b4rtaz/distributed-llama.git
make dllama
make dllama-api
```

Continue to point 3.
Expand All @@ -210,6 +212,7 @@ choco install mingw
```sh
git clone https://github.com/b4rtaz/distributed-llama.git
make dllama
make dllama-api
```

Continue to point 3.
Expand Down
2 changes: 2 additions & 0 deletions docs/LLAMA.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ wget https://huggingface.co/b4rtaz/llama-2-distributed-llama/resolve/main/dllama
6. Build the project:
```bash
make dllama
make dllama-api
```
7. Run:
```bash
Expand Down Expand Up @@ -61,6 +62,7 @@ python converter/convert-tokenizer-llama3.py path/to/tokenizer.model
10. Build the project:
```bash
make dllama
make dllama-api
```
11. Run the Distributed Llama:
```bash
Expand Down
7 changes: 3 additions & 4 deletions src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ TransformerArch TransformerArchFactory::create(TransformerSpec* spec) {
exit(EXIT_FAILURE);
}

void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc)) {
void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec)) {
if (args->modelPath == NULL) {
throw std::runtime_error("Model is required");
}
Expand All @@ -119,15 +119,14 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s
args->steps = spec.seqLen;
}

AcceleratorContext acc(0, 1, NULL);
Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, socketPool, &acc);
Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, socketPool);
socketPool->setTurbo(true);

Inference inference = Inference(&arch, args->nThreads, &transformer, socketPool);

Sampler sampler(spec.vocabSize, args->temperature, args->topp, args->seed);

program(&inference, socketPool, &tokenizer, &sampler, args, &spec, &acc);
program(&inference, socketPool, &tokenizer, &sampler, args, &spec);

delete socketPool;
}
2 changes: 1 addition & 1 deletion src/app.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class TransformerArchFactory {

class App {
public:
static void run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc));
static void run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec));
};

#endif
2 changes: 1 addition & 1 deletion src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ void handleModelsRequest(HttpRequest& request) {
"] }");
}

void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc) {
void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
SocketServer* server = new SocketServer(args->port);

TokenizerChatStops stops(tokenizer);
Expand Down
7 changes: 3 additions & 4 deletions src/apps/dllama/dllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "../../tokenizer.hpp"
#include "../../app.hpp"

void generate(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc) {
void generate(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
if (args->prompt == NULL)
throw std::runtime_error("Prompt is required");

Expand Down Expand Up @@ -193,7 +193,7 @@ class Chat {
}
};

void chat(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc) {
void chat(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
TokenizerChatStops stops(tokenizer);
ChatTemplate chatTemplate(tokenizer->chatTemplate, stops.stops[0]);
EosDetector eosDetector(tokenizer->chatEosId, stops.nStops, stops.stops, stops.maxStopLength, stops.maxStopLength);
Expand All @@ -210,8 +210,7 @@ void worker(AppArgs* args) {
SocketServer server(args->port);
Socket socket = server.accept();
TransformerSpec spec;
AcceleratorContext acc(0, 1, NULL);
Transformer transformer = Transformer::loadSlice(&spec, &socket, &acc);
Transformer transformer = Transformer::loadSlice(&spec, &socket);
TransformerArch arch = TransformerArchFactory::create(&spec);

Worker worker = Worker(&arch, args->nThreads, &transformer, &socket);
Expand Down
39 changes: 4 additions & 35 deletions src/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,35 +104,13 @@ MultiHeadAttSlice::MultiHeadAttSlice(unsigned int nHeads, unsigned int seqLen, u
attSize = seqLen * nHeads0 * sizeof(float);
}

AcceleratorContext::AcceleratorContext(unsigned int nominator, unsigned int denominator, Accelerator* accelerator) {
this->nominator = nominator;
this->denominator = denominator;
this->accelerator = accelerator;
}

unsigned int AcceleratorContext::divCpu(unsigned int value) {
return value - divAcc(value);
}

unsigned int AcceleratorContext::divAcc(unsigned int value) {
return (nominator * value) / denominator;
}

MatmulCommand::MatmulCommand(const unsigned int n, const unsigned int d, const FloatType inputFloatType, const FloatType weightsFloatType, AcceleratorContext* acc) {
MatmulCommand::MatmulCommand(const unsigned int n, const unsigned int d, const FloatType inputFloatType, const FloatType weightsFloatType) {
this->n = n;
this->d = d;
this->inputFloatType = inputFloatType;
this->weightsFloatType = weightsFloatType;
this->acc = acc;
this->accD = acc->divAcc(d);
this->accSize = getBatchBytes(weightsFloatType, n, this->accD);
this->cpuD = acc->divCpu(d);
this->cpuSize = getBatchBytes(weightsFloatType, n, this->cpuD);
this->cpuSize = getBatchBytes(weightsFloatType, n, d);
this->cpuWeights = newBuffer(this->cpuSize);

if (this->accD != 0) {
this->accMatmulIndex = acc->accelerator->allocateMatmul(weightsFloatType, n, this->accD);
}
};

MatmulCommand::~MatmulCommand() {
Expand All @@ -141,20 +119,11 @@ MatmulCommand::~MatmulCommand() {

size_t MatmulCommand::loadWeights(const void* source) {
memcpy(cpuWeights, source, cpuSize);
if (this->accD != 0) {
acc->accelerator->loadMatmulWeights(this->accMatmulIndex, &((char*)source)[cpuSize]);
}
return cpuSize + accSize;
return cpuSize;
}

void MatmulCommand::forward(const void* input, float* output, const unsigned int nThreads, const unsigned int threadIndex) {
if (this->accD != 0 && threadIndex == 0) {
acc->accelerator->beginForwardMatmul(this->accMatmulIndex, input);
}
matmul(weightsFloatType, inputFloatType, output, input, cpuWeights, n, cpuD, nThreads, threadIndex);
if (this->accD != 0 && threadIndex == nThreads - 1) {
acc->accelerator->endForwardMatmul(this->accMatmulIndex, &output[cpuD]);
}
matmul(weightsFloatType, inputFloatType, output, input, cpuWeights, n, d, nThreads, threadIndex);
}

LlamaRopeCommand::LlamaRopeCommand(RopeSlice *slice) {
Expand Down
28 changes: 1 addition & 27 deletions src/commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,42 +75,16 @@ class MultiHeadAttSlice {
MultiHeadAttSlice(unsigned int nHeads, unsigned int seqLen, unsigned int nSlices, slice_index_t sliceIndex);
};

class Accelerator {
public:
virtual const unsigned int allocateMatmul(const FloatType floatType, const unsigned int n, const unsigned int d) = 0;
virtual void loadMatmulWeights(const unsigned int matmulIndex, const void* weights) = 0;
virtual void beginForwardMatmul(const unsigned int matmulIndex, const void* input) = 0;
virtual void endForwardMatmul(const unsigned int matmulIndex, float* output) = 0;
virtual void closeMatmul(const unsigned int matmulIndex) = 0;
};

class AcceleratorContext {
public:
// ratio
unsigned int nominator;
unsigned int denominator;
Accelerator* accelerator;

AcceleratorContext(unsigned int nominator, unsigned int denominator, Accelerator* accelerator);
unsigned int divCpu(unsigned int value);
unsigned int divAcc(unsigned int value);
};

class MatmulCommand {
private:
FloatType inputFloatType;
FloatType weightsFloatType;
unsigned int n;
unsigned int d;
unsigned int cpuD;
unsigned int accD;
size_t cpuSize;
size_t accSize;
void* cpuWeights;
unsigned int accMatmulIndex;
AcceleratorContext* acc;
public:
MatmulCommand(const unsigned int n, const unsigned int d, const FloatType inputFloatType, const FloatType weightsFloatType, AcceleratorContext* acc);
MatmulCommand(const unsigned int n, const unsigned int d, const FloatType inputFloatType, const FloatType weightsFloatType);
~MatmulCommand();
size_t loadWeights(const void* source);
void forward(const void* input, float* output, const unsigned int nThreads, const unsigned int threadIndex);
Expand Down
3 changes: 1 addition & 2 deletions src/grok1-tasks-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ int main() {
for (int f = 0; f < nFloats; f++) block[f] = randomF32(&state) / 100.0;

SocketPool socketPool(0, NULL);
AcceleratorContext acc(0, 1, NULL);
Transformer transformer = Transformer::loadRoot(weights, &spec, &socketPool, &acc);
Transformer transformer = Transformer::loadRoot(weights, &spec, &socketPool);
transformer.pos = 0;

float* x = transformer.x;
Expand Down
3 changes: 1 addition & 2 deletions src/llama2-tasks-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,7 @@ int main() {
for (int i = 0; i < mm; i++) mmData[i] = randomF32(&state) / 120.0;

SocketPool socketPool(0, NULL);
AcceleratorContext acc(0, 1, NULL);
Transformer transformer = Transformer::loadRoot((char*)data, &spec, &socketPool, &acc);
Transformer transformer = Transformer::loadRoot((char*)data, &spec, &socketPool);
transformer.pos = 0;

float* x = transformer.x;
Expand Down
44 changes: 21 additions & 23 deletions src/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,14 @@ size_t TransformerBuffer::getSlicedBytes(uint8_t bufferIndex) {
return bufferBytes[bufferIndex] / nSlices;
}

Transformer::Transformer(TransformerSpec* spec, slice_index_t sliceIndex, AcceleratorContext* acc) {
Transformer::Transformer(TransformerSpec* spec, slice_index_t sliceIndex) {
this->spec = spec;
this->sliceIndex = sliceIndex;
this->acc = acc;

buffer = new TransformerBuffer(spec);
blocks = new TransformerBlock*[spec->nLayers];
for (int i = 0; i < spec->nLayers; i++) {
blocks[i] = new TransformerBlock(spec, sliceIndex, acc);
blocks[i] = new TransformerBlock(spec, sliceIndex);
}

if (IS_ROOT_SLICE(sliceIndex)) {
Expand All @@ -217,7 +216,7 @@ Transformer::Transformer(TransformerSpec* spec, slice_index_t sliceIndex, Accele
tokenEmbeddingTable = (float*)newBuffer(tokenEmbeddingTableBytes);
rmsFinal = (float*)newBuffer(rmsFinalBytes);

wclsMm = new MatmulCommand(spec->dim, spec->vocabSize, F32, spec->weightsFloatType, acc);
wclsMm = new MatmulCommand(spec->dim, spec->vocabSize, F32, spec->weightsFloatType);

x = (float*)newBuffer(spec->dim * sizeof(float));
logits = (float*)newBuffer(spec->vocabSize * sizeof(float));
Expand Down Expand Up @@ -258,10 +257,9 @@ Transformer::~Transformer() {
delete rope;
}

TransformerBlock::TransformerBlock(TransformerSpec* spec, slice_index_t sliceIndex, AcceleratorContext* acc) {
TransformerBlock::TransformerBlock(TransformerSpec* spec, slice_index_t sliceIndex) {
this->sliceIndex = sliceIndex;
this->spec = spec;
this->acc = acc;

if (IS_ROOT_SLICE(sliceIndex)) {
rmsAttBytes = spec->dim * sizeof(float);
Expand Down Expand Up @@ -289,10 +287,10 @@ TransformerBlock::TransformerBlock(TransformerSpec* spec, slice_index_t sliceInd
v0Slice = new RowMatmulSlice(spec->weightsFloatType, spec->nSlices, spec->dim, spec->kvDim);
wo0Slice = new ColMatmulSlice(spec->weightsFloatType, spec->nSlices, spec->dim, spec->dim);

q0mm = new MatmulCommand(q0Slice->n, q0Slice->d0, spec->bufferFloatType, spec->weightsFloatType, acc);
k0mm = new MatmulCommand(k0Slice->n, k0Slice->d0, spec->bufferFloatType, spec->weightsFloatType, acc);
v0mm = new MatmulCommand(v0Slice->n, v0Slice->d0, spec->bufferFloatType, spec->weightsFloatType, acc);
wo0mm = new MatmulCommand(wo0Slice->n0, wo0Slice->d, spec->bufferFloatType, spec->weightsFloatType, acc);
q0mm = new MatmulCommand(q0Slice->n, q0Slice->d0, spec->bufferFloatType, spec->weightsFloatType);
k0mm = new MatmulCommand(k0Slice->n, k0Slice->d0, spec->bufferFloatType, spec->weightsFloatType);
v0mm = new MatmulCommand(v0Slice->n, v0Slice->d0, spec->bufferFloatType, spec->weightsFloatType);
wo0mm = new MatmulCommand(wo0Slice->n0, wo0Slice->d, spec->bufferFloatType, spec->weightsFloatType);

qo0 = (float*)newBuffer(q0Slice->d0 * sizeof(float));

Expand All @@ -305,12 +303,12 @@ TransformerBlock::TransformerBlock(TransformerSpec* spec, slice_index_t sliceInd
moeUpMm = new MatmulCommand*[spec->nExperts];
moeGateMm = new MatmulCommand*[spec->nExperts];
moeDownMm = new MatmulCommand*[spec->nExperts];
moeRouterMm = new MatmulCommand(spec->dim, spec->nExperts, F32, spec->weightsFloatType, acc);
moeRouterMm = new MatmulCommand(spec->dim, spec->nExperts, F32, spec->weightsFloatType);

for (int e = 0; e < spec->nExperts; e++) {
moeUpMm[e] = new MatmulCommand(moeUpAndGate0Slice->n, moeUpAndGate0Slice->d0, spec->bufferFloatType, spec->weightsFloatType, acc);
moeGateMm[e] = new MatmulCommand(moeUpAndGate0Slice->n, moeUpAndGate0Slice->d0, spec->bufferFloatType, spec->weightsFloatType, acc);
moeDownMm[e] = new MatmulCommand(moeDown0Slice->n, moeDown0Slice->d0, spec->bufferFloatType, spec->weightsFloatType, acc);
moeUpMm[e] = new MatmulCommand(moeUpAndGate0Slice->n, moeUpAndGate0Slice->d0, spec->bufferFloatType, spec->weightsFloatType);
moeGateMm[e] = new MatmulCommand(moeUpAndGate0Slice->n, moeUpAndGate0Slice->d0, spec->bufferFloatType, spec->weightsFloatType);
moeDownMm[e] = new MatmulCommand(moeDown0Slice->n, moeDown0Slice->d0, spec->bufferFloatType, spec->weightsFloatType);
}

expertGate = (float*)newBuffer(moeUpAndGate0Slice->d0 * spec->nExperts * sizeof(float));
Expand All @@ -320,9 +318,9 @@ TransformerBlock::TransformerBlock(TransformerSpec* spec, slice_index_t sliceInd
w20Slice = new ColMatmulSlice(spec->weightsFloatType, spec->nSlices, spec->hiddenDim, spec->dim);
w30Slice = new RowMatmulSlice(spec->weightsFloatType, spec->nSlices, spec->dim, spec->hiddenDim);

w10mm = new MatmulCommand(w10Slice->n, w10Slice->d0, spec->bufferFloatType, spec->weightsFloatType, acc);
w20mm = new MatmulCommand(w20Slice->n0, w20Slice->d, spec->bufferFloatType, spec->weightsFloatType, acc);
w30mm = new MatmulCommand(w30Slice->n, w30Slice->d0, spec->bufferFloatType, spec->weightsFloatType, acc);
w10mm = new MatmulCommand(w10Slice->n, w10Slice->d0, spec->bufferFloatType, spec->weightsFloatType);
w20mm = new MatmulCommand(w20Slice->n0, w20Slice->d, spec->bufferFloatType, spec->weightsFloatType);
w30mm = new MatmulCommand(w30Slice->n, w30Slice->d0, spec->bufferFloatType, spec->weightsFloatType);

hb20 = (float*)newBuffer(w30Slice->d0 * sizeof(float));
}
Expand Down Expand Up @@ -413,23 +411,23 @@ static size_t readSlicedMatmulWeights(MatmulSlice* slice, char* weights0, Socket
return slice->sliceBytes;
}

Transformer Transformer::loadRootFromFile(const char* path, TransformerSpec* spec, SocketPool* socketPool, AcceleratorContext* acc) {
Transformer Transformer::loadRootFromFile(const char* path, TransformerSpec* spec, SocketPool* socketPool) {
MmapFile file;
openMmapFile(&file, path, spec->fileSize);

char* weights = ((char*)file.data) + spec->headerSize;
Transformer transformer = Transformer::loadRoot((char*)weights, spec, socketPool, acc);
Transformer transformer = Transformer::loadRoot((char*)weights, spec, socketPool);

closeMmapFile(&file);

return transformer;
}

Transformer Transformer::loadRoot(char* data, TransformerSpec* spec, SocketPool* socketPool, AcceleratorContext* acc) {
Transformer Transformer::loadRoot(char* data, TransformerSpec* spec, SocketPool* socketPool) {
assert(socketPool->nSockets == spec->nSlices - 1);

const slice_index_t sliceIndex = 0; // Root slice
Transformer transformer(spec, sliceIndex, acc);
Transformer transformer(spec, sliceIndex);

if (spec->nSlices > 1) {
for (slice_index_t sliceIndex = 1; sliceIndex < spec->nSlices; sliceIndex++) {
Expand Down Expand Up @@ -486,7 +484,7 @@ Transformer Transformer::loadRoot(char* data, TransformerSpec* spec, SocketPool*
return transformer;
}

Transformer Transformer::loadSlice(TransformerSpec* spec, Socket* socket, AcceleratorContext* acc) {
Transformer Transformer::loadSlice(TransformerSpec* spec, Socket* socket) {
slice_index_t sliceIndex;
socket->read((char*)&sliceIndex, sizeof(uint8_t));
socket->read((char*)spec, sizeof(TransformerSpec));
Expand All @@ -495,7 +493,7 @@ Transformer Transformer::loadSlice(TransformerSpec* spec, Socket* socket, Accele
printf("💡 nSlices: %d\n", spec->nSlices);

assert(sliceIndex >= 1);
Transformer transformer(spec, sliceIndex, acc);
Transformer transformer(spec, sliceIndex);

size_t bufferSize = 0;
// TODO: this is ugly
Expand Down
Loading

0 comments on commit 56b4060

Please sign in to comment.