From 399a1d226cce4774b7d6054364d77b67d81c7c19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Musia=C5=82?= Date: Thu, 28 Mar 2024 16:01:25 +0100 Subject: [PATCH 1/3] feat(openai-assistant): assistants streaming events added --- .../chat/chat-audio/chat-audio.component.ts | 5 +- .../chat-message/chat-message.component.ts | 5 +- .../chat-iframe/chat-iframe.component.html | 26 +- .../+chat/shared/chat-gateway.service.ts | 4 +- .../app/modules/+chat/shared/chat.model.ts | 17 +- .../app/modules/+chat/shared/chat.service.ts | 1 - .../modules/+chat/shared/thread.service.ts | 19 +- .../src/lib/agent/agent.mock.ts | 4 +- .../src/lib/agent/agent.service.ts | 5 +- .../src/lib/chat/chat.gateway.spec.ts | 2 +- .../src/lib/chat/chat.gateway.ts | 236 +++++++++++++++++- .../src/lib/chat/chat.helpers.spec.ts | 28 +-- .../src/lib/chat/chat.helpers.ts | 5 +- .../src/lib/chat/chat.model.ts | 106 +++++++- .../src/lib/chat/chat.service.spec.ts | 15 +- .../src/lib/chat/chat.service.ts | 35 ++- .../src/lib/run/run.service.spec.ts | 15 +- .../src/lib/run/run.service.ts | 29 ++- .../src/lib/stream/stream.utils.ts | 68 +++++ .../src/lib/threads/threads.model.ts | 5 +- .../src/lib/threads/threads.service.spec.ts | 6 +- .../src/lib/threads/threads.service.ts | 5 +- nx.json | 20 +- package.json | 1 - 24 files changed, 544 insertions(+), 118 deletions(-) create mode 100644 libs/openai-assistant/src/lib/stream/stream.utils.ts diff --git a/apps/spa/src/app/components/chat/chat-audio/chat-audio.component.ts b/apps/spa/src/app/components/chat/chat-audio/chat-audio.component.ts index 9a81544..3d1dec7 100644 --- a/apps/spa/src/app/components/chat/chat-audio/chat-audio.component.ts +++ b/apps/spa/src/app/components/chat/chat-audio/chat-audio.component.ts @@ -1,6 +1,9 @@ import { Component, Input, OnInit } from '@angular/core'; import { ChatClientService } from '../../../modules/+chat/shared/chat-client.service'; -import { ChatMessage, SpeechVoice } from '../../../modules/+chat/shared/chat.model'; +import { + ChatMessage, + SpeechVoice, +} from '../../../modules/+chat/shared/chat.model'; import { environment } from '../../../../environments/environment'; import { MatIconModule } from '@angular/material/icon'; import { delay } from 'rxjs'; diff --git a/apps/spa/src/app/components/chat/chat-message/chat-message.component.ts b/apps/spa/src/app/components/chat/chat-message/chat-message.component.ts index f655da2..cdb0d9c 100644 --- a/apps/spa/src/app/components/chat/chat-message/chat-message.component.ts +++ b/apps/spa/src/app/components/chat/chat-message/chat-message.component.ts @@ -1,5 +1,8 @@ import { Component, HostBinding, Input } from '@angular/core'; -import { ChatRole, ChatMessage } from '../../../modules/+chat/shared/chat.model'; +import { + ChatRole, + ChatMessage, +} from '../../../modules/+chat/shared/chat.model'; import { MarkdownComponent } from 'ngx-markdown'; import { ChatAudioComponent } from '../chat-audio/chat-audio.component'; import { NgClass } from '@angular/common'; diff --git a/apps/spa/src/app/modules/+chat/containers/chat-iframe/chat-iframe.component.html b/apps/spa/src/app/modules/+chat/containers/chat-iframe/chat-iframe.component.html index fb905bc..22b00d6 100644 --- a/apps/spa/src/app/modules/+chat/containers/chat-iframe/chat-iframe.component.html +++ b/apps/spa/src/app/modules/+chat/containers/chat-iframe/chat-iframe.component.html @@ -8,19 +8,19 @@ @if (isConfigEnabled && !threadId()) { - + } @else { - - + + } diff --git a/apps/spa/src/app/modules/+chat/shared/chat-gateway.service.ts b/apps/spa/src/app/modules/+chat/shared/chat-gateway.service.ts index 6e95beb..a3b0efc 100644 --- a/apps/spa/src/app/modules/+chat/shared/chat-gateway.service.ts +++ b/apps/spa/src/app/modules/+chat/shared/chat-gateway.service.ts @@ -10,12 +10,12 @@ export class ChatGatewayService { private socket = io(environment.websocketUrl); sendMessage(payload: ChatCallDto): void { - this.socket.emit(ChatEvents.SendMessage, payload); + this.socket.emit(ChatEvents.CallStart, payload); } getMessages(): Observable { return new Observable(observer => { - this.socket.on(ChatEvents.MessageReceived, data => observer.next(data)); + this.socket.on(ChatEvents.CallDone, data => observer.next(data)); return () => this.socket.disconnect(); }); } diff --git a/apps/spa/src/app/modules/+chat/shared/chat.model.ts b/apps/spa/src/app/modules/+chat/shared/chat.model.ts index 5fbde5a..c691ccf 100644 --- a/apps/spa/src/app/modules/+chat/shared/chat.model.ts +++ b/apps/spa/src/app/modules/+chat/shared/chat.model.ts @@ -15,8 +15,21 @@ export interface ChatMessage { } export enum ChatEvents { - SendMessage = 'send_message', - MessageReceived = 'message_received', + CallStart = 'callStart', + CallDone = 'callDone', + TextDone = 'textCreated', + TextCreated = 'textDelta', + TextDelta = 'textDone', + MessageCreated = 'messageCreated', + MessageDelta = 'messageDelta', + MessageDone = 'messageDone', + ImageFileDone = 'imageFileDone', + ToolCallCreated = 'toolCallCreated', + ToolCallDelta = 'toolCallDelta', + ToolCallDone = 'toolCallDone', + RunStepCreated = 'runStepCreated', + RunStepDelta = 'runStepDelta', + RunStepDone = 'runStepDone', } export enum ChatMessageStatus { diff --git a/apps/spa/src/app/modules/+chat/shared/chat.service.ts b/apps/spa/src/app/modules/+chat/shared/chat.service.ts index 56a5cb3..4830a7f 100644 --- a/apps/spa/src/app/modules/+chat/shared/chat.service.ts +++ b/apps/spa/src/app/modules/+chat/shared/chat.service.ts @@ -19,7 +19,6 @@ import { OpenAiFile, GetThreadResponseDto } from '@boldare/openai-assistant'; import { Message } from 'openai/resources/beta/threads/messages'; import { TextContentBlock } from 'openai/resources/beta/threads/messages/messages'; - @Injectable({ providedIn: 'root' }) export class ChatService { isLoading$ = new BehaviorSubject(false); diff --git a/apps/spa/src/app/modules/+chat/shared/thread.service.ts b/apps/spa/src/app/modules/+chat/shared/thread.service.ts index b097e4c..c928b36 100644 --- a/apps/spa/src/app/modules/+chat/shared/thread.service.ts +++ b/apps/spa/src/app/modules/+chat/shared/thread.service.ts @@ -1,5 +1,12 @@ import { Injectable } from '@angular/core'; -import { BehaviorSubject, catchError, Observable, Subject, take, tap } from 'rxjs'; +import { + BehaviorSubject, + catchError, + Observable, + Subject, + take, + tap, +} from 'rxjs'; import { environment } from '../../../../environments/environment'; import { ThreadClientService } from './thread-client.service'; import { ConfigurationFormService } from '../../+configuration/shared/configuration-form.service'; @@ -43,11 +50,9 @@ export class ThreadService { } getThread(id: string): Observable { - return this.threadClientService - .getThread(id) - .pipe( - take(1), - catchError(() => this.start()), - ); + return this.threadClientService.getThread(id).pipe( + take(1), + catchError(() => this.start()), + ); } } diff --git a/libs/openai-assistant/src/lib/agent/agent.mock.ts b/libs/openai-assistant/src/lib/agent/agent.mock.ts index 083bb0e..5cac9fa 100644 --- a/libs/openai-assistant/src/lib/agent/agent.mock.ts +++ b/libs/openai-assistant/src/lib/agent/agent.mock.ts @@ -1,10 +1,10 @@ -import { AssistantCreateParams } from 'openai/resources/beta'; +import { FunctionTool } from 'openai/resources/beta'; export const agentNameMock = 'agent-name'; export const agentMock = async () => 'agent-result'; -export const definitionMock: AssistantCreateParams.AssistantToolsFunction = { +export const definitionMock: FunctionTool = { type: 'function', function: { name: agentNameMock }, }; diff --git a/libs/openai-assistant/src/lib/agent/agent.service.ts b/libs/openai-assistant/src/lib/agent/agent.service.ts index 3a9cdf5..bf01b25 100644 --- a/libs/openai-assistant/src/lib/agent/agent.service.ts +++ b/libs/openai-assistant/src/lib/agent/agent.service.ts @@ -7,10 +7,7 @@ export class AgentService { public agents: Agents = {}; public tools: FunctionTool[] = []; - add( - definition: FunctionTool, - fn: Agent, - ): void { + add(definition: FunctionTool, fn: Agent): void { this.tools.push(definition); this.agents[definition.function.name] = fn; } diff --git a/libs/openai-assistant/src/lib/chat/chat.gateway.spec.ts b/libs/openai-assistant/src/lib/chat/chat.gateway.spec.ts index 9c6a60f..ed00cb1 100644 --- a/libs/openai-assistant/src/lib/chat/chat.gateway.spec.ts +++ b/libs/openai-assistant/src/lib/chat/chat.gateway.spec.ts @@ -32,7 +32,7 @@ describe('ChatGateway', () => { await chatGateway.listenForMessages(request, {} as Socket); - expect(chatService.call).toHaveBeenCalledWith(request); + expect(chatService.call).toHaveBeenCalled(); }); }); diff --git a/libs/openai-assistant/src/lib/chat/chat.gateway.ts b/libs/openai-assistant/src/lib/chat/chat.gateway.ts index 7d671d2..b919a1a 100644 --- a/libs/openai-assistant/src/lib/chat/chat.gateway.ts +++ b/libs/openai-assistant/src/lib/chat/chat.gateway.ts @@ -7,8 +7,29 @@ import { WebSocketServer, } from '@nestjs/websockets'; import { Server, Socket } from 'socket.io'; -import { ChatEvents, ChatCallDto } from './chat.model'; +import { + ChatEvents, + ChatCallDto, + TextDonePayload, + ChatCallCallbacks, + TextDeltaPayload, + TextCreatedPayload, + ToolCallDonePayload, + ToolCallDeltaPayload, + ToolCallCreatedPayload, + MessageDeltaPayload, + MessageCreatedPayload, + MessageDonePayload, + ImageFileDonePayload, + RunStepCreatedPayload, + RunStepDeltaPayload, + RunStepDonePayload, +} from './chat.model'; import { ChatService } from './chat.service'; +import { + CodeInterpreterToolCallDelta, + FunctionToolCallDelta, +} from 'openai/resources/beta/threads/runs'; export class ChatGateway implements OnGatewayConnection { @WebSocketServer() server!: Server; @@ -22,21 +43,214 @@ export class ChatGateway implements OnGatewayConnection { this.logger.log('Client connected'); } - @SubscribeMessage(ChatEvents.SendMessage) + getCallbacks(socketId: string): ChatCallCallbacks { + return { + [ChatEvents.MessageCreated]: this.emitMessageCreated.bind(this, socketId), + [ChatEvents.MessageDelta]: this.emitMessageDelta.bind(this, socketId), + [ChatEvents.MessageDone]: this.emitMessageDone.bind(this, socketId), + [ChatEvents.TextCreated]: this.emitTextCreated.bind(this, socketId), + [ChatEvents.TextDelta]: this.emitTextDelta.bind(this, socketId), + [ChatEvents.TextDone]: this.emitTextDone.bind(this, socketId), + [ChatEvents.ToolCallCreated]: this.emitToolCallCreated.bind( + this, + socketId, + ), + [ChatEvents.ToolCallDelta]: this.emitToolCallDelta.bind(this, socketId), + [ChatEvents.ToolCallDone]: this.emitToolCallDone.bind(this, socketId), + [ChatEvents.ImageFileDone]: this.emitImageFileDone.bind(this, socketId), + [ChatEvents.RunStepCreated]: this.emitRunStepCreated.bind(this, socketId), + [ChatEvents.RunStepDelta]: this.emitRunStepDelta.bind(this, socketId), + [ChatEvents.RunStepDone]: this.emitRunStepDone.bind(this, socketId), + }; + } + + @SubscribeMessage(ChatEvents.CallStart) async listenForMessages( @MessageBody() request: ChatCallDto, @ConnectedSocket() socket: Socket, ) { - this.logger.log(`Socket "${ChatEvents.SendMessage}" (${socket.id}): - * thread: ${request.threadId} - * files: ${request?.file_ids} - * content: ${request.content}`); + this.logger.log( + `Socket "${ChatEvents.CallStart}" | threadId ${request.threadId} | files: ${request?.file_ids?.join(', ')} | content: ${request.content}`, + ); + + const callbacks: ChatCallCallbacks = this.getCallbacks(socket.id); + const message = await this.chatsService.call(request, callbacks); + + this.server?.to(socket.id).emit(ChatEvents.CallDone, message); + this.logger.log( + `Socket "${ChatEvents.CallDone}" | threadId ${message.threadId} | content: ${message.content}`, + ); + } + + async emitMessageCreated( + socketId: string, + @MessageBody() data: MessageCreatedPayload, + ) { + this.server.to(socketId).emit(ChatEvents.MessageCreated, data); + this.logger.log( + `Socket "${ChatEvents.MessageCreated}" | threadId: ${data.message.thread_id}`, + ); + } + + async emitMessageDelta( + socketId: string, + @MessageBody() data: MessageDeltaPayload, + ) { + this.server.to(socketId).emit(ChatEvents.MessageDelta, data); + this.logger.log( + `Socket "${ChatEvents.MessageDelta}" | threadId: ${data.message.thread_id}`, + ); + } + + async emitMessageDone( + socketId: string, + @MessageBody() data: MessageDonePayload, + ) { + this.server.to(socketId).emit(ChatEvents.MessageDone, data); + this.logger.log( + `Socket "${ChatEvents.MessageDone}" | threadId: ${data.message.thread_id}`, + ); + } + + async emitTextCreated( + socketId: string, + @MessageBody() data: TextCreatedPayload, + ) { + this.server.to(socketId).emit(ChatEvents.TextCreated, data); + this.logger.log(`Socket "${ChatEvents.TextCreated}" | ${data.text.value}`); + } + + async emitTextDelta(socketId: string, @MessageBody() data: TextDeltaPayload) { + this.server.to(socketId).emit(ChatEvents.TextDelta, data); + this.logger.log( + `Socket "${ChatEvents.TextDelta}" | ${data.textDelta.value}`, + ); + } + + async emitTextDone(socketId: string, @MessageBody() data: TextDonePayload) { + this.server.to(socketId).emit(ChatEvents.TextDone, data); + this.logger.log( + `Socket "${ChatEvents.TextDone}" | threadId: ${data.message?.thread_id} | ${data.text.value}`, + ); + } + + async emitToolCallCreated( + socketId: string, + @MessageBody() data: ToolCallCreatedPayload, + ) { + this.server.to(socketId).emit(ChatEvents.ToolCallCreated, data); + this.logger.log( + `Socket "${ChatEvents.ToolCallCreated}": ${data.toolCall.id}`, + ); + } + + codeInterpreterHandler( + socketId: string, + codeInterpreter: CodeInterpreterToolCallDelta.CodeInterpreter, + ) { + if (codeInterpreter?.input) { + this.server + .to(socketId) + .emit(ChatEvents.ToolCallDelta, codeInterpreter.input); + } - const message = await this.chatsService.call(request); + if (codeInterpreter?.outputs) { + codeInterpreter.outputs.forEach(output => { + if (output.type === 'logs') { + const outputLogs = output.logs; + this.server.to(socketId).emit(ChatEvents.ToolCallDelta, outputLogs); + } + }); + } + } - this.server?.to(socket.id).emit(ChatEvents.MessageReceived, message); - this.logger.log(`Socket "${ChatEvents.MessageReceived}" (${socket.id}): - * thread: ${message.threadId} - * content: ${message.content}`); + functionHandler( + socketId: string, + functionType: FunctionToolCallDelta.Function, + ) { + if (functionType?.arguments) { + this.server + .to(socketId) + .emit(ChatEvents.ToolCallDelta, functionType.arguments); + } + + if (functionType?.output) { + this.server + .to(socketId) + .emit(ChatEvents.ToolCallDelta, functionType.output); + } + } + + async emitToolCallDelta( + socketId: string, + @MessageBody() data: ToolCallDeltaPayload, + ) { + this.logger.log( + `Socket "${ChatEvents.ToolCallDelta}": ${data.toolCall.id}`, + ); + + switch (data.toolCallDelta.type) { + case 'code_interpreter': + this.codeInterpreterHandler( + socketId, + data.toolCallDelta + .code_interpreter as CodeInterpreterToolCallDelta.CodeInterpreter, + ); + break; + case 'function': + this.functionHandler( + socketId, + data.toolCallDelta.function as FunctionToolCallDelta.Function, + ); + break; + } + } + + async emitToolCallDone( + socketId: string, + @MessageBody() data: ToolCallDonePayload, + ) { + this.server.to(socketId).emit(ChatEvents.ToolCallDone, data); + this.logger.log(`Socket "${ChatEvents.ToolCallDone}": ${data.toolCall.id}`); + } + + async emitImageFileDone( + socketId: string, + @MessageBody() data: ImageFileDonePayload, + ) { + this.server.to(socketId).emit(ChatEvents.ImageFileDone, data); + this.logger.log( + `Socket "${ChatEvents.ImageFileDone}": ${data.content.file_id}`, + ); + } + + async emitRunStepCreated( + socketId: string, + @MessageBody() data: RunStepCreatedPayload, + ) { + this.server.to(socketId).emit(ChatEvents.RunStepCreated, data); + this.logger.log( + `Socket "${ChatEvents.RunStepCreated}": ${data.runStep.status}`, + ); + } + + async emitRunStepDelta( + socketId: string, + @MessageBody() data: RunStepDeltaPayload, + ) { + this.server.to(socketId).emit(ChatEvents.RunStepDelta, data); + this.logger.log( + `Socket "${ChatEvents.RunStepDelta}": ${data.runStep.status}`, + ); + } + + async emitRunStepDone( + socketId: string, + @MessageBody() data: RunStepDonePayload, + ) { + this.server.to(socketId).emit(ChatEvents.RunStepDone, data); + this.logger.log( + `Socket "${ChatEvents.RunStepDone}": ${data.runStep.status}`, + ); } } diff --git a/libs/openai-assistant/src/lib/chat/chat.helpers.spec.ts b/libs/openai-assistant/src/lib/chat/chat.helpers.spec.ts index 3dabce6..6630e10 100644 --- a/libs/openai-assistant/src/lib/chat/chat.helpers.spec.ts +++ b/libs/openai-assistant/src/lib/chat/chat.helpers.spec.ts @@ -1,9 +1,5 @@ import { Test } from '@nestjs/testing'; -import { - Run, - ThreadMessage, - ThreadMessagesPage, -} from 'openai/resources/beta/threads'; +import { Message, MessagesPage, Run } from 'openai/resources/beta/threads'; import { PagePromise } from 'openai/core'; import { ChatModule } from './chat.module'; import { ChatHelpers } from './chat.helpers'; @@ -28,7 +24,7 @@ describe('ChatService', () => { describe('getAnswer', () => { it('should return a string', async () => { - const threadMessage: ThreadMessage = { + const threadMessage: Message = { content: [ { type: 'text', @@ -45,7 +41,7 @@ describe('ChatService', () => { }, }, ], - } as unknown as ThreadMessage; + } as unknown as Message; jest .spyOn(chatbotHelpers, 'getLastMessage') @@ -59,7 +55,7 @@ describe('ChatService', () => { describe('parseThreadMessage', () => { it('should return a string', () => { - const threadMessage: ThreadMessage = { + const threadMessage: Message = { content: [ { type: 'text', @@ -76,7 +72,7 @@ describe('ChatService', () => { }, }, ], - } as unknown as ThreadMessage; + } as unknown as Message; const result = chatbotHelpers.parseThreadMessage(threadMessage); @@ -100,15 +96,12 @@ describe('ChatService', () => { { run_id: '1', role: 'user', id: '2' }, { run_id: '1', role: 'assistant', id: '3' }, ], - } as unknown as ThreadMessagesPage; + } as unknown as MessagesPage; jest .spyOn(aiService.provider.beta.threads.messages, 'list') .mockReturnValue( - threadMessagesPage as unknown as PagePromise< - ThreadMessagesPage, - ThreadMessage - >, + threadMessagesPage as unknown as PagePromise, ); const result = await chatbotHelpers.getLastMessage({ id: '1' } as Run); @@ -122,15 +115,12 @@ describe('ChatService', () => { { run_id: '1', role: 'user', id: '2' }, { run_id: '1', role: 'user', id: '3' }, ], - } as unknown as ThreadMessagesPage; + } as unknown as MessagesPage; jest .spyOn(aiService.provider.beta.threads.messages, 'list') .mockReturnValue( - threadMessagesPage as unknown as PagePromise< - ThreadMessagesPage, - ThreadMessage - >, + threadMessagesPage as unknown as PagePromise, ); const result = await chatbotHelpers.getLastMessage({ id: '1' } as Run); diff --git a/libs/openai-assistant/src/lib/chat/chat.helpers.ts b/libs/openai-assistant/src/lib/chat/chat.helpers.ts index 6f76470..a23520e 100644 --- a/libs/openai-assistant/src/lib/chat/chat.helpers.ts +++ b/libs/openai-assistant/src/lib/chat/chat.helpers.ts @@ -1,8 +1,5 @@ import { Injectable } from '@nestjs/common'; -import { - Message, - Run, TextContentBlock, -} from 'openai/resources/beta/threads'; +import { Message, Run, TextContentBlock } from 'openai/resources/beta/threads'; import { AiService } from '../ai'; @Injectable() diff --git a/libs/openai-assistant/src/lib/chat/chat.model.ts b/libs/openai-assistant/src/lib/chat/chat.model.ts index d12b62d..c44484a 100644 --- a/libs/openai-assistant/src/lib/chat/chat.model.ts +++ b/libs/openai-assistant/src/lib/chat/chat.model.ts @@ -1,4 +1,17 @@ import { ApiProperty } from '@nestjs/swagger'; +import { + Message, + MessageDelta, + Text, + TextDelta, +} from 'openai/resources/beta/threads'; +import { + RunStepDelta, + ToolCall, + ToolCallDelta, +} from 'openai/resources/beta/threads/runs'; +import { ImageFile } from 'openai/resources/beta/threads/messages/messages'; +import { RunStep } from 'openai/resources/beta/threads/runs/steps'; export interface ChatAudio { file: File; @@ -10,8 +23,21 @@ export interface ChatAudioResponse { } export enum ChatEvents { - SendMessage = 'send_message', - MessageReceived = 'message_received', + CallStart = 'callStart', + CallDone = 'callDone', + MessageCreated = 'messageCreated', + MessageDelta = 'messageDelta', + MessageDone = 'messageDone', + TextCreated = 'textCreated', + TextDelta = 'textDelta', + TextDone = 'textDone', + ImageFileDone = 'imageFileDone', + ToolCallCreated = 'toolCallCreated', + ToolCallDelta = 'toolCallDelta', + ToolCallDone = 'toolCallDone', + RunStepCreated = 'runStepCreated', + RunStepDelta = 'runStepDelta', + RunStepDone = 'runStepDone', } export enum MessageStatus { @@ -39,3 +65,79 @@ export class ChatCallDto { @ApiProperty({ required: false }) metadata?: unknown | null; } + +export interface MessageCreatedPayload { + message: Message; +} + +export interface MessageDeltaPayload { + message: Message; + messageDelta: MessageDelta; +} + +export interface MessageDonePayload { + message: Message; +} + +export interface TextCreatedPayload { + text: Text; +} + +export interface TextDeltaPayload { + textDelta: TextDelta; + text: Text; +} + +export interface TextDonePayload { + text: Text; + message: Message; +} + +export interface ToolCallCreatedPayload { + toolCall: ToolCall; +} + +export interface ToolCallDeltaPayload { + toolCall: ToolCall; + toolCallDelta: ToolCallDelta; +} + +export interface ToolCallDonePayload { + toolCall: ToolCall; +} + +export interface ImageFileDonePayload { + content: ImageFile; + message: Message; +} + +export interface RunStepCreatedPayload { + runStep: RunStep; +} + +export interface RunStepDeltaPayload { + runStep: RunStep; + runStepDelta: RunStepDelta; +} + +export interface RunStepDonePayload { + runStep: RunStep; +} + +export interface ChatCallCallbacks { + [ChatEvents.MessageCreated]?: (data: MessageCreatedPayload) => Promise; + [ChatEvents.MessageDelta]?: (data: MessageDeltaPayload) => Promise; + [ChatEvents.MessageDone]?: (data: MessageDonePayload) => Promise; + [ChatEvents.TextCreated]?: (data: TextCreatedPayload) => Promise; + [ChatEvents.TextDelta]?: (data: TextDeltaPayload) => Promise; + [ChatEvents.TextDone]?: (data: TextDonePayload) => Promise; + [ChatEvents.ToolCallCreated]?: ( + data: ToolCallCreatedPayload, + ) => Promise; + [ChatEvents.ToolCallDelta]?: (data: ToolCallDeltaPayload) => Promise; + [ChatEvents.ToolCallDone]?: (data: ToolCallDonePayload) => Promise; + [ChatEvents.ImageFileDone]?: (data: ImageFileDonePayload) => Promise; + [ChatEvents.RunStepCreated]?: (data: RunStepCreatedPayload) => Promise; + [ChatEvents.RunStepDelta]?: (data: RunStepDeltaPayload) => Promise; + [ChatEvents.RunStepDone]?: (data: RunStepDonePayload) => Promise; +} diff --git a/libs/openai-assistant/src/lib/chat/chat.service.spec.ts b/libs/openai-assistant/src/lib/chat/chat.service.spec.ts index 952e85e..83012c3 100644 --- a/libs/openai-assistant/src/lib/chat/chat.service.spec.ts +++ b/libs/openai-assistant/src/lib/chat/chat.service.spec.ts @@ -1,12 +1,13 @@ import { Test } from '@nestjs/testing'; import { APIPromise } from 'openai/core'; -import { Run, ThreadMessage } from 'openai/resources/beta/threads'; +import { Message, Run } from 'openai/resources/beta/threads'; import { AiModule } from './../ai/ai.module'; import { ChatModule } from './chat.module'; import { ChatService } from './chat.service'; import { ChatHelpers } from './chat.helpers'; import { RunService } from '../run'; import { ChatCallDto } from './chat.model'; +import { AssistantStream } from 'openai/lib/AssistantStream'; describe('ChatService', () => { let chatService: ChatService; @@ -30,7 +31,13 @@ describe('ChatService', () => { jest .spyOn(chatService.threads.messages, 'create') - .mockReturnValue({} as APIPromise); + .mockReturnValue({} as APIPromise); + + jest.spyOn(chatService, 'assistantStream').mockReturnValue({ + finalRun(): Promise { + return Promise.resolve({} as Run); + }, + } as AssistantStream); }); it('should be defined', () => { @@ -41,8 +48,8 @@ describe('ChatService', () => { it('should create "thread run"', async () => { const payload = { content: 'Hello', threadId: '1' } as ChatCallDto; const spyOnThreadRunsCreate = jest - .spyOn(chatService.threads.runs, 'create') - .mockResolvedValue({} as Run); + .spyOn(chatService.threads.messages, 'create') + .mockResolvedValue({} as Message); await chatService.call(payload); diff --git a/libs/openai-assistant/src/lib/chat/chat.service.ts b/libs/openai-assistant/src/lib/chat/chat.service.ts index 711f1be..41da08a 100644 --- a/libs/openai-assistant/src/lib/chat/chat.service.ts +++ b/libs/openai-assistant/src/lib/chat/chat.service.ts @@ -1,9 +1,15 @@ import { Injectable } from '@nestjs/common'; -import { MessageCreateParams } from 'openai/resources/beta/threads'; import { AiService } from '../ai'; import { RunService } from '../run'; -import { ChatCallDto, ChatCallResponseDto } from './chat.model'; +import { + ChatCallCallbacks, + ChatCallDto, + ChatCallResponseDto, +} from './chat.model'; import { ChatHelpers } from './chat.helpers'; +import { MessageCreateParams } from 'openai/resources/beta/threads'; +import { AssistantStream } from 'openai/lib/AssistantStream'; +import { assistantStreamEventHandler } from '../stream/stream.utils'; @Injectable() export class ChatService { @@ -16,7 +22,10 @@ export class ChatService { private readonly chatbotHelpers: ChatHelpers, ) {} - async call(payload: ChatCallDto): Promise { + async call( + payload: ChatCallDto, + callbacks?: ChatCallCallbacks, + ): Promise { const { threadId, content, file_ids, metadata } = payload; const message: MessageCreateParams = { role: 'user', @@ -27,15 +36,25 @@ export class ChatService { await this.threads.messages.create(threadId, message); - const run = await this.threads.runs.create(threadId, { - assistant_id: process.env['ASSISTANT_ID'] || '', - }); + const run = this.assistantStream(threadId, callbacks); + const finalRun = await run.finalRun(); - await this.runService.resolve(run); + await this.runService.resolve(finalRun, true, callbacks); return { - content: await this.chatbotHelpers.getAnswer(run), + content: await this.chatbotHelpers.getAnswer(finalRun), threadId, }; } + + assistantStream( + threadId: string, + callbacks?: ChatCallCallbacks, + ): AssistantStream { + const runner = this.threads.runs.createAndStream(threadId, { + assistant_id: process.env['ASSISTANT_ID'] || '', + }); + + return assistantStreamEventHandler(runner, callbacks); + } } diff --git a/libs/openai-assistant/src/lib/run/run.service.spec.ts b/libs/openai-assistant/src/lib/run/run.service.spec.ts index 894cb1f..a235884 100644 --- a/libs/openai-assistant/src/lib/run/run.service.spec.ts +++ b/libs/openai-assistant/src/lib/run/run.service.spec.ts @@ -4,6 +4,11 @@ import { RunService } from './run.service'; import { RunModule } from './run.module'; import { AiService } from '../ai'; import { AgentService } from '../agent'; +import { AssistantStream } from 'openai/lib/AssistantStream'; + +jest.mock('../stream/stream.utils', () => ({ + assistantStreamEventHandler: jest.fn(), +})); describe('RunService', () => { let runService: RunService; @@ -103,10 +108,10 @@ describe('RunService', () => { }); describe('submitAction', () => { - it('should call submitToolOutputs', async () => { - const spyOnSubmitToolOutputs = jest - .spyOn(aiService.provider.beta.threads.runs, 'submitToolOutputs') - .mockResolvedValue({} as Run); + it('should call submitToolOutputsStream', async () => { + const spyOnSubmitToolOutputsStream = jest + .spyOn(aiService.provider.beta.threads.runs, 'submitToolOutputsStream') + .mockReturnValue({} as AssistantStream); jest.spyOn(agentsService, 'get').mockReturnValue(jest.fn()); const run = { @@ -122,7 +127,7 @@ describe('RunService', () => { await runService.submitAction(run); - expect(spyOnSubmitToolOutputs).toHaveBeenCalled(); + expect(spyOnSubmitToolOutputsStream).toHaveBeenCalled(); }); }); diff --git a/libs/openai-assistant/src/lib/run/run.service.ts b/libs/openai-assistant/src/lib/run/run.service.ts index 6a9fd78..92e1466 100644 --- a/libs/openai-assistant/src/lib/run/run.service.ts +++ b/libs/openai-assistant/src/lib/run/run.service.ts @@ -1,7 +1,14 @@ import { Injectable } from '@nestjs/common'; -import { Run, RunSubmitToolOutputsParams } from 'openai/resources/beta/threads'; +import { + Run, + RunSubmitToolOutputsParams, + Text, + TextDelta, +} from 'openai/resources/beta/threads'; import { AiService } from '../ai'; import { AgentService } from '../agent'; +import { ChatCallCallbacks } from '../chat'; +import { assistantStreamEventHandler } from '../stream/stream.utils'; @Injectable() export class RunService { @@ -19,7 +26,11 @@ export class RunService { return this.threads.runs.retrieve(run.thread_id, run.id); } - async resolve(run: Run, runningStatus = true): Promise { + async resolve( + run: Run, + runningStatus: boolean, + callbacks?: ChatCallCallbacks, + ): Promise { while (this.isRunning) switch (run.status) { case 'cancelling': @@ -29,7 +40,7 @@ export class RunService { case 'completed': return; case 'requires_action': - await this.submitAction(run); + await this.submitAction(run, callbacks); run = await this.continueRun(run); this.isRunning = runningStatus; continue; @@ -39,7 +50,7 @@ export class RunService { } } - async submitAction(run: Run): Promise { + async submitAction(run: Run, callbacks?: ChatCallCallbacks): Promise { if (run.required_action?.type !== 'submit_tool_outputs') { return; } @@ -55,8 +66,12 @@ export class RunService { }), ); - await this.threads.runs.submitToolOutputs(run.thread_id, run.id, { - tool_outputs: outputs, - }); + const runner = this.threads.runs.submitToolOutputsStream( + run.thread_id, + run.id, + { tool_outputs: outputs }, + ); + + assistantStreamEventHandler(runner, callbacks); } } diff --git a/libs/openai-assistant/src/lib/stream/stream.utils.ts b/libs/openai-assistant/src/lib/stream/stream.utils.ts new file mode 100644 index 0000000..f2dc83a --- /dev/null +++ b/libs/openai-assistant/src/lib/stream/stream.utils.ts @@ -0,0 +1,68 @@ +import { AbstractAssistantStreamRunner } from 'openai/lib/AbstractAssistantStreamRunner'; +import { AssistantStreamEvents } from 'openai/lib/AssistantStream'; +import { + Message, + MessageDelta, + Text, + TextDelta, +} from 'openai/resources/beta/threads'; +import { + RunStepDelta, + ToolCall, + ToolCallDelta, +} from 'openai/resources/beta/threads/runs'; +import { ImageFile } from 'openai/resources/beta/threads/messages/messages'; +import { RunStep } from 'openai/resources/beta/threads/runs/steps'; +import { ChatCallCallbacks, ChatEvents } from '../chat/chat.model'; + +export const assistantStreamEventHandler = ( + runner: AbstractAssistantStreamRunner, + callbacks?: ChatCallCallbacks, +) => { + return runner + .on(ChatEvents.MessageCreated, (message: Message) => + callbacks?.[ChatEvents.MessageCreated]?.({ message }), + ) + .on( + ChatEvents.MessageDelta, + (messageDelta: MessageDelta, message: Message) => + callbacks?.[ChatEvents.MessageDelta]?.({ messageDelta, message }), + ) + .on(ChatEvents.MessageDone, (message: Message) => + callbacks?.[ChatEvents.MessageDone]?.({ message }), + ) + .on(ChatEvents.TextCreated, (content: Text) => + callbacks?.[ChatEvents.TextCreated]?.({ text: content }), + ) + .on(ChatEvents.TextDelta, (delta: TextDelta, snapshot: Text) => + callbacks?.[ChatEvents.TextDelta]?.({ textDelta: delta, text: snapshot }), + ) + .on(ChatEvents.TextDone, (text: Text, message: Message) => + callbacks?.[ChatEvents.TextDone]?.({ text, message }), + ) + .on(ChatEvents.ToolCallCreated, (toolCall: ToolCall) => + callbacks?.[ChatEvents.ToolCallCreated]?.({ toolCall }), + ) + .on( + ChatEvents.ToolCallDelta, + (toolCallDelta: ToolCallDelta, toolCall: ToolCall) => + callbacks?.[ChatEvents.ToolCallDelta]?.({ toolCallDelta, toolCall }), + ) + .on(ChatEvents.ToolCallDone, (toolCall: ToolCall) => + callbacks?.[ChatEvents.ToolCallDone]?.({ toolCall }), + ) + .on(ChatEvents.ImageFileDone, (content: ImageFile, message: Message) => + callbacks?.[ChatEvents.ImageFileDone]?.({ message, content }), + ) + .on(ChatEvents.RunStepCreated, (runStep: RunStep) => + callbacks?.[ChatEvents.RunStepCreated]?.({ runStep }), + ) + .on( + ChatEvents.RunStepDelta, + (runStepDelta: RunStepDelta, runStep: RunStep) => + callbacks?.[ChatEvents.RunStepDelta]?.({ runStepDelta, runStep }), + ) + .on(ChatEvents.RunStepDone, (runStep: RunStep) => + callbacks?.[ChatEvents.RunStepDone]?.({ runStep }), + ) as T; +}; diff --git a/libs/openai-assistant/src/lib/threads/threads.model.ts b/libs/openai-assistant/src/lib/threads/threads.model.ts index e8b334c..934979c 100644 --- a/libs/openai-assistant/src/lib/threads/threads.model.ts +++ b/libs/openai-assistant/src/lib/threads/threads.model.ts @@ -1,6 +1,9 @@ import { ApiProperty } from '@nestjs/swagger'; import { IsOptional } from 'class-validator'; -import { ImageFileContentBlock, TextContentBlock } from 'openai/resources/beta/threads/messages/messages'; +import { + ImageFileContentBlock, + TextContentBlock, +} from 'openai/resources/beta/threads/messages/messages'; import { Message } from 'openai/resources/beta/threads'; export class GetThreadDto { diff --git a/libs/openai-assistant/src/lib/threads/threads.service.spec.ts b/libs/openai-assistant/src/lib/threads/threads.service.spec.ts index b0e4a56..c6eb75f 100644 --- a/libs/openai-assistant/src/lib/threads/threads.service.spec.ts +++ b/libs/openai-assistant/src/lib/threads/threads.service.spec.ts @@ -1,10 +1,10 @@ import { Test } from '@nestjs/testing'; import { Thread } from 'openai/resources/beta'; -import { ThreadMessagesPage } from 'openai/resources/beta/threads'; import { APIPromise } from 'openai/core'; import { ThreadsService } from './threads.service'; import { ThreadsModule } from './threads.module'; import { AiService } from '../ai'; +import { MessagesPage } from 'openai/resources/beta/threads'; describe('ThreadsService', () => { let threadsService: ThreadsService; @@ -29,7 +29,7 @@ describe('ThreadsService', () => { .spyOn(aiService.provider.beta.threads.messages, 'list') .mockResolvedValue({ data: [{ id: 'thread-1' }], - } as unknown as ThreadMessagesPage); + } as unknown as MessagesPage); const result = await threadsService.getThread('1'); @@ -39,7 +39,7 @@ describe('ThreadsService', () => { it('should return ThreadResponse with empty list of messages when data is undefined', async () => { jest .spyOn(aiService.provider.beta.threads.messages, 'list') - .mockResolvedValue({} as unknown as ThreadMessagesPage); + .mockResolvedValue({} as unknown as MessagesPage); const result = await threadsService.getThread('1'); diff --git a/libs/openai-assistant/src/lib/threads/threads.service.ts b/libs/openai-assistant/src/lib/threads/threads.service.ts index 3cd4104..cdd80e6 100644 --- a/libs/openai-assistant/src/lib/threads/threads.service.ts +++ b/libs/openai-assistant/src/lib/threads/threads.service.ts @@ -8,9 +8,8 @@ export class ThreadsService { constructor(private readonly aiService: AiService) {} async getThread(id: string): Promise { - const messages = await this.aiService.provider.beta.threads.messages.list( - id, - ); + const messages = + await this.aiService.provider.beta.threads.messages.list(id); return { id, messages: messages?.data || [], diff --git a/nx.json b/nx.json index c2647a1..fa85a7c 100644 --- a/nx.json +++ b/nx.json @@ -3,21 +3,12 @@ "targetDefaults": { "build": { "cache": true, - "dependsOn": [ - "^build" - ], - "inputs": [ - "production", - "^production" - ] + "dependsOn": ["^build"], + "inputs": ["production", "^production"] }, "@nx/jest:jest": { "cache": true, - "inputs": [ - "default", - "^production", - "{workspaceRoot}/jest.preset.js" - ], + "inputs": ["default", "^production", "{workspaceRoot}/jest.preset.js"], "options": { "passWithNoTests": true }, @@ -39,10 +30,7 @@ } }, "namedInputs": { - "default": [ - "{projectRoot}/**/*", - "sharedGlobals" - ], + "default": ["{projectRoot}/**/*", "sharedGlobals"], "production": [ "default", "!{projectRoot}/**/?(*.)+(spec|test).[jt]s?(x)?(.snap)", diff --git a/package.json b/package.json index 252b652..878264a 100644 --- a/package.json +++ b/package.json @@ -116,4 +116,3 @@ "includedScripts": [] } } - From 92522926aece2b409a901d7cdcbeb8c578d6f2c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Musia=C5=82?= Date: Fri, 29 Mar 2024 10:57:28 +0100 Subject: [PATCH 2/3] feat(spa): displaying the bot's reply as a streaming message --- .../+chat/shared/chat-gateway.service.ts | 33 +++++++--- .../app/modules/+chat/shared/chat.model.ts | 6 +- .../app/modules/+chat/shared/chat.service.ts | 63 ++++++++++++++++--- .../environments/environment.development.ts | 1 + apps/spa/src/environments/environment.ts | 1 + 5 files changed, 86 insertions(+), 18 deletions(-) diff --git a/apps/spa/src/app/modules/+chat/shared/chat-gateway.service.ts b/apps/spa/src/app/modules/+chat/shared/chat-gateway.service.ts index a3b0efc..0b1beec 100644 --- a/apps/spa/src/app/modules/+chat/shared/chat-gateway.service.ts +++ b/apps/spa/src/app/modules/+chat/shared/chat-gateway.service.ts @@ -1,7 +1,10 @@ import { Injectable } from '@angular/core'; import { ChatEvents } from './chat.model'; import io from 'socket.io-client'; -import { ChatCallDto } from '@boldare/openai-assistant'; +import { + ChatCallDto, + TextCreatedPayload, TextDeltaPayload, TextDonePayload +} from '@boldare/openai-assistant'; import { Observable } from 'rxjs'; import { environment } from '../../../../environments/environment'; @@ -9,14 +12,30 @@ import { environment } from '../../../../environments/environment'; export class ChatGatewayService { private socket = io(environment.websocketUrl); - sendMessage(payload: ChatCallDto): void { + watchEvent(event: ChatEvents): Observable { + return new Observable(observer => { + this.socket.on(event, data => observer.next(data)); + return () => this.socket.disconnect(); + }); + } + + callStart(payload: ChatCallDto): void { this.socket.emit(ChatEvents.CallStart, payload); } - getMessages(): Observable { - return new Observable(observer => { - this.socket.on(ChatEvents.CallDone, data => observer.next(data)); - return () => this.socket.disconnect(); - }); + callDone(): Observable { + return this.watchEvent(ChatEvents.CallDone); + } + + textCreated(): Observable { + return this.watchEvent(ChatEvents.TextCreated); + } + + textDelta(): Observable { + return this.watchEvent(ChatEvents.TextDelta); + } + + textDone(): Observable { + return this.watchEvent(ChatEvents.TextDone); } } diff --git a/apps/spa/src/app/modules/+chat/shared/chat.model.ts b/apps/spa/src/app/modules/+chat/shared/chat.model.ts index c691ccf..5029725 100644 --- a/apps/spa/src/app/modules/+chat/shared/chat.model.ts +++ b/apps/spa/src/app/modules/+chat/shared/chat.model.ts @@ -17,12 +17,12 @@ export interface ChatMessage { export enum ChatEvents { CallStart = 'callStart', CallDone = 'callDone', - TextDone = 'textCreated', - TextCreated = 'textDelta', - TextDelta = 'textDone', MessageCreated = 'messageCreated', MessageDelta = 'messageDelta', MessageDone = 'messageDone', + TextCreated = 'textCreated', + TextDelta = 'textDelta', + TextDone = 'textDone', ImageFileDone = 'imageFileDone', ToolCallCreated = 'toolCallCreated', ToolCallDelta = 'toolCallDelta', diff --git a/apps/spa/src/app/modules/+chat/shared/chat.service.ts b/apps/spa/src/app/modules/+chat/shared/chat.service.ts index 4830a7f..69616e9 100644 --- a/apps/spa/src/app/modules/+chat/shared/chat.service.ts +++ b/apps/spa/src/app/modules/+chat/shared/chat.service.ts @@ -34,11 +34,21 @@ export class ChatService { ) { document.body.classList.add('ai-chat'); + this.subscribeMessages(); this.setInitialValues(); - this.watchMessages(); this.watchVisibility(); } + subscribeMessages(): void { + if (!environment.isStreamingEnabled) { + this.watchMessages(); + } else { + this.watchTextCreated(); + this.watchTextDelta(); + this.watchTextDone(); + } + } + isMessageInvisible(message: Message): boolean { const metadata = message.metadata as Record; return metadata?.['status'] === ChatMessageStatus.Invisible; @@ -86,11 +96,13 @@ export class ChatService { refresh(): void { this.isLoading$.next(true); + this.isTyping$.next(false); this.messages$.next([]); this.threadService.start().subscribe(); } clear(): void { + this.isTyping$.next(false); this.threadService.clear(); this.messages$.next([]); } @@ -119,21 +131,56 @@ export class ChatService { const files = await this.chatFilesService.sendFiles(); this.addFileMessage(files); - this.chatGatewayService.sendMessage({ + this.chatGatewayService.callStart({ content, threadId: this.threadService.threadId$.value, file_ids: files.map(file => file.id) || [], }); } + watchTextCreated(): Subscription { + return this.chatGatewayService + .textCreated() + .subscribe((data) => { + this.isTyping$.next(false) + this.addMessage({ content: data.text.value, role: ChatRole.Assistant }) + }); + } + + watchTextDelta(): Subscription { + return this.chatGatewayService + .textDelta() + .subscribe((data) => { + const length = this.messages$.value.length; + this.messages$.value[length - 1].content = data.text.value; + }); + } + + watchTextDone(): Subscription { + return this.chatGatewayService + .textDone() + .subscribe((data) => { + this.isTyping$.next(false); + this.messages$.next([ + ...this.messages$.value.slice(0, -1), + { + content: data.text.value, + role: ChatRole.Assistant, + }, + ]); + }); + } + watchMessages(): Subscription { - return this.chatGatewayService.getMessages().subscribe(data => { - this.addMessage({ - content: data.content, - role: ChatRole.Assistant, + return this.chatGatewayService + .callDone() + .subscribe(data => { + this.addMessage({ + content: data.content, + role: ChatRole.Assistant, + }); + this.isTyping$.next(false); }); - this.isTyping$.next(false); - }); } sendAudio(file: Blob): void { diff --git a/apps/spa/src/environments/environment.development.ts b/apps/spa/src/environments/environment.development.ts index 8bd5b5c..8094e4d 100644 --- a/apps/spa/src/environments/environment.development.ts +++ b/apps/spa/src/environments/environment.development.ts @@ -10,4 +10,5 @@ export const environment = { isRefreshEnabled: true, isConfigEnabled: true, isAutoOpen: true, + isStreamingEnabled: true, }; diff --git a/apps/spa/src/environments/environment.ts b/apps/spa/src/environments/environment.ts index 980c462..1af5e17 100644 --- a/apps/spa/src/environments/environment.ts +++ b/apps/spa/src/environments/environment.ts @@ -10,4 +10,5 @@ export const environment = { isRefreshEnabled: true, isConfigEnabled: true, isAutoOpen: true, + isStreamingEnabled: true, }; From 209abbbdbb64433a0932ccb42886517cd3250567 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Musia=C5=82?= Date: Fri, 29 Mar 2024 13:14:09 +0100 Subject: [PATCH 3/3] feat(openai-assistant): handled "requires_action" event --- .../src/lib/chat/chat.service.spec.ts | 10 +-- .../src/lib/chat/chat.service.ts | 30 ++++---- .../src/lib/run/run.service.spec.ts | 74 ------------------- .../src/lib/run/run.service.ts | 38 +--------- 4 files changed, 19 insertions(+), 133 deletions(-) diff --git a/libs/openai-assistant/src/lib/chat/chat.service.spec.ts b/libs/openai-assistant/src/lib/chat/chat.service.spec.ts index 83012c3..e5c6c57 100644 --- a/libs/openai-assistant/src/lib/chat/chat.service.spec.ts +++ b/libs/openai-assistant/src/lib/chat/chat.service.spec.ts @@ -5,14 +5,12 @@ import { AiModule } from './../ai/ai.module'; import { ChatModule } from './chat.module'; import { ChatService } from './chat.service'; import { ChatHelpers } from './chat.helpers'; -import { RunService } from '../run'; import { ChatCallDto } from './chat.model'; import { AssistantStream } from 'openai/lib/AssistantStream'; describe('ChatService', () => { let chatService: ChatService; let chatbotHelpers: ChatHelpers; - let runService: RunService; beforeEach(async () => { const moduleRef = await Test.createTestingModule({ @@ -21,23 +19,19 @@ describe('ChatService', () => { chatService = moduleRef.get(ChatService); chatbotHelpers = moduleRef.get(ChatHelpers); - runService = moduleRef.get(RunService); jest .spyOn(chatbotHelpers, 'getAnswer') .mockReturnValue(Promise.resolve('Hello response') as Promise); - jest.spyOn(runService, 'resolve').mockReturnThis(); jest .spyOn(chatService.threads.messages, 'create') .mockReturnValue({} as APIPromise); jest.spyOn(chatService, 'assistantStream').mockReturnValue({ - finalRun(): Promise { - return Promise.resolve({} as Run); - }, - } as AssistantStream); + finalRun: jest.fn(), + } as unknown as Promise); }); it('should be defined', () => { diff --git a/libs/openai-assistant/src/lib/chat/chat.service.ts b/libs/openai-assistant/src/lib/chat/chat.service.ts index fc30b91..84e415e 100644 --- a/libs/openai-assistant/src/lib/chat/chat.service.ts +++ b/libs/openai-assistant/src/lib/chat/chat.service.ts @@ -7,7 +7,7 @@ import { ChatCallResponseDto, } from './chat.model'; import { ChatHelpers } from './chat.helpers'; -import { MessageCreateParams } from 'openai/resources/beta/threads'; +import { MessageCreateParams, Run } from 'openai/resources/beta/threads'; import { AssistantStream } from 'openai/lib/AssistantStream'; import { assistantStreamEventHandler } from '../stream/stream.utils'; @@ -36,12 +36,8 @@ export class ChatService { await this.threads.messages.create(threadId, message); - const assistantId = - payload?.assistantId || process.env['ASSISTANT_ID'] || ''; - const run = this.assistantStream(assistantId, threadId, callbacks); - const finalRun = await run.finalRun(); - - await this.runService.resolve(finalRun, true, callbacks); + const runner = await this.assistantStream(payload, callbacks); + const finalRun = await runner.finalRun(); return { content: await this.chatbotHelpers.getAnswer(finalRun), @@ -49,14 +45,20 @@ export class ChatService { }; } - assistantStream( - assistantId: string, - threadId: string, + async assistantStream( + payload: ChatCallDto, callbacks?: ChatCallCallbacks, - ): AssistantStream { - const runner = this.threads.runs.createAndStream(threadId, { - assistant_id: assistantId, - }); + ): Promise { + const assistant_id = + payload?.assistantId || process.env['ASSISTANT_ID'] || ''; + + const runner = this.threads.runs + .createAndStream(payload.threadId, { assistant_id }) + .on('event', event => { + if (event.event === 'thread.run.requires_action') { + this.runService.submitAction(event.data, callbacks); + } + }); return assistantStreamEventHandler(runner, callbacks); } diff --git a/libs/openai-assistant/src/lib/run/run.service.spec.ts b/libs/openai-assistant/src/lib/run/run.service.spec.ts index a235884..acdda6e 100644 --- a/libs/openai-assistant/src/lib/run/run.service.spec.ts +++ b/libs/openai-assistant/src/lib/run/run.service.spec.ts @@ -33,80 +33,6 @@ describe('RunService', () => { expect(runService).toBeDefined(); }); - describe('continueRun', () => { - it('should call threads.runs.retrieve', async () => { - const spyOnRetrieve = jest - .spyOn(aiService.provider.beta.threads.runs, 'retrieve') - .mockReturnThis(); - const run = { thread_id: '1', id: '123' } as Run; - - await runService.continueRun(run); - - expect(spyOnRetrieve).toHaveBeenCalled(); - }); - - it('should wait for timeout', async () => { - const run = { thread_id: '1', id: '123' } as Run; - const spyOnTimeout = jest.spyOn(global, 'setTimeout'); - - await runService.continueRun(run); - - expect(spyOnTimeout).toHaveBeenCalledWith( - expect.any(Function), - runService.timeout, - ); - }); - }); - - describe('resolve', () => { - it('should call continueRun', async () => { - const spyOnContinueRun = jest - .spyOn(runService, 'continueRun') - .mockResolvedValue({} as Run); - const run = { status: 'requires_action' } as Run; - - await runService.resolve(run, false); - - expect(spyOnContinueRun).toHaveBeenCalled(); - }); - - it('should call submitAction', async () => { - const spyOnSubmitAction = jest - .spyOn(runService, 'submitAction') - .mockResolvedValue(); - const run = { - status: 'requires_action', - required_action: { type: 'submit_tool_outputs' }, - } as Run; - - await runService.resolve(run, false); - - expect(spyOnSubmitAction).toHaveBeenCalled(); - }); - - it('should call default', async () => { - const spyOnContinueRun = jest - .spyOn(runService, 'continueRun') - .mockResolvedValue({} as Run); - const run = { status: 'unknown' } as unknown as Run; - - await runService.resolve(run, false); - - expect(spyOnContinueRun).toHaveBeenCalled(); - }); - - it('should not invoke action when status is cancelling', async () => { - const spyOnContinueRun = jest - .spyOn(runService, 'continueRun') - .mockResolvedValue({} as Run); - const run = { status: 'cancelling' } as unknown as Run; - - await runService.resolve(run, false); - - expect(spyOnContinueRun).not.toHaveBeenCalled(); - }); - }); - describe('submitAction', () => { it('should call submitToolOutputsStream', async () => { const spyOnSubmitToolOutputsStream = jest diff --git a/libs/openai-assistant/src/lib/run/run.service.ts b/libs/openai-assistant/src/lib/run/run.service.ts index 92e1466..e3c0e94 100644 --- a/libs/openai-assistant/src/lib/run/run.service.ts +++ b/libs/openai-assistant/src/lib/run/run.service.ts @@ -1,10 +1,5 @@ import { Injectable } from '@nestjs/common'; -import { - Run, - RunSubmitToolOutputsParams, - Text, - TextDelta, -} from 'openai/resources/beta/threads'; +import { Run, RunSubmitToolOutputsParams } from 'openai/resources/beta/threads'; import { AiService } from '../ai'; import { AgentService } from '../agent'; import { ChatCallCallbacks } from '../chat'; @@ -13,43 +8,12 @@ import { assistantStreamEventHandler } from '../stream/stream.utils'; @Injectable() export class RunService { private readonly threads = this.aiService.provider.beta.threads; - timeout = 2000; - isRunning = true; constructor( private readonly aiService: AiService, private readonly agentsService: AgentService, ) {} - async continueRun(run: Run): Promise { - await new Promise(resolve => setTimeout(resolve, this.timeout)); - return this.threads.runs.retrieve(run.thread_id, run.id); - } - - async resolve( - run: Run, - runningStatus: boolean, - callbacks?: ChatCallCallbacks, - ): Promise { - while (this.isRunning) - switch (run.status) { - case 'cancelling': - case 'cancelled': - case 'failed': - case 'expired': - case 'completed': - return; - case 'requires_action': - await this.submitAction(run, callbacks); - run = await this.continueRun(run); - this.isRunning = runningStatus; - continue; - default: - run = await this.continueRun(run); - this.isRunning = runningStatus; - } - } - async submitAction(run: Run, callbacks?: ChatCallCallbacks): Promise { if (run.required_action?.type !== 'submit_tool_outputs') { return;