Skip to content

Commit

Permalink
accelerator manager.
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz committed Jun 16, 2024
1 parent 48d4b79 commit b3fca58
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 15 deletions.
34 changes: 21 additions & 13 deletions src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,24 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
return args;
}

AcceleratorManager::AcceleratorManager(AppArgs* args) {
if (args->acceleratorNominator > 0 && args->acceleratorDenominator > 0) {
printf("🚀 acceleratorRatio: %d/%d\n", args->acceleratorNominator, args->acceleratorDenominator);
#ifdef DLLAMA_VULKAN
accelerator = new AcceleratorVulkan();
#endif
} else {
accelerator = NULL;
}
context = new AcceleratorContext(args->acceleratorNominator, args->acceleratorDenominator, accelerator);
}

AcceleratorManager::~AcceleratorManager() {
if (accelerator != NULL)
delete accelerator;
delete context;
}

TransformerArch TransformerArchFactory::create(TransformerSpec* spec) {
if (spec->archType == LLAMA) return buildLlamaArch(spec);
if (spec->archType == GROK1) return buildGrok1Arch(spec);
Expand Down Expand Up @@ -139,25 +157,15 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s
args->steps = spec.seqLen;
}

Accelerator* accelerator = NULL;
if (args->acceleratorNominator > 0 && args->acceleratorDenominator > 0) {
printf("🚀 acceleratorRatio: %d/%d\n", args->acceleratorNominator, args->acceleratorDenominator);
#ifdef DLLAMA_VULKAN
accelerator = new AcceleratorVulkan();
#endif
}

AcceleratorContext acc(args->acceleratorNominator, args->acceleratorDenominator, accelerator);
Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, socketPool, &acc);
AcceleratorManager acceleratorManager(args);
Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, socketPool, acceleratorManager.context);
socketPool->setTurbo(true);

Inference inference = Inference(&arch, args->nThreads, &transformer, socketPool);

Sampler sampler(spec.vocabSize, args->temperature, args->topp, args->seed);

program(&inference, socketPool, &tokenizer, &sampler, args, &spec, &acc);
program(&inference, socketPool, &tokenizer, &sampler, args, &spec, acceleratorManager.context);

delete socketPool;
if (accelerator != NULL)
delete accelerator;
}
8 changes: 8 additions & 0 deletions src/app.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ class AppArgs {
static AppArgs parse(int argc, char** argv, bool hasMode);
};

class AcceleratorManager {
public:
Accelerator* accelerator;
AcceleratorContext* context;
AcceleratorManager(AppArgs* args);
~AcceleratorManager();
};

class TransformerArchFactory {
public:
static TransformerArch create(TransformerSpec* spec);
Expand Down
4 changes: 2 additions & 2 deletions src/apps/dllama/dllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ void worker(AppArgs* args) {
SocketServer server(args->port);
Socket socket = server.accept();
TransformerSpec spec;
AcceleratorContext acc(0, 1, NULL);
Transformer transformer = Transformer::loadSlice(&spec, &socket, &acc);
AcceleratorManager acceleratorManager(args);
Transformer transformer = Transformer::loadSlice(&spec, &socket, acceleratorManager.context);
TransformerArch arch = TransformerArchFactory::create(&spec);

Worker worker = Worker(&arch, args->nThreads, &transformer, &socket);
Expand Down

0 comments on commit b3fca58

Please sign in to comment.