From ef1e312b571de746a618d0ac6f2f8fc90904f3dd Mon Sep 17 00:00:00 2001 From: Bart Tadych Date: Sun, 26 May 2024 11:36:07 +0200 Subject: [PATCH] fix: chunked stream, close stream without econnreset. (#65) --- src/apps/dllama-api/dllama-api.cpp | 176 ++++++++++++++++------------- src/socket.cpp | 3 +- 2 files changed, 96 insertions(+), 83 deletions(-) diff --git a/src/apps/dllama-api/dllama-api.cpp b/src/apps/dllama-api/dllama-api.cpp index c339a12..624b0b3 100644 --- a/src/apps/dllama-api/dllama-api.cpp +++ b/src/apps/dllama-api/dllama-api.cpp @@ -85,14 +85,10 @@ class HttpParser { // Parse body std::getline(iss, httpRequest.body, '\0'); - // Parse JSON if Content-Type is application/json - httpRequest.parsedJson = json::object(); - if(httpRequest.headers.find("Content-Type") != httpRequest.headers.end()){ - if(httpRequest.headers["Content-Type"] == "application/json"){ - httpRequest.parsedJson = json::parse(httpRequest.body); - } + if (httpRequest.body.size() > 0) { + // printf("body: %s\n", httpRequest.body.c_str()); + httpRequest.parsedJson = json::parse(httpRequest.body); } - return httpRequest; } private: @@ -189,6 +185,14 @@ void to_json(json& j, const Choice& choice) { j = json{{"index", choice.index}, {"message", choice.message}, {"finish_reason", choice.finish_reason}}; } +std::string createJsonResponse(std::string json) { + std::ostringstream oss; + oss << "HTTP/1.1 200 OK\r\n" + << "Content-Type: application/json; charset=utf-8\r\n" + << "Content-Length: " << json.length() << "\r\n\r\n" << json; + return oss.str(); +} + struct ChatCompletionChunk { std::string id; std::string object; @@ -197,7 +201,7 @@ struct ChatCompletionChunk { std::vector choices; ChatCompletionChunk(ChunkChoice &choice_) - : id("chatcmpl-test"), object("chat.completion"), model("Distributed Model") { + : id("chatcmpl-test"), object("chat.completion"), model("dl") { created = std::time(nullptr); // Set created to current Unix timestamp choices.push_back(choice_); } @@ -231,7 +235,7 @@ struct ChatCompletion { ChatUsage usage; ChatCompletion(Choice &choice_) - : id("chatcmpl-test"), object("chat.completion"), model("Distributed Model") { + : id("chatcmpl-test"), object("chat.completion"), model("dl") { created = std::time(nullptr); // Set created to current Unix timestamp choices.push_back(choice_); } @@ -288,90 +292,93 @@ std::string buildChatPrompt(Tokenizer *tokenizer, const std::vector return oss.str(); } -void outputChatCompletionChunk(Socket &client_socket, const std::string &delta, const std::string &finish_reason){ - ChunkChoice choice; - - if(finish_reason.size() > 0){ - choice.finish_reason = finish_reason; +void writeChunk(Socket& socket, const std::string data, const bool stop) { + std::ostringstream formattedChunk; + formattedChunk << std::hex << data.size() << "\r\n" << data << "\r\n"; + if (stop) { + formattedChunk << "0000\r\n\r\n"; } - else{ + socket.write(formattedChunk.str().c_str(), formattedChunk.str().size()); +} + +void writeChatCompletionChunk(Socket &client_socket, const std::string &delta, const bool stop){ + ChunkChoice choice; + if (stop) { + choice.finish_reason = "stop"; + } else { choice.delta = ChatMessageDelta("assistant", delta); } - ChatCompletionChunk chunk = ChatCompletionChunk(choice); - - std::ostringstream oss; - - oss << "data: " << ((json)chunk).dump() << "\n\n"; - if(finish_reason.size() > 0){ - oss << "data: [DONE]\n\n"; - } - - std::string chunkResponse = oss.str(); - - // Format the chunked response - std::ostringstream formattedChunk; - formattedChunk << std::hex << chunkResponse.length() << "\r\n" << chunkResponse << "\r\n"; + std::ostringstream buffer; + buffer << "data: " << ((json)chunk).dump() << "\r\n\r\n"; + writeChunk(client_socket, buffer.str(), false); - client_socket.write(formattedChunk.str().c_str(), formattedChunk.str().length()); + if (stop) { + writeChunk(client_socket, "data: [DONE]", true); + } } -void handleCompletionsRequest(Socket& client_socket, HttpRequest& request, Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) { - printf("Handling Completion Request\n"); +void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) { // Set inference arguments - InferenceParams inferParams; - inferParams.temperature = args->temperature; - inferParams.top_p = args->topp; - inferParams.seed = args->seed; - inferParams.stream = false; - inferParams.prompt = buildChatPrompt(tokenizer, parseChatMessages(request.parsedJson["messages"])); - inferParams.max_tokens = spec->seqLen - inferParams.prompt.size(); - - if(request.parsedJson.contains("stream")){ - inferParams.stream = request.parsedJson["stream"].template get(); + InferenceParams params; + params.temperature = args->temperature; + params.top_p = args->topp; + params.seed = args->seed; + params.stream = false; + params.prompt = buildChatPrompt(tokenizer, parseChatMessages(request.parsedJson["messages"])); + params.max_tokens = spec->seqLen - params.prompt.size(); + + if (request.parsedJson.contains("stream")) { + params.stream = request.parsedJson["stream"].get(); } - if(request.parsedJson.contains("temperature")){ - inferParams.temperature = request.parsedJson["temperature"].template get(); - assert(inferParams.temperature >= 0.0f); - sampler->setTemp(inferParams.temperature); + if (request.parsedJson.contains("temperature")) { + params.temperature = request.parsedJson["temperature"].template get(); + assert(params.temperature >= 0.0f); + sampler->setTemp(params.temperature); } - if(request.parsedJson.contains("seed")){ - inferParams.seed = request.parsedJson["seed"].template get(); - sampler->setSeed(inferParams.seed); + if (request.parsedJson.contains("seed")) { + params.seed = request.parsedJson["seed"].template get(); + sampler->setSeed(params.seed); } - if(request.parsedJson.contains("max_tokens")){ - inferParams.max_tokens = request.parsedJson["max_tokens"].template get(); - assert(inferParams.max_tokens <= spec->seqLen); //until rope scaling or similiar is implemented + if (request.parsedJson.contains("max_tokens")) { + params.max_tokens = request.parsedJson["max_tokens"].template get(); + assert(params.max_tokens <= spec->seqLen); //until rope scaling or similiar is implemented } - if(request.parsedJson.contains("stop")){ - inferParams.stop = request.parsedJson["stop"].template get>(); + if (request.parsedJson.contains("stop")) { + params.stop = request.parsedJson["stop"].template get>(); + } else { + const std::string defaultStop = "<|eot_id|>"; + params.stop = std::vector{defaultStop}; } + printf("🔸"); + fflush(stdout); + //Process the chat completion request std::vector generated; - generated.get_allocator().allocate(inferParams.max_tokens); + generated.get_allocator().allocate(params.max_tokens); - if (inferParams.stream) { + if (params.stream) { std::ostringstream oss; oss << "HTTP/1.1 200 OK\r\n" << "Content-Type: text/event-stream; charset=utf-8\r\n" - << "Connection: keep-alive\r\n" + << "Connection: close\r\n" << "Transfer-Encoding: chunked\r\n\r\n"; - client_socket.write(oss.str().c_str(), oss.str().length()); + socket.write(oss.str().c_str(), oss.str().length()); } - int promptLength = inferParams.prompt.length(); + int promptLength = params.prompt.length(); int nPromptTokens; int promptTokens[promptLength + 3]; char prompt[promptLength + 1]; prompt[promptLength] = 0; - strcpy(prompt, inferParams.prompt.c_str()); + strcpy(prompt, params.prompt.c_str()); tokenizer->encode(prompt, promptTokens, &nPromptTokens, true, false); int token = promptTokens[0]; - pos_t maxPos = nPromptTokens + inferParams.max_tokens; + pos_t maxPos = nPromptTokens + params.max_tokens; if (maxPos > spec->seqLen) maxPos = spec->seqLen; bool eosEncountered = false; for (pos_t pos = 0; pos < maxPos; pos++) { @@ -390,7 +397,7 @@ void handleCompletionsRequest(Socket& client_socket, HttpRequest& request, Infer bool safePiece = isSafePiece(piece); - if (!inferParams.stop.empty() && safePiece) { + if (!params.stop.empty() && safePiece) { std::string concatenatedTokens; int startIndex = std::max(0, static_cast(generated.size()) - 7); for (int i = startIndex; i < generated.size(); ++i) { @@ -398,7 +405,7 @@ void handleCompletionsRequest(Socket& client_socket, HttpRequest& request, Infer } concatenatedTokens += std::string(piece); - for (const auto& word : inferParams.stop) { + for (const auto& word : params.stop) { if (concatenatedTokens.find(word) != std::string::npos) { eosEncountered = true; break; @@ -409,49 +416,56 @@ void handleCompletionsRequest(Socket& client_socket, HttpRequest& request, Infer if (eosEncountered) break; std::string string = std::string(piece); - - //char string[100]; - //strcpy(string, piece); safePrintf(piece); + fflush(stdout); generated.push_back(string); - if (inferParams.stream) { - outputChatCompletionChunk(client_socket, string, ""); + if (params.stream) { + writeChatCompletionChunk(socket, string, false); } } } - if (!inferParams.stream) { + if (!params.stream) { ChatMessage chatMessage = ChatMessage("assistant", std::accumulate(generated.begin(), generated.end(), std::string(""))); Choice responseChoice = Choice(chatMessage); ChatCompletion completion = ChatCompletion(responseChoice); completion.usage = ChatUsage(nPromptTokens, generated.size(), nPromptTokens + generated.size()); std::string chatJson = ((json)completion).dump(); - - std::ostringstream oss; - - oss << "HTTP/1.1 200 OK\r\n" - << "Content-Type: application/json; charset=utf-8\r\n" - << "Content-Length: " << chatJson.length() << "\r\n\r\n" << chatJson; - - std::string response = oss.str(); - - client_socket.write(response.c_str(), response.length()); + std::string response = createJsonResponse(chatJson); + socket.write(response.c_str(), response.length()); } else { - outputChatCompletionChunk(client_socket, "", "stop"); + writeChatCompletionChunk(socket, "", true); } + printf("🔶\n"); + fflush(stdout); +} + +void handleModelsRequest(Socket& client_socket, HttpRequest& request) { + std::string response = createJsonResponse( + "{ \"object\": \"list\"," + "\"data\": [" + "{ \"id\": \"dl\", \"object\": \"model\", \"created\": 0, \"owned_by\": \"user\" }" + "] }"); + client_socket.write(response.c_str(), response.length()); } void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) { SocketServer* server = new SocketServer(args->port); + printf("Server URL: http://127.0.0.1:%d/v1/\n", args->port); std::vector routes = { { "/v1/chat/completions", HttpMethod::METHOD_POST, std::bind(&handleCompletionsRequest, std::placeholders::_1, std::placeholders::_2, inference, tokenizer, sampler, args, spec) + }, + { + "/v1/models", + HttpMethod::METHOD_GET, + std::bind(&handleModelsRequest, std::placeholders::_1, std::placeholders::_2) } }; @@ -464,7 +478,7 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, // Parse the HTTP request HttpRequest request = HttpParser::parseRequest(std::string(httpRequest.begin(), httpRequest.end())); // Handle the HTTP request - printf("New Request: %s %s\n", request.getMethod().c_str(), request.path.c_str()); + printf("🔷 %s %s\n", request.getMethod().c_str(), request.path.c_str()); Router::routeRequest(client, request, routes); } catch (ReadSocketException& ex) { printf("Read socket error: %d %s\n", ex.code, ex.message); diff --git a/src/socket.cpp b/src/socket.cpp index 94500e4..4f944c5 100644 --- a/src/socket.cpp +++ b/src/socket.cpp @@ -230,7 +230,6 @@ Socket SocketServer::accept() { throw std::runtime_error("Error accepting connection"); setNoDelay(clientSocket); setQuickAck(clientSocket); - printf("Client connected\n"); return Socket(clientSocket); } @@ -261,7 +260,7 @@ bool Socket::tryRead(void* data, size_t size, unsigned long maxAttempts) { std::vector Socket::readHttpRequest() { std::vector httpRequest; - char buffer[1024]; // Initial buffer size + char buffer[1024 * 1024]; // TODO: this should be refactored asap ssize_t bytesRead; // Peek into the socket buffer to check available data