Skip to content

Commit eaba8ca

Browse files
update to LLMResponse for streaming
1 parent ddb3a7c commit eaba8ca

File tree

6 files changed

+22
-22
lines changed

6 files changed

+22
-22
lines changed

core/ai_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ func (a *stubAIWorker) SegmentAnything2(ctx context.Context, req worker.GenSegme
653653

654654
func (a *stubAIWorker) LLM(ctx context.Context, req worker.GenLLMJSONRequestBody) (interface{}, error) {
655655
var choices []worker.LLMChoice
656-
choices = append(choices, worker.LLMChoice{Delta: worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0})
656+
choices = append(choices, worker.LLMChoice{Delta: &worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0})
657657
tokensUsed := worker.LLMTokenUsage{PromptTokens: 40, CompletionTokens: 10, TotalTokens: 50}
658658
return &worker.LLMResponse{Choices: choices, Created: 1, Model: "llm_model", TokensUsed: tokensUsed}, nil
659659
}

server/ai_http.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
586586
}
587587

588588
// Check if the response is a streaming response
589-
if streamChan, ok := resp.(<-chan worker.LlmStreamChunk); ok {
589+
if streamChan, ok := resp.(<-chan *worker.LLMResponse); ok {
590590
glog.Infof("Streaming response for request id=%v", requestID)
591591

592592
// Set headers for SSE
@@ -610,7 +610,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
610610
fmt.Fprintf(w, "data: %s\n\n", data)
611611
flusher.Flush()
612612

613-
if chunk.Done {
613+
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {
614614
break
615615
}
616616
}
@@ -683,8 +683,8 @@ func (h *lphttp) AIResults() http.Handler {
683683
case "text/event-stream":
684684
resultType = "streaming"
685685
glog.Infof("Received %s response from remote worker=%s taskId=%d", resultType, r.RemoteAddr, tid)
686-
resChan := make(chan worker.LlmStreamChunk, 100)
687-
workerResult.Results = (<-chan worker.LlmStreamChunk)(resChan)
686+
resChan := make(chan *worker.LLMResponse, 100)
687+
workerResult.Results = (<-chan *worker.LLMResponse)(resChan)
688688

689689
defer r.Body.Close()
690690
defer close(resChan)
@@ -703,12 +703,12 @@ func (h *lphttp) AIResults() http.Handler {
703703
line := scanner.Text()
704704
if strings.HasPrefix(line, "data: ") {
705705
data := strings.TrimPrefix(line, "data: ")
706-
var chunk worker.LlmStreamChunk
706+
var chunk worker.LLMResponse
707707
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
708708
clog.Errorf(ctx, "Error unmarshaling stream data: %v", err)
709709
continue
710710
}
711-
resChan <- chunk
711+
resChan <- &chunk
712712
}
713713
}
714714
}

server/ai_mediaserver.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ func (ls *LivepeerServer) LLM() http.Handler {
289289
took := time.Since(start)
290290
clog.V(common.VERBOSE).Infof(ctx, "Processed LLM request model_id=%v took=%v", *req.Model, took)
291291

292-
if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok {
292+
if streamChan, ok := resp.(chan *worker.LLMResponse); ok {
293293
// Handle streaming response (SSE)
294294
w.Header().Set("Content-Type", "text/event-stream")
295295
w.Header().Set("Cache-Control", "no-cache")
@@ -299,7 +299,7 @@ func (ls *LivepeerServer) LLM() http.Handler {
299299
data, _ := json.Marshal(chunk)
300300
fmt.Fprintf(w, "data: %s\n\n", data)
301301
w.(http.Flusher).Flush()
302-
if chunk.Done {
302+
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {
303303
break
304304
}
305305
}

server/ai_process.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,7 +1106,7 @@ func processLLM(ctx context.Context, params aiRequestParams, req worker.GenLLMJS
11061106
}
11071107

11081108
if req.Stream != nil && *req.Stream {
1109-
streamChan, ok := resp.(chan worker.LlmStreamChunk)
1109+
streamChan, ok := resp.(chan *worker.LLMResponse)
11101110
if !ok {
11111111
return nil, errors.New("unexpected response type for streaming request")
11121112
}
@@ -1166,36 +1166,36 @@ func submitLLM(ctx context.Context, params aiRequestParams, sess *AISession, req
11661166
return handleNonStreamingResponse(ctx, resp.Body, sess, req, start)
11671167
}
11681168

1169-
func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMJSONRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) {
1170-
streamChan := make(chan worker.LlmStreamChunk, 100)
1169+
func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLLMJSONRequestBody, start time.Time) (chan *worker.LLMResponse, error) {
1170+
streamChan := make(chan *worker.LLMResponse, 100)
11711171
go func() {
11721172
defer close(streamChan)
11731173
defer body.Close()
11741174
scanner := bufio.NewScanner(body)
1175-
var totalTokens int
1175+
var totalTokens worker.LLMTokenUsage
11761176
for scanner.Scan() {
11771177
line := scanner.Text()
11781178
if strings.HasPrefix(line, "data: ") {
11791179
data := strings.TrimPrefix(line, "data: ")
11801180
if data == "[DONE]" {
1181-
streamChan <- worker.LlmStreamChunk{Done: true, TokensUsed: totalTokens}
1181+
//streamChan <- worker.LLMResponse{Done: true, TokensUsed: totalTokens}
11821182
break
11831183
}
1184-
var chunk worker.LlmStreamChunk
1184+
var chunk worker.LLMResponse
11851185
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
11861186
clog.Errorf(ctx, "Error unmarshaling SSE data: %v", err)
11871187
continue
11881188
}
1189-
totalTokens += chunk.TokensUsed
1190-
streamChan <- chunk
1189+
totalTokens = chunk.TokensUsed
1190+
streamChan <- &chunk
11911191
}
11921192
}
11931193
if err := scanner.Err(); err != nil {
11941194
clog.Errorf(ctx, "Error reading SSE stream: %v", err)
11951195
}
11961196

11971197
took := time.Since(start)
1198-
sess.LatencyScore = CalculateLLMLatencyScore(took, totalTokens)
1198+
sess.LatencyScore = CalculateLLMLatencyScore(took, totalTokens.TotalTokens)
11991199

12001200
if monitor.Enabled {
12011201
var pricePerAIUnit float64

server/ai_worker.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ func runAIJob(n *core.LivepeerNode, orchAddr string, httpc *http.Client, notify
354354

355355
if resp != nil {
356356
if resultType == "text/event-stream" {
357-
streamChan, ok := resp.(<-chan worker.LlmStreamChunk)
357+
streamChan, ok := resp.(<-chan *worker.LLMResponse)
358358
if ok {
359359
sendStreamingAIResult(ctx, n, orchAddr, notify.AIJobData.Pipeline, httpc, resultType, streamChan)
360360
return
@@ -530,7 +530,7 @@ func sendAIResult(ctx context.Context, n *core.LivepeerNode, orchAddr string, pi
530530
}
531531

532532
func sendStreamingAIResult(ctx context.Context, n *core.LivepeerNode, orchAddr string, pipeline string, httpc *http.Client,
533-
contentType string, streamChan <-chan worker.LlmStreamChunk,
533+
contentType string, streamChan <-chan *worker.LLMResponse,
534534
) {
535535
clog.Infof(ctx, "sending streaming results back to Orchestrator")
536536
taskId := clog.GetVal(ctx, "taskId")
@@ -571,7 +571,7 @@ func sendStreamingAIResult(ctx context.Context, n *core.LivepeerNode, orchAddr s
571571
}
572572
fmt.Fprintf(pWriter, "data: %s\n\n", data)
573573

574-
if chunk.Done {
574+
if chunk.Choices[0].FinishReason != nil && *chunk.Choices[0].FinishReason != "" {
575575
pWriter.Close()
576576
clog.Infof(ctx, "streaming results finished")
577577
return

server/ai_worker_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ func (a *stubAIWorker) LLM(ctx context.Context, req worker.GenLLMJSONRequestBody
606606
return nil, a.Err
607607
} else {
608608
var choices []worker.LLMChoice
609-
choices = append(choices, worker.LLMChoice{Delta: worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0})
609+
choices = append(choices, worker.LLMChoice{Delta: &worker.LLMMessage{Content: "choice1", Role: "assistant"}, Index: 0})
610610
tokensUsed := worker.LLMTokenUsage{PromptTokens: 40, CompletionTokens: 10, TotalTokens: 50}
611611
return &worker.LLMResponse{Choices: choices, Created: 1, Model: "llm_model", TokensUsed: tokensUsed}, nil
612612
}

0 commit comments

Comments
 (0)