Skip to content

Commit

Permalink
feat: naive cache.
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz committed May 29, 2024
1 parent 5fee854 commit aaa0094
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 99 deletions.
7 changes: 4 additions & 3 deletions examples/chat-api-client.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ async function ask(system, user, maxTokens) {
content: user
}
], maxTokens);
console.log(`${response.choices[0].message.content}`);
console.log(response.usage);
console.log(response.choices[0].message.content);
}

async function main() {
await ask('You are an excellent math teacher.', 'What is 1 + 2?', 64);
await ask('You are a romantic.', 'Where is Europe?', 64);
await ask('You are an excellent math teacher.', 'What is 1 + 2?', 128);
await ask('You are a romantic.', 'Where is Europe?', 128);
}

main();
290 changes: 198 additions & 92 deletions src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cassert>
#include <sstream>
#include <iostream>
#include <algorithm>
#include <vector>

#ifdef _WIN32
Expand Down Expand Up @@ -41,10 +42,10 @@ class HttpRequest {

std::vector<char> httpRequest = socket.readHttpRequest();
// Parse the HTTP request
std::string request = std::string(httpRequest.begin(), httpRequest.end());
std::string data = std::string(httpRequest.begin(), httpRequest.end());

// Split request into lines
std::istringstream iss(request);
std::istringstream iss(data);
std::string line;
std::getline(iss, line);

Expand Down Expand Up @@ -73,7 +74,7 @@ class HttpRequest {
std::getline(iss, req.body, '\0');

if (req.body.size() > 0) {
// printf("body: %s\n", httpRequest.body.c_str());
// printf("body: %s\n", req.body.c_str());
req.parsedJson = json::parse(req.body);
}
return req;
Expand Down Expand Up @@ -199,118 +200,222 @@ void writeChatCompletionChunk(HttpRequest &request, const std::string &delta, co
}
}

void handleCompletionsRequest(HttpRequest& request, Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
InferenceParams params;
params.temperature = args->temperature;
params.top_p = args->topp;
params.seed = args->seed;
params.stream = false;
params.prompt = buildChatPrompt(tokenizer, parseChatMessages(request.parsedJson["messages"]));
params.max_tokens = spec->seqLen - params.prompt.size();

if (request.parsedJson.contains("stream")) {
params.stream = request.parsedJson["stream"].get<bool>();
class NaiveCacheItem {
public:
pos_t endPos;
ChatMessage message;
NaiveCacheItem(pos_t endPos, ChatMessage message) {
this->endPos = endPos;
this->message = message;
}
if (request.parsedJson.contains("temperature")) {
params.temperature = request.parsedJson["temperature"].template get<float>();
assert(params.temperature >= 0.0f);
sampler->setTemp(params.temperature);
};

class NaiveCache {
private:
std::vector<NaiveCacheItem> cache;
public:
void push(NaiveCacheItem item) {
cache.push_back(item);
}
if (request.parsedJson.contains("seed")) {
params.seed = request.parsedJson["seed"].template get<unsigned long long>();
sampler->setSeed(params.seed);

void clear() {
cache.clear();
}
if (request.parsedJson.contains("max_tokens")) {
params.max_tokens = request.parsedJson["max_tokens"].template get<int>();
assert(params.max_tokens <= spec->seqLen); //until rope scaling or similiar is implemented

bool resolveDeltaPrompt(std::vector<ChatMessage>& messages, pos_t& startPos) {
size_t cacheSize = cache.size();
if (cacheSize == 0)
return false;
if (messages.size() > cacheSize) {
size_t i = 0;
while (i < cacheSize) {
if (
cache[i].message.role != messages[i].role ||
cache[i].message.content != messages[i].content
) break;
i++;
}
if (i == cacheSize) {
startPos = cache[i - 1].endPos;
messages.erase(messages.begin(), messages.begin() + i);
printf("🐤 Found naive cache for %zu messages, pos=%d\n", i, startPos);
return true;
}
}
cache.clear();
return false;
}
if (request.parsedJson.contains("stop")) {
params.stop = request.parsedJson["stop"].template get<std::vector<std::string>>();
} else {
const std::string defaultStop = "<|eot_id|>";
params.stop = std::vector<std::string>{defaultStop};
};

class ApiServer {
private:
Inference* inference;
Tokenizer* tokenizer;
Sampler* sampler;
AppArgs* args;
TransformerSpec* spec;
NaiveCache naiveCache;

public:
ApiServer(
Inference* inference,
Tokenizer* tokenizer,
Sampler* sampler,
AppArgs* args,
TransformerSpec* spec) {
this->inference = inference;
this->tokenizer = tokenizer;
this->sampler = sampler;
this->args = args;
this->spec = spec;
}

printf("🔸");
fflush(stdout);
void complete(HttpRequest& request) {
InferenceParams params = parseRequest(request);

//Process the chat completion request
std::vector<std::string> generated;
generated.get_allocator().allocate(params.max_tokens);
pos_t startPos = 0;
std::vector<ChatMessage> deltaPrompt = params.messages;
naiveCache.resolveDeltaPrompt(deltaPrompt, startPos);

if (params.stream) {
request.writeStreamStartChunk();
}
printf("🔸");
fflush(stdout);

int promptLength = params.prompt.length();
int nPromptTokens;
int promptTokens[promptLength + 3];
char prompt[promptLength + 1];
prompt[promptLength] = 0;
strcpy(prompt, params.prompt.c_str());
tokenizer->encode(prompt, promptTokens, &nPromptTokens, true, false);

int token = promptTokens[0];
pos_t maxPos = nPromptTokens + params.max_tokens;
if (maxPos > spec->seqLen) maxPos = spec->seqLen;
bool eosEncountered = false;
for (pos_t pos = 0; pos < maxPos; pos++) {
float* logits = inference->infer(token, pos);

if (pos < nPromptTokens - 1) {
token = promptTokens[pos + 1];
} else {
int prevToken = token;
token = sampler->sample(logits);
std::string inputPrompt = buildChatPrompt(tokenizer, deltaPrompt);
int promptLength = inputPrompt.size();
int nPromptTokens;
int promptTokens[promptLength + 3];
char prompt[promptLength + 1];
prompt[promptLength] = 0;
strcpy(prompt, inputPrompt.c_str());
tokenizer->encode(prompt, promptTokens, &nPromptTokens, true, false);
int promptEndPos = startPos + nPromptTokens;

for (size_t j = 0; j < deltaPrompt.size(); j++) {
naiveCache.push(NaiveCacheItem(promptEndPos, deltaPrompt[j]));
}

pos_t maxPos = params.max_tokens > 0 ? (promptEndPos + params.max_tokens) : spec->seqLen;
if (maxPos > spec->seqLen) maxPos = spec->seqLen;

if (params.stream) {
request.writeStreamStartChunk();
}

std::string delta;
std::string buffer;
size_t nStops = params.stop.size();

if (token == tokenizer->eosId) eosEncountered = true;
int token = promptTokens[0];
pos_t pos = startPos;
for (; pos < maxPos; pos++) {
float* logits = inference->infer(token, pos);

char* piece = tokenizer->decode(prevToken, token);
if (pos < promptEndPos - 1) {
token = promptTokens[pos - startPos + 1];
} else {
int prevToken = token;
token = sampler->sample(logits);

bool safePiece = isSafePiece(piece);

if (!params.stop.empty() && safePiece) {
std::string concatenatedTokens;
int startIndex = std::max(0, static_cast<int>(generated.size()) - 7);
for (int i = startIndex; i < generated.size(); ++i) {
concatenatedTokens += generated[i];
if (token == tokenizer->eosId) {
printf("🔴");
break;
}
concatenatedTokens += std::string(piece);

for (const auto& word : params.stop) {
if (concatenatedTokens.find(word) != std::string::npos) {
eosEncountered = true;
char* piece = tokenizer->decode(prevToken, token);

if (isSafePiece(piece)) {
printf("%s", piece);
fflush(stdout);
delta += piece;
}

bool maybeEos = false;
size_t deltaSize = delta.size();
if (nStops > 0 && deltaSize > 0) {
bool eos = false;
for (size_t s = 0; s < nStops; s++) {
size_t stopSize = params.stop[s].size();
if (params.stop[s].compare(0, deltaSize, delta) == 0) {
if (stopSize <= deltaSize) {
eos = true;
break;
} else {
maybeEos = true;
break;
}
}
}
if (eos) {
printf("");
break;
}
}
}

if (eosEncountered) break;
if (!maybeEos) {
if (params.stream)
writeChatCompletionChunk(request, delta, false);
buffer += delta;
delta.clear();
}
}
}

std::string string = std::string(piece);
safePrintf(piece);
fflush(stdout);
ChatMessage chatMessage("assistant", buffer);
if (pos == spec->seqLen) {
naiveCache.clear();
} else {
naiveCache.push(NaiveCacheItem(pos, chatMessage));
}

generated.push_back(string);
if (params.stream) {
writeChatCompletionChunk(request, string, false);
}
if (params.stream) {
writeChatCompletionChunk(request, "", true);
} else {
int nCompletionTokens = pos - promptEndPos;
ChatUsage usage(nPromptTokens, nCompletionTokens, nPromptTokens + nCompletionTokens);
Choice choice(chatMessage);
ChatCompletion completion(choice, usage);
std::string chatJson = ((json)completion).dump();
request.writeJson(chatJson);
}
printf("🔶\n");
fflush(stdout);
}

if (!params.stream) {
ChatMessage chatMessage = ChatMessage("assistant", std::accumulate(generated.begin(), generated.end(), std::string("")));
Choice responseChoice = Choice(chatMessage);
ChatCompletion completion = ChatCompletion(responseChoice);
completion.usage = ChatUsage(nPromptTokens, generated.size(), nPromptTokens + generated.size());

std::string chatJson = ((json)completion).dump();
request.writeJson(chatJson);
} else {
writeChatCompletionChunk(request, "", true);
private:
InferenceParams parseRequest(HttpRequest& request) {
InferenceParams params;
params.temperature = args->temperature;
params.top_p = args->topp;
params.seed = args->seed;
params.stream = false;
params.messages = parseChatMessages(request.parsedJson["messages"]);
params.max_tokens = -1;

if (request.parsedJson.contains("stream")) {
params.stream = request.parsedJson["stream"].get<bool>();
}
if (request.parsedJson.contains("temperature")) {
params.temperature = request.parsedJson["temperature"].template get<float>();
}
if (request.parsedJson.contains("seed")) {
params.seed = request.parsedJson["seed"].template get<unsigned long long>();
sampler->setSeed(params.seed);
}
if (request.parsedJson.contains("max_tokens")) {
params.max_tokens = request.parsedJson["max_tokens"].template get<int>();
}
if (request.parsedJson.contains("stop")) {
params.stop = request.parsedJson["stop"].template get<std::vector<std::string>>();
} else {
const std::string defaultStop = "<|eot_id|>";
params.stop = std::vector<std::string>{defaultStop};
}
return params;
}
printf("🔶\n");
fflush(stdout);
};

void handleCompletionsRequest(HttpRequest& request, ApiServer* api) {
api->complete(request);
}

void handleModelsRequest(HttpRequest& request) {
Expand All @@ -324,12 +429,13 @@ void handleModelsRequest(HttpRequest& request) {
void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
SocketServer* server = new SocketServer(args->port);
printf("Server URL: http://127.0.0.1:%d/v1/\n", args->port);
ApiServer api(inference, tokenizer, sampler, args, spec);

std::vector<Route> routes = {
{
"/v1/chat/completions",
HttpMethod::METHOD_POST,
std::bind(&handleCompletionsRequest, std::placeholders::_1, inference, tokenizer, sampler, args, spec)
std::bind(&handleCompletionsRequest, std::placeholders::_1, &api)
},
{
"/v1/models",
Expand Down
Loading

0 comments on commit aaa0094

Please sign in to comment.