diff --git a/core/src/browser/extensions/conversational.ts b/core/src/browser/extensions/conversational.ts index ec53fbbbf9..49fedd5448 100644 --- a/core/src/browser/extensions/conversational.ts +++ b/core/src/browser/extensions/conversational.ts @@ -1,4 +1,10 @@ -import { Thread, ThreadInterface, ThreadMessage, MessageInterface } from '../../types' +import { + Thread, + ThreadInterface, + ThreadMessage, + MessageInterface, + ThreadAssistantInfo, +} from '../../types' import { BaseExtension, ExtensionTypeEnum } from '../extension' /** @@ -17,10 +23,21 @@ export abstract class ConversationalExtension return ExtensionTypeEnum.Conversational } - abstract getThreads(): Promise - abstract saveThread(thread: Thread): Promise + abstract listThreads(): Promise + abstract createThread(thread: Partial): Promise + abstract modifyThread(thread: Thread): Promise abstract deleteThread(threadId: string): Promise - abstract addNewMessage(message: ThreadMessage): Promise - abstract writeMessages(threadId: string, messages: ThreadMessage[]): Promise - abstract getAllMessages(threadId: string): Promise + abstract createMessage(message: Partial): Promise + abstract deleteMessage(threadId: string, messageId: string): Promise + abstract listMessages(threadId: string): Promise + abstract getThreadAssistant(threadId: string): Promise + abstract createThreadAssistant( + threadId: string, + assistant: ThreadAssistantInfo + ): Promise + abstract modifyThreadAssistant( + threadId: string, + assistant: ThreadAssistantInfo + ): Promise + abstract modifyMessage(message: ThreadMessage): Promise } diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index d0528b0abf..2d1bdb3c2f 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -2,7 +2,6 @@ import { events } from '../../events' import { BaseExtension } from '../../extension' import { MessageRequest, Model, ModelEvent } from '../../../types' import { EngineManager } from './EngineManager' -import { ModelManager } from '../../models/manager' /** * Base AIEngine diff --git a/core/src/node/api/restful/helper/builder.ts b/core/src/node/api/restful/helper/builder.ts index e081708cfe..230eb64ab0 100644 --- a/core/src/node/api/restful/helper/builder.ts +++ b/core/src/node/api/restful/helper/builder.ts @@ -6,7 +6,6 @@ import { mkdirSync, appendFileSync, createWriteStream, - rmdirSync, } from 'fs' import { JanApiRouteConfiguration, RouteConfiguration } from './configuration' import { join } from 'path' @@ -126,7 +125,7 @@ export const createThread = async (thread: any) => { } } - const threadId = generateThreadId(thread.assistants[0].assistant_id) + const threadId = generateThreadId(thread.assistants[0]?.assistant_id) try { const updatedThread = { ...thread, @@ -280,7 +279,7 @@ export const models = async (request: any, reply: any) => { 'Content-Type': 'application/json', } - const response = await fetch(`${CORTEX_API_URL}/models${request.url.split('/models')[1] ?? ""}`, { + const response = await fetch(`${CORTEX_API_URL}/models${request.url.split('/models')[1] ?? ''}`, { method: request.method, headers: headers, body: JSON.stringify(request.body), diff --git a/core/src/types/assistant/assistantEntity.ts b/core/src/types/assistant/assistantEntity.ts index 27592e26b6..42617a4b5e 100644 --- a/core/src/types/assistant/assistantEntity.ts +++ b/core/src/types/assistant/assistantEntity.ts @@ -36,3 +36,10 @@ export type Assistant = { /** Represents the metadata of the object. */ metadata?: Record } + +export interface CodeInterpreterTool { + /** + * The type of tool being defined: `code_interpreter` + */ + type: 'code_interpreter' +} diff --git a/core/src/types/message/messageEntity.ts b/core/src/types/message/messageEntity.ts index 26bcad1a74..7c2774da6d 100644 --- a/core/src/types/message/messageEntity.ts +++ b/core/src/types/message/messageEntity.ts @@ -1,3 +1,4 @@ +import { CodeInterpreterTool } from '../assistant' import { ChatCompletionMessage, ChatCompletionRole } from '../inference' import { ModelInfo } from '../model' import { Thread } from '../thread' @@ -15,6 +16,10 @@ export type ThreadMessage = { thread_id: string /** The assistant id of this thread. **/ assistant_id?: string + /** + * A list of files attached to the message, and the tools they were added to. + */ + attachments?: Array | null /** The role of the author of this message. **/ role: ChatCompletionRole /** The content of this message. **/ @@ -52,6 +57,11 @@ export type MessageRequest = { */ assistantId?: string + /** + * A list of files attached to the message, and the tools they were added to. + */ + attachments: Array | null + /** Messages for constructing a chat completion request **/ messages?: ChatCompletionMessage[] @@ -97,8 +107,7 @@ export enum ErrorCode { */ export enum ContentType { Text = 'text', - Image = 'image', - Pdf = 'pdf', + Image = 'image_url', } /** @@ -108,8 +117,15 @@ export enum ContentType { export type ContentValue = { value: string annotations: string[] - name?: string - size?: number +} + +/** + * The `ImageContentValue` type defines the shape of a content value object of image type + * @data_transfer_object + */ +export type ImageContentValue = { + detail?: string + url?: string } /** @@ -118,5 +134,37 @@ export type ContentValue = { */ export type ThreadContent = { type: ContentType - text: ContentValue + text?: ContentValue + image_url?: ImageContentValue +} + +export interface Attachment { + /** + * The ID of the file to attach to the message. + */ + file_id?: string + + /** + * The tools to add this file to. + */ + tools?: Array +} + +export namespace Attachment { + export interface AssistantToolsFileSearchTypeOnly { + /** + * The type of tool being defined: `file_search` + */ + type: 'file_search' + } +} + +/** + * On an incomplete message, details about why the message is incomplete. + */ +export interface IncompleteDetails { + /** + * The reason the message is incomplete. + */ + reason: 'content_filter' | 'max_tokens' | 'run_cancelled' | 'run_expired' | 'run_failed' } diff --git a/core/src/types/message/messageInterface.ts b/core/src/types/message/messageInterface.ts index f6579da88b..1ea04298a0 100644 --- a/core/src/types/message/messageInterface.ts +++ b/core/src/types/message/messageInterface.ts @@ -11,20 +11,20 @@ export interface MessageInterface { * @param {ThreadMessage} message - The message to be added. * @returns {Promise} A promise that resolves when the message has been added. */ - addNewMessage(message: ThreadMessage): Promise - - /** - * Writes an array of messages to a specific thread. - * @param {string} threadId - The ID of the thread to write the messages to. - * @param {ThreadMessage[]} messages - The array of messages to be written. - * @returns {Promise} A promise that resolves when the messages have been written. - */ - writeMessages(threadId: string, messages: ThreadMessage[]): Promise + createMessage(message: ThreadMessage): Promise /** * Retrieves all messages from a specific thread. * @param {string} threadId - The ID of the thread to retrieve the messages from. * @returns {Promise} A promise that resolves to an array of messages from the thread. */ - getAllMessages(threadId: string): Promise + listMessages(threadId: string): Promise + + /** + * Deletes a specific message from a thread. + * @param {string} threadId - The ID of the thread from which the message will be deleted. + * @param {string} messageId - The ID of the message to be deleted. + * @returns {Promise} A promise that resolves when the message has been successfully deleted. + */ + deleteMessage(threadId: string, messageId: string): Promise } diff --git a/core/src/types/thread/threadInterface.ts b/core/src/types/thread/threadInterface.ts index 792c8c8a5f..4a78812c6a 100644 --- a/core/src/types/thread/threadInterface.ts +++ b/core/src/types/thread/threadInterface.ts @@ -11,15 +11,23 @@ export interface ThreadInterface { * @abstract * @returns {Promise} A promise that resolves to an array of threads. */ - getThreads(): Promise + listThreads(): Promise /** - * Saves a thread. + * Create a thread. * @abstract * @param {Thread} thread - The thread to save. * @returns {Promise} A promise that resolves when the thread is saved. */ - saveThread(thread: Thread): Promise + createThread(thread: Thread): Promise + + /** + * modify a thread. + * @abstract + * @param {Thread} thread - The thread to save. + * @returns {Promise} A promise that resolves when the thread is saved. + */ + modifyThread(thread: Thread): Promise /** * Deletes a thread. diff --git a/electron/tests/config/fixtures.ts b/electron/tests/config/fixtures.ts index bc3f8a7d13..f61eddfaed 100644 --- a/electron/tests/config/fixtures.ts +++ b/electron/tests/config/fixtures.ts @@ -108,7 +108,7 @@ export const test = base.extend< }) test.beforeAll(async () => { - await rmSync(path.join(__dirname, '../../test-data'), { + rmSync(path.join(__dirname, '../../test-data'), { recursive: true, force: true, }) @@ -122,6 +122,5 @@ test.beforeAll(async () => { }) test.afterAll(async () => { - // temporally disabling this due to the config for parallel testing WIP // teardownElectron() }) diff --git a/electron/tests/e2e/navigation.e2e.spec.ts b/electron/tests/e2e/navigation.e2e.spec.ts index b599a951c1..1b463d3813 100644 --- a/electron/tests/e2e/navigation.e2e.spec.ts +++ b/electron/tests/e2e/navigation.e2e.spec.ts @@ -2,11 +2,8 @@ import { expect } from '@playwright/test' import { page, test, TIMEOUT } from '../config/fixtures' test('renders left navigation panel', async () => { - const settingsBtn = await page - .getByTestId('Thread') - .first() - .isEnabled({ timeout: TIMEOUT }) - expect([settingsBtn].filter((e) => !e).length).toBe(0) + const threadBtn = page.getByTestId('Thread').first() + await expect(threadBtn).toBeVisible({ timeout: TIMEOUT }) // Chat section should be there await page.getByTestId('Local API Server').first().click({ timeout: TIMEOUT, diff --git a/extensions/assistant-extension/src/index.ts b/extensions/assistant-extension/src/index.ts index 6705483d64..0b3a1ec403 100644 --- a/extensions/assistant-extension/src/index.ts +++ b/extensions/assistant-extension/src/index.ts @@ -141,7 +141,7 @@ export default class JanAssistantExtension extends AssistantExtension { top_k: 2, chunk_size: 1024, chunk_overlap: 64, - retrieval_template: `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. + retrieval_template: `Use the following pieces of context to answer the question at the end. ---------------- CONTEXT: {CONTEXT} ---------------- diff --git a/extensions/assistant-extension/src/node/index.ts b/extensions/assistant-extension/src/node/index.ts index 83a4a19831..11e8f49c4f 100644 --- a/extensions/assistant-extension/src/node/index.ts +++ b/extensions/assistant-extension/src/node/index.ts @@ -9,13 +9,14 @@ export function toolRetrievalUpdateTextSplitter( retrieval.updateTextSplitter(chunkSize, chunkOverlap) } export async function toolRetrievalIngestNewDocument( + thread: string, file: string, model: string, engine: string, useTimeWeighted: boolean ) { - const filePath = path.join(getJanDataFolderPath(), normalizeFilePath(file)) - const threadPath = path.dirname(filePath.replace('files', '')) + const threadPath = path.join(getJanDataFolderPath(), 'threads', thread) + const filePath = path.join(getJanDataFolderPath(), 'files', file) retrieval.updateEmbeddingEngine(model, engine) return retrieval .ingestAgentKnowledge(filePath, `${threadPath}/memory`, useTimeWeighted) diff --git a/extensions/assistant-extension/src/tools/retrieval.ts b/extensions/assistant-extension/src/tools/retrieval.ts index 7631922871..b1a0c3cba0 100644 --- a/extensions/assistant-extension/src/tools/retrieval.ts +++ b/extensions/assistant-extension/src/tools/retrieval.ts @@ -35,6 +35,7 @@ export class RetrievalTool extends InferenceTool { await executeOnMain( NODE, 'toolRetrievalIngestNewDocument', + data.thread?.id, docFile, data.model?.id, data.model?.engine, diff --git a/extensions/conversational-extension/package.json b/extensions/conversational-extension/package.json index 036fcfab25..ea30064490 100644 --- a/extensions/conversational-extension/package.json +++ b/extensions/conversational-extension/package.json @@ -18,12 +18,14 @@ "devDependencies": { "cpx": "^1.5.0", "rimraf": "^3.0.2", + "ts-loader": "^9.5.0", "webpack": "^5.88.2", - "webpack-cli": "^5.1.4", - "ts-loader": "^9.5.0" + "webpack-cli": "^5.1.4" }, "dependencies": { - "@janhq/core": "file:../../core" + "@janhq/core": "file:../../core", + "ky": "^1.7.2", + "p-queue": "^8.0.1" }, "engines": { "node": ">=18.0.0" diff --git a/extensions/conversational-extension/src/@types/global.d.ts b/extensions/conversational-extension/src/@types/global.d.ts new file mode 100644 index 0000000000..757b5eebf3 --- /dev/null +++ b/extensions/conversational-extension/src/@types/global.d.ts @@ -0,0 +1,14 @@ +export {} +declare global { + declare const API_URL: string + declare const SOCKET_URL: string + + interface Core { + api: APIFunctions + events: EventEmitter + } + interface Window { + core?: Core | undefined + electronAPI?: any | undefined + } +} diff --git a/extensions/conversational-extension/src/Conversational.test.ts b/extensions/conversational-extension/src/Conversational.test.ts deleted file mode 100644 index 3d1d6fc607..0000000000 --- a/extensions/conversational-extension/src/Conversational.test.ts +++ /dev/null @@ -1,408 +0,0 @@ -/** - * @jest-environment jsdom - */ -jest.mock('@janhq/core', () => ({ - ...jest.requireActual('@janhq/core/node'), - fs: { - existsSync: jest.fn(), - mkdir: jest.fn(), - writeFileSync: jest.fn(), - readdirSync: jest.fn(), - readFileSync: jest.fn(), - appendFileSync: jest.fn(), - rm: jest.fn(), - writeBlob: jest.fn(), - joinPath: jest.fn(), - fileStat: jest.fn(), - }, - joinPath: jest.fn(), - ConversationalExtension: jest.fn(), -})) - -import { fs } from '@janhq/core' - -import JSONConversationalExtension from '.' - -describe('JSONConversationalExtension Tests', () => { - let extension: JSONConversationalExtension - - beforeEach(() => { - // @ts-ignore - extension = new JSONConversationalExtension() - }) - - it('should create thread folder on load if it does not exist', async () => { - // @ts-ignore - jest.spyOn(fs, 'existsSync').mockResolvedValue(false) - const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) - - await extension.onLoad() - - expect(mkdirSpy).toHaveBeenCalledWith('file://threads') - }) - - it('should log message on unload', () => { - const consoleSpy = jest.spyOn(console, 'debug').mockImplementation() - - extension.onUnload() - - expect(consoleSpy).toHaveBeenCalledWith( - 'JSONConversationalExtension unloaded' - ) - }) - - it('should return sorted threads', async () => { - jest - .spyOn(extension, 'getValidThreadDirs') - .mockResolvedValue(['dir1', 'dir2']) - jest - .spyOn(extension, 'readThread') - .mockResolvedValueOnce({ updated: '2023-01-01' }) - .mockResolvedValueOnce({ updated: '2023-01-02' }) - - const threads = await extension.getThreads() - - expect(threads).toEqual([ - { updated: '2023-01-02' }, - { updated: '2023-01-01' }, - ]) - }) - - it('should ignore broken threads', async () => { - jest - .spyOn(extension, 'getValidThreadDirs') - .mockResolvedValue(['dir1', 'dir2']) - jest - .spyOn(extension, 'readThread') - .mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' })) - .mockResolvedValueOnce('this_is_an_invalid_json_content') - - const threads = await extension.getThreads() - - expect(threads).toEqual([{ updated: '2023-01-01' }]) - }) - - it('should save thread', async () => { - // @ts-ignore - jest.spyOn(fs, 'existsSync').mockResolvedValue(false) - const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) - const writeFileSyncSpy = jest - .spyOn(fs, 'writeFileSync') - .mockResolvedValue({}) - - const thread = { id: '1', updated: '2023-01-01' } as any - await extension.saveThread(thread) - - expect(mkdirSpy).toHaveBeenCalled() - expect(writeFileSyncSpy).toHaveBeenCalled() - }) - - it('should delete thread', async () => { - const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({}) - - await extension.deleteThread('1') - - expect(rmSpy).toHaveBeenCalled() - }) - - it('should add new message', async () => { - // @ts-ignore - jest.spyOn(fs, 'existsSync').mockResolvedValue(false) - const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) - const appendFileSyncSpy = jest - .spyOn(fs, 'appendFileSync') - .mockResolvedValue({}) - - const message = { - thread_id: '1', - content: [{ type: 'text', text: { annotations: [] } }], - } as any - await extension.addNewMessage(message) - - expect(mkdirSpy).toHaveBeenCalled() - expect(appendFileSyncSpy).toHaveBeenCalled() - }) - - it('should store image', async () => { - const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({}) - - await extension.storeImage( - 'data:image/png;base64,abcd', - 'path/to/image.png' - ) - - expect(writeBlobSpy).toHaveBeenCalled() - }) - - it('should store file', async () => { - const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({}) - - await extension.storeFile( - 'data:application/pdf;base64,abcd', - 'path/to/file.pdf' - ) - - expect(writeBlobSpy).toHaveBeenCalled() - }) - - it('should write messages', async () => { - // @ts-ignore - jest.spyOn(fs, 'existsSync').mockResolvedValue(false) - const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) - const writeFileSyncSpy = jest - .spyOn(fs, 'writeFileSync') - .mockResolvedValue({}) - - const messages = [{ id: '1', thread_id: '1', content: [] }] as any - await extension.writeMessages('1', messages) - - expect(mkdirSpy).toHaveBeenCalled() - expect(writeFileSyncSpy).toHaveBeenCalled() - }) - - it('should get all messages on string response', async () => { - jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl']) - jest.spyOn(fs, 'readFileSync').mockResolvedValue('{"id":"1"}\n{"id":"2"}\n') - - const messages = await extension.getAllMessages('1') - - expect(messages).toEqual([{ id: '1' }, { id: '2' }]) - }) - - it('should get all messages on object response', async () => { - jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl']) - jest.spyOn(fs, 'readFileSync').mockResolvedValue({ id: 1 }) - - const messages = await extension.getAllMessages('1') - - expect(messages).toEqual([{ id: 1 }]) - }) - - it('get all messages return empty on error', async () => { - jest.spyOn(fs, 'readdirSync').mockRejectedValue(['messages.jsonl']) - - const messages = await extension.getAllMessages('1') - - expect(messages).toEqual([]) - }) - - it('return empty messages on no messages file', async () => { - jest.spyOn(fs, 'readdirSync').mockResolvedValue([]) - - const messages = await extension.getAllMessages('1') - - expect(messages).toEqual([]) - }) - - it('should ignore error message', async () => { - jest.spyOn(fs, 'readdirSync').mockResolvedValue(['messages.jsonl']) - jest - .spyOn(fs, 'readFileSync') - .mockResolvedValue('{"id":"1"}\nyolo\n{"id":"2"}\n') - - const messages = await extension.getAllMessages('1') - - expect(messages).toEqual([{ id: '1' }, { id: '2' }]) - }) - - it('should create thread folder on load if it does not exist', async () => { - // @ts-ignore - jest.spyOn(fs, 'existsSync').mockResolvedValue(false) - const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) - - await extension.onLoad() - - expect(mkdirSpy).toHaveBeenCalledWith('file://threads') - }) - - it('should log message on unload', () => { - const consoleSpy = jest.spyOn(console, 'debug').mockImplementation() - - extension.onUnload() - - expect(consoleSpy).toHaveBeenCalledWith( - 'JSONConversationalExtension unloaded' - ) - }) - - it('should return sorted threads', async () => { - jest - .spyOn(extension, 'getValidThreadDirs') - .mockResolvedValue(['dir1', 'dir2']) - jest - .spyOn(extension, 'readThread') - .mockResolvedValueOnce({ updated: '2023-01-01' }) - .mockResolvedValueOnce({ updated: '2023-01-02' }) - - const threads = await extension.getThreads() - - expect(threads).toEqual([ - { updated: '2023-01-02' }, - { updated: '2023-01-01' }, - ]) - }) - - it('should ignore broken threads', async () => { - jest - .spyOn(extension, 'getValidThreadDirs') - .mockResolvedValue(['dir1', 'dir2']) - jest - .spyOn(extension, 'readThread') - .mockResolvedValueOnce(JSON.stringify({ updated: '2023-01-01' })) - .mockResolvedValueOnce('this_is_an_invalid_json_content') - - const threads = await extension.getThreads() - - expect(threads).toEqual([{ updated: '2023-01-01' }]) - }) - - it('should save thread', async () => { - // @ts-ignore - jest.spyOn(fs, 'existsSync').mockResolvedValue(false) - const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) - const writeFileSyncSpy = jest - .spyOn(fs, 'writeFileSync') - .mockResolvedValue({}) - - const thread = { id: '1', updated: '2023-01-01' } as any - await extension.saveThread(thread) - - expect(mkdirSpy).toHaveBeenCalled() - expect(writeFileSyncSpy).toHaveBeenCalled() - }) - - it('should delete thread', async () => { - const rmSpy = jest.spyOn(fs, 'rm').mockResolvedValue({}) - - await extension.deleteThread('1') - - expect(rmSpy).toHaveBeenCalled() - }) - - it('should add new message', async () => { - // @ts-ignore - jest.spyOn(fs, 'existsSync').mockResolvedValue(false) - const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) - const appendFileSyncSpy = jest - .spyOn(fs, 'appendFileSync') - .mockResolvedValue({}) - - const message = { - thread_id: '1', - content: [{ type: 'text', text: { annotations: [] } }], - } as any - await extension.addNewMessage(message) - - expect(mkdirSpy).toHaveBeenCalled() - expect(appendFileSyncSpy).toHaveBeenCalled() - }) - - it('should add new image message', async () => { - jest - .spyOn(fs, 'existsSync') - // @ts-ignore - .mockResolvedValueOnce(false) - // @ts-ignore - .mockResolvedValueOnce(false) - // @ts-ignore - .mockResolvedValueOnce(true) - const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) - const appendFileSyncSpy = jest - .spyOn(fs, 'appendFileSync') - .mockResolvedValue({}) - jest.spyOn(fs, 'writeBlob').mockResolvedValue({}) - - const message = { - thread_id: '1', - content: [ - { type: 'image', text: { annotations: ['data:image;base64,hehe'] } }, - ], - } as any - await extension.addNewMessage(message) - - expect(mkdirSpy).toHaveBeenCalled() - expect(appendFileSyncSpy).toHaveBeenCalled() - }) - - it('should add new pdf message', async () => { - jest - .spyOn(fs, 'existsSync') - // @ts-ignore - .mockResolvedValueOnce(false) - // @ts-ignore - .mockResolvedValueOnce(false) - // @ts-ignore - .mockResolvedValueOnce(true) - const mkdirSpy = jest.spyOn(fs, 'mkdir').mockResolvedValue({}) - const appendFileSyncSpy = jest - .spyOn(fs, 'appendFileSync') - .mockResolvedValue({}) - jest.spyOn(fs, 'writeBlob').mockResolvedValue({}) - - const message = { - thread_id: '1', - content: [ - { type: 'pdf', text: { annotations: ['data:pdf;base64,hehe'] } }, - ], - } as any - await extension.addNewMessage(message) - - expect(mkdirSpy).toHaveBeenCalled() - expect(appendFileSyncSpy).toHaveBeenCalled() - }) - - it('should store image', async () => { - const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({}) - - await extension.storeImage( - 'data:image/png;base64,abcd', - 'path/to/image.png' - ) - - expect(writeBlobSpy).toHaveBeenCalled() - }) - - it('should store file', async () => { - const writeBlobSpy = jest.spyOn(fs, 'writeBlob').mockResolvedValue({}) - - await extension.storeFile( - 'data:application/pdf;base64,abcd', - 'path/to/file.pdf' - ) - - expect(writeBlobSpy).toHaveBeenCalled() - }) -}) - -describe('test readThread', () => { - let extension: JSONConversationalExtension - - beforeEach(() => { - // @ts-ignore - extension = new JSONConversationalExtension() - }) - - it('should read thread', async () => { - jest - .spyOn(fs, 'readFileSync') - .mockResolvedValue(JSON.stringify({ id: '1' })) - const thread = await extension.readThread('1') - expect(thread).toEqual(`{"id":"1"}`) - }) - - it('getValidThreadDirs should return valid thread directories', async () => { - jest - .spyOn(fs, 'readdirSync') - .mockResolvedValueOnce(['1', '2', '3']) - .mockResolvedValueOnce(['thread.json']) - .mockResolvedValueOnce(['thread.json']) - .mockResolvedValueOnce([]) - // @ts-ignore - jest.spyOn(fs, 'existsSync').mockResolvedValue(true) - jest.spyOn(fs, 'fileStat').mockResolvedValue({ - isDirectory: true, - } as any) - const validThreadDirs = await extension.getValidThreadDirs() - expect(validThreadDirs).toEqual(['1', '2']) - }) -}) diff --git a/extensions/conversational-extension/src/index.ts b/extensions/conversational-extension/src/index.ts index b34f09181d..81d9d2023a 100644 --- a/extensions/conversational-extension/src/index.ts +++ b/extensions/conversational-extension/src/index.ts @@ -1,90 +1,71 @@ import { - fs, - joinPath, ConversationalExtension, Thread, + ThreadAssistantInfo, ThreadMessage, } from '@janhq/core' -import { safelyParseJSON } from './jsonUtil' +import ky from 'ky' +import PQueue from 'p-queue' + +type ThreadList = { + data: Thread[] +} + +type MessageList = { + data: ThreadMessage[] +} /** * JSONConversationalExtension is a ConversationalExtension implementation that provides * functionality for managing threads. */ export default class JSONConversationalExtension extends ConversationalExtension { - private static readonly _threadFolder = 'file://threads' - private static readonly _threadInfoFileName = 'thread.json' - private static readonly _threadMessagesFileName = 'messages.jsonl' + queue = new PQueue({ concurrency: 1 }) /** * Called when the extension is loaded. */ async onLoad() { - if (!(await fs.existsSync(JSONConversationalExtension._threadFolder))) { - await fs.mkdir(JSONConversationalExtension._threadFolder) - } + this.queue.add(() => this.healthz()) } /** * Called when the extension is unloaded. */ - onUnload() { - console.debug('JSONConversationalExtension unloaded') - } + onUnload() {} /** * Returns a Promise that resolves to an array of Conversation objects. */ - async getThreads(): Promise { - try { - const threadDirs = await this.getValidThreadDirs() - - const promises = threadDirs.map((dirName) => this.readThread(dirName)) - const promiseResults = await Promise.allSettled(promises) - const convos = promiseResults - .map((result) => { - if (result.status === 'fulfilled') { - return typeof result.value === 'object' - ? result.value - : safelyParseJSON(result.value) - } - return undefined - }) - .filter((convo) => !!convo) - convos.sort( - (a, b) => new Date(b.updated).getTime() - new Date(a.updated).getTime() - ) - - return convos - } catch (error) { - console.error(error) - return [] - } + async listThreads(): Promise { + return this.queue.add(() => + ky + .get(`${API_URL}/v1/threads`) + .json() + .then((e) => e.data) + ) as Promise } /** * Saves a Thread object to a json file. * @param thread The Thread object to save. */ - async saveThread(thread: Thread): Promise { - try { - const threadDirPath = await joinPath([ - JSONConversationalExtension._threadFolder, - thread.id, - ]) - const threadJsonPath = await joinPath([ - threadDirPath, - JSONConversationalExtension._threadInfoFileName, - ]) - if (!(await fs.existsSync(threadDirPath))) { - await fs.mkdir(threadDirPath) - } + async createThread(thread: Thread): Promise { + return this.queue.add(() => + ky.post(`${API_URL}/v1/threads`, { json: thread }).json() + ) as Promise + } - await fs.writeFileSync(threadJsonPath, JSON.stringify(thread, null, 2)) - } catch (err) { - console.error(err) - Promise.reject(err) - } + /** + * Saves a Thread object to a json file. + * @param thread The Thread object to save. + */ + async modifyThread(thread: Thread): Promise { + return this.queue + .add(() => + ky.post(`${API_URL}/v1/threads/${thread.id}`, { json: thread }) + ) + .then() } /** @@ -92,189 +73,126 @@ export default class JSONConversationalExtension extends ConversationalExtension * @param threadId The ID of the thread to delete. */ async deleteThread(threadId: string): Promise { - const path = await joinPath([ - JSONConversationalExtension._threadFolder, - `${threadId}`, - ]) - try { - await fs.rm(path) - } catch (err) { - console.error(err) - } + return this.queue + .add(() => ky.delete(`${API_URL}/v1/threads/${threadId}`)) + .then() } - async addNewMessage(message: ThreadMessage): Promise { - try { - const threadDirPath = await joinPath([ - JSONConversationalExtension._threadFolder, - message.thread_id, - ]) - const threadMessagePath = await joinPath([ - threadDirPath, - JSONConversationalExtension._threadMessagesFileName, - ]) - if (!(await fs.existsSync(threadDirPath))) await fs.mkdir(threadDirPath) - - if (message.content[0]?.type === 'image') { - const filesPath = await joinPath([threadDirPath, 'files']) - if (!(await fs.existsSync(filesPath))) await fs.mkdir(filesPath) - - const imagePath = await joinPath([filesPath, `${message.id}.png`]) - const base64 = message.content[0].text.annotations[0] - await this.storeImage(base64, imagePath) - if ((await fs.existsSync(imagePath)) && message.content?.length) { - // Use file path instead of blob - message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.png` - } - } - - if (message.content[0]?.type === 'pdf') { - const filesPath = await joinPath([threadDirPath, 'files']) - if (!(await fs.existsSync(filesPath))) await fs.mkdir(filesPath) - - const filePath = await joinPath([filesPath, `${message.id}.pdf`]) - const blob = message.content[0].text.annotations[0] - await this.storeFile(blob, filePath) - - if ((await fs.existsSync(filePath)) && message.content?.length) { - // Use file path instead of blob - message.content[0].text.annotations[0] = `threads/${message.thread_id}/files/${message.id}.pdf` - } - } - await fs.appendFileSync(threadMessagePath, JSON.stringify(message) + '\n') - Promise.resolve() - } catch (err) { - Promise.reject(err) - } - } - - async storeImage(base64: string, filePath: string): Promise { - const base64Data = base64.replace(/^data:image\/\w+;base64,/, '') - - try { - await fs.writeBlob(filePath, base64Data) - } catch (err) { - console.error(err) - } + /** + * Adds a new message to a specified thread. + * @param message The ThreadMessage object to be added. + * @returns A Promise that resolves when the message has been added. + */ + async createMessage(message: ThreadMessage): Promise { + return this.queue.add(() => + ky + .post(`${API_URL}/v1/threads/${message.thread_id}/messages`, { + json: message, + }) + .json() + ) as Promise } - async storeFile(base64: string, filePath: string): Promise { - const base64Data = base64.replace(/^data:application\/pdf;base64,/, '') - try { - await fs.writeBlob(filePath, base64Data) - } catch (err) { - console.error(err) - } + /** + * Modifies a message in a thread. + * @param message + * @returns + */ + async modifyMessage(message: ThreadMessage): Promise { + return this.queue.add(() => + ky + .post( + `${API_URL}/v1/threads/${message.thread_id}/messages/${message.id}`, + { + json: message, + } + ) + .json() + ) as Promise } - async writeMessages( - threadId: string, - messages: ThreadMessage[] - ): Promise { - try { - const threadDirPath = await joinPath([ - JSONConversationalExtension._threadFolder, - threadId, - ]) - const threadMessagePath = await joinPath([ - threadDirPath, - JSONConversationalExtension._threadMessagesFileName, - ]) - if (!(await fs.existsSync(threadDirPath))) await fs.mkdir(threadDirPath) - await fs.writeFileSync( - threadMessagePath, - messages.map((msg) => JSON.stringify(msg)).join('\n') + - (messages.length ? '\n' : '') + /** + * Deletes a specific message from a thread. + * @param threadId The ID of the thread containing the message. + * @param messageId The ID of the message to be deleted. + * @returns A Promise that resolves when the message has been successfully deleted. + */ + async deleteMessage(threadId: string, messageId: string): Promise { + return this.queue + .add(() => + ky.delete(`${API_URL}/v1/threads/${threadId}/messages/${messageId}`) ) - Promise.resolve() - } catch (err) { - Promise.reject(err) - } + .then() } /** - * A promise builder for reading a thread from a file. - * @param threadDirName the thread dir we are reading from. - * @returns data of the thread + * Retrieves all messages for a specified thread. + * @param threadId The ID of the thread to get messages from. + * @returns A Promise that resolves to an array of ThreadMessage objects. */ - async readThread(threadDirName: string): Promise { - return fs.readFileSync( - await joinPath([ - JSONConversationalExtension._threadFolder, - threadDirName, - JSONConversationalExtension._threadInfoFileName, - ]), - 'utf-8' - ) + async listMessages(threadId: string): Promise { + return this.queue.add(() => + ky + .get(`${API_URL}/v1/threads/${threadId}/messages?order=asc`) + .json() + .then((e) => e.data) + ) as Promise } /** - * Returns a Promise that resolves to an array of thread directories. - * @private + * Retrieves the assistant information for a specified thread. + * @param threadId The ID of the thread for which to retrieve assistant information. + * @returns A Promise that resolves to a ThreadAssistantInfo object containing + * the details of the assistant associated with the specified thread. */ - async getValidThreadDirs(): Promise { - const fileInsideThread: string[] = await fs.readdirSync( - JSONConversationalExtension._threadFolder - ) - - const threadDirs: string[] = [] - for (let i = 0; i < fileInsideThread.length; i++) { - const path = await joinPath([ - JSONConversationalExtension._threadFolder, - fileInsideThread[i], - ]) - if (!(await fs.fileStat(path))?.isDirectory) continue - - const isHavingThreadInfo = (await fs.readdirSync(path)).includes( - JSONConversationalExtension._threadInfoFileName - ) - if (!isHavingThreadInfo) { - console.debug(`Ignore ${path} because it does not have thread info`) - continue - } - - threadDirs.push(fileInsideThread[i]) - } - return threadDirs + async getThreadAssistant(threadId: string): Promise { + return this.queue.add(() => + ky.get(`${API_URL}/v1/assistants/${threadId}`).json() + ) as Promise + } + /** + * Creates a new assistant for the specified thread. + * @param threadId The ID of the thread for which the assistant is being created. + * @param assistant The information about the assistant to be created. + * @returns A Promise that resolves to the newly created ThreadAssistantInfo object. + */ + async createThreadAssistant( + threadId: string, + assistant: ThreadAssistantInfo + ): Promise { + return this.queue.add(() => + ky + .post(`${API_URL}/v1/assistants/${threadId}`, { json: assistant }) + .json() + ) as Promise } - async getAllMessages(threadId: string): Promise { - try { - const threadDirPath = await joinPath([ - JSONConversationalExtension._threadFolder, - threadId, - ]) - - const files: string[] = await fs.readdirSync(threadDirPath) - if ( - !files.includes(JSONConversationalExtension._threadMessagesFileName) - ) { - console.debug(`${threadDirPath} not contains message file`) - return [] - } - - const messageFilePath = await joinPath([ - threadDirPath, - JSONConversationalExtension._threadMessagesFileName, - ]) - - let readResult = await fs.readFileSync(messageFilePath, 'utf-8') - - if (typeof readResult === 'object') { - readResult = JSON.stringify(readResult) - } - - const result = readResult.split('\n').filter((line) => line !== '') + /** + * Modifies an existing assistant for the specified thread. + * @param threadId The ID of the thread for which the assistant is being modified. + * @param assistant The updated information for the assistant. + * @returns A Promise that resolves to the updated ThreadAssistantInfo object. + */ + async modifyThreadAssistant( + threadId: string, + assistant: ThreadAssistantInfo + ): Promise { + return this.queue.add(() => + ky + .patch(`${API_URL}/v1/assistants/${threadId}`, { json: assistant }) + .json() + ) as Promise + } - const messages: ThreadMessage[] = [] - result.forEach((line: string) => { - const message = safelyParseJSON(line) - if (message) messages.push(safelyParseJSON(line)) + /** + * Do health check on cortex.cpp + * @returns + */ + healthz(): Promise { + return ky + .get(`${API_URL}/healthz`, { + retry: { limit: 20, delay: () => 500, methods: ['get'] }, }) - return messages - } catch (err) { - console.error(err) - return [] - } + .then(() => {}) } } diff --git a/extensions/conversational-extension/src/jsonUtil.ts b/extensions/conversational-extension/src/jsonUtil.ts deleted file mode 100644 index 7f83cadce5..0000000000 --- a/extensions/conversational-extension/src/jsonUtil.ts +++ /dev/null @@ -1,14 +0,0 @@ -// Note about performance -// The v8 JavaScript engine used by Node.js cannot optimise functions which contain a try/catch block. -// v8 4.5 and above can optimise try/catch -export function safelyParseJSON(json) { - // This function cannot be optimised, it's best to - // keep it small! - var parsed - try { - parsed = JSON.parse(json) - } catch (e) { - return undefined - } - return parsed // Could be undefined! -} diff --git a/extensions/conversational-extension/webpack.config.js b/extensions/conversational-extension/webpack.config.js index e4a0b2179e..0448af4212 100644 --- a/extensions/conversational-extension/webpack.config.js +++ b/extensions/conversational-extension/webpack.config.js @@ -17,7 +17,12 @@ module.exports = { filename: 'index.js', // Adjust the output file name as needed library: { type: 'module' }, // Specify ESM output format }, - plugins: [new webpack.DefinePlugin({})], + plugins: [ + new webpack.DefinePlugin({ + API_URL: JSON.stringify('http://127.0.0.1:39291'), + SOCKET_URL: JSON.stringify('ws://127.0.0.1:39291'), + }), + ], resolve: { extensions: ['.ts', '.js'], }, diff --git a/extensions/inference-cortex-extension/bin/version.txt b/extensions/inference-cortex-extension/bin/version.txt index a6a3a43c3a..25837a1836 100644 --- a/extensions/inference-cortex-extension/bin/version.txt +++ b/extensions/inference-cortex-extension/bin/version.txt @@ -1 +1 @@ -1.0.4 \ No newline at end of file +1.0.5-rc1 diff --git a/extensions/inference-cortex-extension/download.bat b/extensions/inference-cortex-extension/download.bat index 7d9a9213ae..0e7eef20e5 100644 --- a/extensions/inference-cortex-extension/download.bat +++ b/extensions/inference-cortex-extension/download.bat @@ -2,7 +2,7 @@ set BIN_PATH=./bin set SHARED_PATH=./../../electron/shared set /p CORTEX_VERSION=<./bin/version.txt -set ENGINE_VERSION=0.1.40 +set ENGINE_VERSION=0.1.42 @REM Download cortex.llamacpp binaries set DOWNLOAD_URL=https://github.com/janhq/cortex.llamacpp/releases/download/v%ENGINE_VERSION%/cortex.llamacpp-%ENGINE_VERSION%-windows-amd64 @@ -38,4 +38,4 @@ for %%F in (%SUBFOLDERS%) do ( ) ) -echo DLL files moved successfully. \ No newline at end of file +echo DLL files moved successfully. diff --git a/extensions/inference-cortex-extension/download.sh b/extensions/inference-cortex-extension/download.sh index f62e5961b6..b0f3b36e35 100755 --- a/extensions/inference-cortex-extension/download.sh +++ b/extensions/inference-cortex-extension/download.sh @@ -2,7 +2,7 @@ # Read CORTEX_VERSION CORTEX_VERSION=$(cat ./bin/version.txt) -ENGINE_VERSION=0.1.40 +ENGINE_VERSION=0.1.42 CORTEX_RELEASE_URL="https://github.com/janhq/cortex.cpp/releases/download" ENGINE_DOWNLOAD_URL="https://github.com/janhq/cortex.llamacpp/releases/download/v${ENGINE_VERSION}/cortex.llamacpp-${ENGINE_VERSION}" CUDA_DOWNLOAD_URL="https://github.com/janhq/cortex.llamacpp/releases/download/v${ENGINE_VERSION}" diff --git a/extensions/inference-cortex-extension/rollup.config.ts b/extensions/inference-cortex-extension/rollup.config.ts index 8fa61e91d8..266281a756 100644 --- a/extensions/inference-cortex-extension/rollup.config.ts +++ b/extensions/inference-cortex-extension/rollup.config.ts @@ -120,7 +120,7 @@ export default [ SETTINGS: JSON.stringify(defaultSettingJson), CORTEX_API_URL: JSON.stringify('http://127.0.0.1:39291'), CORTEX_SOCKET_URL: JSON.stringify('ws://127.0.0.1:39291'), - CORTEX_ENGINE_VERSION: JSON.stringify('v0.1.40'), + CORTEX_ENGINE_VERSION: JSON.stringify('v0.1.42'), }), // Allow json resolution json(), diff --git a/web/containers/ErrorMessage/index.tsx b/web/containers/ErrorMessage/index.tsx index 96ced0ac53..95b87fc53b 100644 --- a/web/containers/ErrorMessage/index.tsx +++ b/web/containers/ErrorMessage/index.tsx @@ -18,14 +18,14 @@ import { isLocalEngine } from '@/utils/modelEngine' import { mainViewStateAtom } from '@/helpers/atoms/App.atom' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom' -import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' const ErrorMessage = ({ message }: { message: ThreadMessage }) => { const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom) const setMainState = useSetAtom(mainViewStateAtom) const setSelectedSettingScreen = useSetAtom(selectedSettingAtom) - const activeThread = useAtomValue(activeThreadAtom) + const activeAssistant = useAtomValue(activeAssistantAtom) const defaultDesc = () => { return ( @@ -46,7 +46,7 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => { } const getEngine = () => { - const engineName = activeThread?.assistants?.[0]?.model?.engine + const engineName = activeAssistant?.model?.engine return engineName ? EngineManager.instance().get(engineName) : null } @@ -89,7 +89,9 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => { ) : ( <> - + {message?.content[0]?.text?.value && ( + + )} {defaultDesc()} )} diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx index f6adf090bc..6244162bb8 100644 --- a/web/containers/ModelDropdown/index.tsx +++ b/web/containers/ModelDropdown/index.tsx @@ -46,6 +46,7 @@ import { import { extensionManager } from '@/extension' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom' import { configuredModelsAtom, @@ -75,6 +76,7 @@ const ModelDropdown = ({ const [searchText, setSearchText] = useState('') const [open, setOpen] = useState(false) const activeThread = useAtomValue(activeThreadAtom) + const activeAssistant = useAtomValue(activeAssistantAtom) const downloadingModels = useAtomValue(getDownloadingModelAtom) const [toggle, setToggle] = useState(null) const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom) @@ -151,17 +153,24 @@ const ModelDropdown = ({ useEffect(() => { if (!activeThread) return - const modelId = activeThread?.assistants?.[0]?.model?.id + const modelId = activeAssistant?.model?.id let model = downloadedModels.find((model) => model.id === modelId) if (!model) { model = recommendedModel } setSelectedModel(model) - }, [recommendedModel, activeThread, downloadedModels, setSelectedModel]) + }, [ + recommendedModel, + activeThread, + downloadedModels, + setSelectedModel, + activeAssistant?.model?.id, + ]) const onClickModelItem = useCallback( async (modelId: string) => { + if (!activeAssistant) return const model = downloadedModels.find((m) => m.id === modelId) setSelectedModel(model) setOpen(false) @@ -172,14 +181,14 @@ const ModelDropdown = ({ ...activeThread, assistants: [ { - ...activeThread.assistants[0], + ...activeAssistant, tools: [ { type: 'retrieval', enabled: isModelSupportRagAndTools(model as Model), settings: { - ...(activeThread.assistants[0].tools && - activeThread.assistants[0].tools[0]?.settings), + ...(activeAssistant.tools && + activeAssistant.tools[0]?.settings), }, }, ], @@ -219,13 +228,14 @@ const ModelDropdown = ({ } }, [ + activeAssistant, downloadedModels, - activeThread, setSelectedModel, + activeThread, + updateThreadMetadata, isModelSupportRagAndTools, setThreadModelParams, updateModelParameter, - updateThreadMetadata, ] ) diff --git a/web/containers/Providers/Jotai.tsx b/web/containers/Providers/Jotai.tsx index 8f1433ea0c..5371097f48 100644 --- a/web/containers/Providers/Jotai.tsx +++ b/web/containers/Providers/Jotai.tsx @@ -8,7 +8,7 @@ import { FileInfo } from '@/types/file' export const editPromptAtom = atom('') export const currentPromptAtom = atom('') -export const fileUploadAtom = atom([]) +export const fileUploadAtom = atom() export const searchAtom = atom('') diff --git a/web/containers/Providers/ModelHandler.tsx b/web/containers/Providers/ModelHandler.tsx index 373c0aebd6..d72db4e8e0 100644 --- a/web/containers/Providers/ModelHandler.tsx +++ b/web/containers/Providers/ModelHandler.tsx @@ -31,6 +31,7 @@ import { addNewMessageAtom, updateMessageAtom, tokenSpeedAtom, + deleteMessageAtom, } from '@/helpers/atoms/ChatMessage.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { @@ -49,6 +50,7 @@ export default function ModelHandler() { const addNewMessage = useSetAtom(addNewMessageAtom) const updateMessage = useSetAtom(updateMessageAtom) const downloadedModels = useAtomValue(downloadedModelsAtom) + const deleteMessage = useSetAtom(deleteMessageAtom) const activeModel = useAtomValue(activeModelAtom) const setActiveModel = useSetAtom(activeModelAtom) const setStateModel = useSetAtom(stateModelAtom) @@ -86,7 +88,7 @@ export default function ModelHandler() { }, [activeModelParams]) const onNewMessageResponse = useCallback( - (message: ThreadMessage) => { + async (message: ThreadMessage) => { if (message.type === MessageRequestType.Thread) { addNewMessage(message) } @@ -154,12 +156,15 @@ export default function ModelHandler() { ...thread, title: cleanedMessageContent, - metadata: thread.metadata, + metadata: { + ...thread.metadata, + title: cleanedMessageContent, + }, } extensionManager .get(ExtensionTypeEnum.Conversational) - ?.saveThread({ + ?.modifyThread({ ...updatedThread, }) .then(() => { @@ -233,7 +238,9 @@ export default function ModelHandler() { const thread = threadsRef.current?.find((e) => e.id == message.thread_id) if (!thread) return + const messageContent = message.content[0]?.text?.value + const metadata = { ...thread.metadata, ...(messageContent && { lastMessage: messageContent }), @@ -246,15 +253,19 @@ export default function ModelHandler() { extensionManager .get(ExtensionTypeEnum.Conversational) - ?.saveThread({ + ?.modifyThread({ ...thread, metadata, }) - - // If this is not the summary of the Thread, don't need to add it to the Thread - extensionManager - .get(ExtensionTypeEnum.Conversational) - ?.addNewMessage(message) + ;(async () => { + const updatedMessage = await extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.createMessage(message) + if (updatedMessage) { + deleteMessage(message.id) + addNewMessage(updatedMessage) + } + })() // Attempt to generate the title of the Thread when needed generateThreadTitle(message, thread) @@ -279,7 +290,9 @@ export default function ModelHandler() { const generateThreadTitle = (message: ThreadMessage, thread: Thread) => { // If this is the first ever prompt in the thread - if (thread.title?.trim() !== defaultThreadTitle) { + if ( + (thread.title ?? thread.metadata?.title)?.trim() !== defaultThreadTitle + ) { return } @@ -292,11 +305,14 @@ export default function ModelHandler() { const updatedThread: Thread = { ...thread, title: (thread.metadata?.lastMessage as string) || defaultThreadTitle, - metadata: thread.metadata, + metadata: { + ...thread.metadata, + title: (thread.metadata?.lastMessage as string) || defaultThreadTitle, + }, } return extensionManager .get(ExtensionTypeEnum.Conversational) - ?.saveThread({ + ?.modifyThread({ ...updatedThread, }) .then(() => { @@ -313,7 +329,7 @@ export default function ModelHandler() { if (!threadMessages || threadMessages.length === 0) return - const summarizeFirstPrompt = `Summarize in a ${maxWordForThreadTitle}-word Title. Give the title only. "${threadMessages[0].content[0].text.value}"` + const summarizeFirstPrompt = `Summarize in a ${maxWordForThreadTitle}-word Title. Give the title only. "${threadMessages[0]?.content[0]?.text?.value}"` // Prompt: Given this query from user {query}, return to me the summary in 10 words as the title const msgId = ulid() @@ -330,6 +346,7 @@ export default function ModelHandler() { id: msgId, threadId: message.thread_id, type: MessageRequestType.Summary, + attachments: [], messages, model: { ...activeModelRef.current, diff --git a/web/helpers/atoms/Assistant.atom.ts b/web/helpers/atoms/Assistant.atom.ts index d44703cf41..cb50a0553e 100644 --- a/web/helpers/atoms/Assistant.atom.ts +++ b/web/helpers/atoms/Assistant.atom.ts @@ -1,4 +1,12 @@ -import { Assistant } from '@janhq/core' +import { Assistant, ThreadAssistantInfo } from '@janhq/core' import { atom } from 'jotai' +import { atomWithStorage } from 'jotai/utils' export const assistantsAtom = atom([]) + +/** + * Get the current active assistant + */ +export const activeAssistantAtom = atomWithStorage< + ThreadAssistantInfo | undefined +>('activeAssistant', undefined, undefined, { getOnInit: true }) diff --git a/web/helpers/atoms/ChatMessage.atom.ts b/web/helpers/atoms/ChatMessage.atom.ts index 1f6099a2e0..b0ec6c4936 100644 --- a/web/helpers/atoms/ChatMessage.atom.ts +++ b/web/helpers/atoms/ChatMessage.atom.ts @@ -6,6 +6,8 @@ import { } from '@janhq/core' import { atom } from 'jotai' +import { atomWithStorage } from 'jotai/utils' + import { getActiveThreadIdAtom, updateThreadStateLastMessageAtom, @@ -13,15 +15,23 @@ import { import { TokenSpeed } from '@/types/token' +const CHAT_MESSAGE_NAME = 'chatMessages' /** * Stores all chat messages for all threads */ -export const chatMessages = atom>({}) +export const chatMessages = atomWithStorage>( + CHAT_MESSAGE_NAME, + {}, + undefined, + { getOnInit: true } +) /** * Stores the status of the messages load for each thread */ -export const readyThreadsMessagesAtom = atom>({}) +export const readyThreadsMessagesAtom = atomWithStorage< + Record +>('currentThreadMessages', {}, undefined, { getOnInit: true }) /** * Store the token speed for current message @@ -34,6 +44,7 @@ export const getCurrentChatMessagesAtom = atom((get) => { const activeThreadId = get(getActiveThreadIdAtom) if (!activeThreadId) return [] const messages = get(chatMessages)[activeThreadId] + if (!Array.isArray(messages)) return [] return messages ?? [] }) diff --git a/web/helpers/atoms/Model.atom.test.ts b/web/helpers/atoms/Model.atom.test.ts index 923f24df47..b4eb87e7a8 100644 --- a/web/helpers/atoms/Model.atom.test.ts +++ b/web/helpers/atoms/Model.atom.test.ts @@ -58,7 +58,9 @@ describe('Model.atom.ts', () => { setAtom.current({ id: '1' } as any) }) expect(getAtom.current).toEqual([{ id: '1' }]) - reset.current([]) + act(() => { + reset.current([]) + }) }) }) @@ -83,7 +85,9 @@ describe('Model.atom.ts', () => { removeAtom.current('1') }) expect(getAtom.current).toEqual([]) - reset.current([]) + act(() => { + reset.current([]) + }) }) }) @@ -113,7 +117,9 @@ describe('Model.atom.ts', () => { removeAtom.current('1') }) expect(getAtom.current).toEqual([]) - reset.current([]) + act(() => { + reset.current([]) + }) }) }) diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts index 63513bee28..e436d116e0 100644 --- a/web/hooks/useActiveModel.ts +++ b/web/hooks/useActiveModel.ts @@ -8,6 +8,7 @@ import { toaster } from '@/containers/Toast' import { LAST_USED_MODEL_ID } from './useRecommendedModel' import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' @@ -34,6 +35,7 @@ export function useActiveModel() { const setLoadModelError = useSetAtom(loadModelErrorAtom) const pendingModelLoad = useRef(false) const isVulkanEnabled = useAtomValue(vulkanEnabledAtom) + const activeAssistant = useAtomValue(activeAssistantAtom) const downloadedModelsRef = useRef([]) @@ -79,12 +81,12 @@ export function useActiveModel() { } /// Apply thread model settings - if (activeThread?.assistants[0]?.model.id === modelId) { + if (activeAssistant?.model.id === modelId) { model = { ...model, settings: { ...model.settings, - ...activeThread.assistants[0].model.settings, + ...activeAssistant?.model.settings, }, } } diff --git a/web/hooks/useCreateNewThread.test.ts b/web/hooks/useCreateNewThread.test.ts index 25589c0988..d98983830d 100644 --- a/web/hooks/useCreateNewThread.test.ts +++ b/web/hooks/useCreateNewThread.test.ts @@ -67,7 +67,7 @@ describe('useCreateNewThread', () => { } as any) }) - expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set + expect(mockSetAtom).toHaveBeenCalledTimes(1) expect(extensionManager.get).toHaveBeenCalled() }) @@ -104,7 +104,7 @@ describe('useCreateNewThread', () => { await result.current.requestCreateNewThread({ id: 'assistant1', name: 'Assistant 1', - instructions: "Hello Jan Assistant", + instructions: 'Hello Jan Assistant', model: { id: 'model1', parameters: [], @@ -113,16 +113,8 @@ describe('useCreateNewThread', () => { } as any) }) - expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set + expect(mockSetAtom).toHaveBeenCalledTimes(1) // Check if all the necessary atoms were set expect(extensionManager.get).toHaveBeenCalled() - expect(mockSetAtom).toHaveBeenNthCalledWith( - 2, - expect.objectContaining({ - assistants: expect.arrayContaining([ - expect.objectContaining({ instructions: 'Hello Jan Assistant' }), - ]), - }) - ) }) it('should create a new thread with previous instructions', async () => { @@ -166,16 +158,8 @@ describe('useCreateNewThread', () => { } as any) }) - expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set + expect(mockSetAtom).toHaveBeenCalledTimes(1) // Check if all the necessary atoms were set expect(extensionManager.get).toHaveBeenCalled() - expect(mockSetAtom).toHaveBeenNthCalledWith( - 2, - expect.objectContaining({ - assistants: expect.arrayContaining([ - expect.objectContaining({ instructions: 'Hello Jan' }), - ]), - }) - ) }) it('should show a warning toast if trying to create an empty thread', async () => { @@ -212,13 +196,12 @@ describe('useCreateNewThread', () => { const { result } = renderHook(() => useCreateNewThread()) - const mockThread = { id: 'thread1', title: 'Test Thread' } + const mockThread = { id: 'thread1', title: 'Test Thread', assistants: [{}] } await act(async () => { await result.current.updateThreadMetadata(mockThread as any) }) expect(mockUpdateThread).toHaveBeenCalledWith(mockThread) - expect(extensionManager.get).toHaveBeenCalled() }) }) diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts index 63de2d3abb..02f3cd371b 100644 --- a/web/hooks/useCreateNewThread.ts +++ b/web/hooks/useCreateNewThread.ts @@ -1,7 +1,6 @@ import { useCallback } from 'react' import { - Assistant, ConversationalExtension, ExtensionTypeEnum, Thread, @@ -9,8 +8,11 @@ import { ThreadState, AssistantTool, Model, + Assistant, } from '@janhq/core' -import { atom, useAtomValue, useSetAtom } from 'jotai' +import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' + +import { useDebouncedCallback } from 'use-debounce' import { copyOverInstructionEnabledAtom } from '@/containers/CopyInstruction' import { fileUploadAtom } from '@/containers/Providers/Jotai' @@ -18,7 +20,6 @@ import { fileUploadAtom } from '@/containers/Providers/Jotai' import { toaster } from '@/containers/Toast' import { isLocalEngine } from '@/utils/modelEngine' -import { generateThreadId } from '@/utils/thread' import { useActiveModel } from './useActiveModel' import useRecommendedModel from './useRecommendedModel' @@ -28,6 +29,7 @@ import useSetActiveThread from './useSetActiveThread' import { extensionManager } from '@/extension' import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { threadsAtom, @@ -35,7 +37,6 @@ import { updateThreadAtom, setThreadModelParamsAtom, isGeneratingResponseAtom, - activeThreadAtom, } from '@/helpers/atoms/Thread.atom' const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => { @@ -65,7 +66,7 @@ export const useCreateNewThread = () => { const copyOverInstructionEnabled = useAtomValue( copyOverInstructionEnabledAtom ) - const activeThread = useAtomValue(activeThreadAtom) + const [activeAssistant, setActiveAssistant] = useAtom(activeAssistantAtom) const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom) const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) @@ -76,7 +77,7 @@ export const useCreateNewThread = () => { const { stopInference } = useActiveModel() const requestCreateNewThread = async ( - assistant: Assistant, + assistant: (ThreadAssistantInfo & { id: string; name: string }) | Assistant, model?: Model | undefined ) => { // Stop generating if any @@ -127,7 +128,7 @@ export const useCreateNewThread = () => { const createdAt = Date.now() let instructions: string | undefined = assistant.instructions if (copyOverInstructionEnabled) { - instructions = activeThread?.assistants[0]?.instructions ?? undefined + instructions = activeAssistant?.instructions ?? undefined } const assistantInfo: ThreadAssistantInfo = { assistant_id: assistant.id, @@ -142,46 +143,95 @@ export const useCreateNewThread = () => { instructions, } - const threadId = generateThreadId(assistant.id) - const thread: Thread = { - id: threadId, + const thread: Partial = { object: 'thread', title: 'New Thread', assistants: [assistantInfo], created: createdAt, updated: createdAt, + metadata: { + title: 'New Thread', + }, } // add the new thread on top of the thread list to the state //TODO: Why do we have thread list then thread states? Should combine them - createNewThread(thread) - - setSelectedModel(defaultModel) - setThreadModelParams(thread.id, { - ...defaultModel?.settings, - ...defaultModel?.parameters, - ...overriddenSettings, - }) + try { + const createdThread = await persistNewThread(thread, assistantInfo) + if (!createdThread) throw 'Thread created failed.' + createNewThread(createdThread) + + setSelectedModel(defaultModel) + setThreadModelParams(createdThread.id, { + ...defaultModel?.settings, + ...defaultModel?.parameters, + ...overriddenSettings, + }) + + // Delete the file upload state + setFileUpload(undefined) + setActiveThread(createdThread) + } catch (ex) { + return toaster({ + title: 'Thread created failed.', + description: `To avoid piling up empty threads, please reuse previous one before creating new.`, + type: 'error', + }) + } + } - // Delete the file upload state - setFileUpload([]) - // Update thread metadata - await updateThreadMetadata(thread) + const updateThreadExtension = (thread: Thread) => { + return extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.modifyThread(thread) + } - setActiveThread(thread) + const updateAssistantExtension = ( + threadId: string, + assistant: ThreadAssistantInfo + ) => { + return extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.modifyThreadAssistant(threadId, assistant) } + const updateThreadCallback = useDebouncedCallback(updateThreadExtension, 300) + const updateAssistantCallback = useDebouncedCallback( + updateAssistantExtension, + 300 + ) + const updateThreadMetadata = useCallback( async (thread: Thread) => { updateThread(thread) - await extensionManager - .get(ExtensionTypeEnum.Conversational) - ?.saveThread(thread) + setActiveAssistant(thread.assistants[0]) + updateThreadCallback(thread) + updateAssistantCallback(thread.id, thread.assistants[0]) }, - [updateThread] + [ + updateThread, + setActiveAssistant, + updateThreadCallback, + updateAssistantCallback, + ] ) + const persistNewThread = async ( + thread: Partial, + assistantInfo: ThreadAssistantInfo + ): Promise => { + return await extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.createThread(thread) + .then(async (thread) => { + await extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.createThreadAssistant(thread.id, assistantInfo) + return thread + }) + } + return { requestCreateNewThread, updateThreadMetadata, diff --git a/web/hooks/useDeleteThread.test.ts b/web/hooks/useDeleteThread.test.ts index d3a6138d07..50b0c7511b 100644 --- a/web/hooks/useDeleteThread.test.ts +++ b/web/hooks/useDeleteThread.test.ts @@ -2,8 +2,7 @@ import { renderHook, act } from '@testing-library/react' import { useAtom, useAtomValue, useSetAtom } from 'jotai' import useDeleteThread from './useDeleteThread' import { extensionManager } from '@/extension/ExtensionManager' -import { toaster } from '@/containers/Toast' - +import { useCreateNewThread } from './useCreateNewThread' // Mock the necessary dependencies // Mock dependencies jest.mock('jotai', () => ({ @@ -12,6 +11,7 @@ jest.mock('jotai', () => ({ useAtom: jest.fn(), atom: jest.fn(), })) +jest.mock('./useCreateNewThread') jest.mock('@/extension/ExtensionManager') jest.mock('@/containers/Toast') @@ -27,8 +27,13 @@ describe('useDeleteThread', () => { ] const mockSetThreads = jest.fn() ;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads]) + ;(useSetAtom as jest.Mock).mockReturnValue(() => {}) + ;(useCreateNewThread as jest.Mock).mockReturnValue({}) + + const mockDeleteThread = jest.fn().mockImplementation(() => ({ + catch: () => jest.fn, + })) - const mockDeleteThread = jest.fn() extensionManager.get = jest.fn().mockReturnValue({ deleteThread: mockDeleteThread, }) @@ -50,12 +55,17 @@ describe('useDeleteThread', () => { const mockCleanMessages = jest.fn() ;(useSetAtom as jest.Mock).mockReturnValue(() => mockCleanMessages) ;(useAtomValue as jest.Mock).mockReturnValue(['thread 1']) + const mockCreateNewThread = jest.fn() + ;(useCreateNewThread as jest.Mock).mockReturnValue({ + requestCreateNewThread: mockCreateNewThread, + }) - const mockWriteMessages = jest.fn() const mockSaveThread = jest.fn() + const mockDeleteThread = jest.fn().mockResolvedValue({}) extensionManager.get = jest.fn().mockReturnValue({ - writeMessages: mockWriteMessages, saveThread: mockSaveThread, + getThreadAssistant: jest.fn().mockResolvedValue({}), + deleteThread: mockDeleteThread, }) const { result } = renderHook(() => useDeleteThread()) @@ -64,20 +74,18 @@ describe('useDeleteThread', () => { await result.current.cleanThread('thread1') }) - expect(mockWriteMessages).toHaveBeenCalled() - expect(mockSaveThread).toHaveBeenCalledWith( - expect.objectContaining({ - id: 'thread1', - title: 'New Thread', - metadata: expect.objectContaining({ lastMessage: undefined }), - }) - ) + expect(mockDeleteThread).toHaveBeenCalled() + expect(mockCreateNewThread).toHaveBeenCalled() }) it('should handle errors when deleting a thread', async () => { const mockThreads = [{ id: 'thread1', title: 'Thread 1' }] const mockSetThreads = jest.fn() ;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads]) + const mockCreateNewThread = jest.fn() + ;(useCreateNewThread as jest.Mock).mockReturnValue({ + requestCreateNewThread: mockCreateNewThread, + }) const mockDeleteThread = jest .fn() @@ -98,8 +106,6 @@ describe('useDeleteThread', () => { expect(mockDeleteThread).toHaveBeenCalledWith('thread1') expect(consoleErrorSpy).toHaveBeenCalledWith(expect.any(Error)) - expect(mockSetThreads).not.toHaveBeenCalled() - expect(toaster).not.toHaveBeenCalled() consoleErrorSpy.mockRestore() }) diff --git a/web/hooks/useDeleteThread.ts b/web/hooks/useDeleteThread.ts index 69e51228f1..7b98a4ea5c 100644 --- a/web/hooks/useDeleteThread.ts +++ b/web/hooks/useDeleteThread.ts @@ -1,13 +1,6 @@ import { useCallback } from 'react' -import { - ChatCompletionRole, - ExtensionTypeEnum, - ConversationalExtension, - fs, - joinPath, - Thread, -} from '@janhq/core' +import { ExtensionTypeEnum, ConversationalExtension } from '@janhq/core' import { useAtom, useAtomValue, useSetAtom } from 'jotai' @@ -15,89 +8,63 @@ import { currentPromptAtom } from '@/containers/Providers/Jotai' import { toaster } from '@/containers/Toast' +import { useCreateNewThread } from './useCreateNewThread' + import { extensionManager } from '@/extension/ExtensionManager' -import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom' -import { - chatMessages, - cleanChatMessageAtom as cleanChatMessagesAtom, - deleteChatMessageAtom as deleteChatMessagesAtom, -} from '@/helpers/atoms/ChatMessage.atom' +import { assistantsAtom } from '@/helpers/atoms/Assistant.atom' +import { deleteChatMessageAtom as deleteChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' +import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { threadsAtom, setActiveThreadIdAtom, deleteThreadStateAtom, - updateThreadStateLastMessageAtom, - updateThreadAtom, } from '@/helpers/atoms/Thread.atom' export default function useDeleteThread() { const [threads, setThreads] = useAtom(threadsAtom) - const messages = useAtomValue(chatMessages) - const janDataFolderPath = useAtomValue(janDataFolderPathAtom) + const { requestCreateNewThread } = useCreateNewThread() + const assistants = useAtomValue(assistantsAtom) + const models = useAtomValue(downloadedModelsAtom) const setCurrentPrompt = useSetAtom(currentPromptAtom) const setActiveThreadId = useSetAtom(setActiveThreadIdAtom) const deleteMessages = useSetAtom(deleteChatMessagesAtom) - const cleanMessages = useSetAtom(cleanChatMessagesAtom) const deleteThreadState = useSetAtom(deleteThreadStateAtom) - const updateThreadLastMessage = useSetAtom(updateThreadStateLastMessageAtom) - const updateThread = useSetAtom(updateThreadAtom) const cleanThread = useCallback( async (threadId: string) => { - cleanMessages(threadId) const thread = threads.find((c) => c.id === threadId) if (!thread) return - - const updatedMessages = (messages[threadId] ?? []).filter( - (msg) => msg.role === ChatCompletionRole.System - ) - - // remove files - try { - const threadFolderPath = await joinPath([ - janDataFolderPath, - 'threads', - threadId, - ]) - const threadFilesPath = await joinPath([threadFolderPath, 'files']) - const threadMemoryPath = await joinPath([threadFolderPath, 'memory']) - await fs.rm(threadFilesPath) - await fs.rm(threadMemoryPath) - } catch (err) { - console.warn('Error deleting thread files', err) - } - - await extensionManager + const assistantInfo = await extensionManager .get(ExtensionTypeEnum.Conversational) - ?.writeMessages(threadId, updatedMessages) - - thread.metadata = { - ...thread.metadata, - } - - const updatedThread: Thread = { - ...thread, - title: 'New Thread', - metadata: { ...thread.metadata, lastMessage: undefined }, - } - + ?.getThreadAssistant(thread.id) + + if (!assistantInfo) return + const model = models.find((c) => c.id === assistantInfo?.model?.id) + + requestCreateNewThread( + { + ...assistantInfo, + id: assistants[0].id, + name: assistants[0].name, + }, + model + ? { + ...model, + parameters: assistantInfo?.model?.parameters ?? {}, + settings: assistantInfo?.model?.settings ?? {}, + } + : undefined + ) + // Delete this thread await extensionManager .get(ExtensionTypeEnum.Conversational) - ?.saveThread(updatedThread) - updateThreadLastMessage(threadId, undefined) - updateThread(updatedThread) + ?.deleteThread(threadId) + .catch(console.error) }, - [ - cleanMessages, - threads, - messages, - updateThreadLastMessage, - updateThread, - janDataFolderPath, - ] + [assistants, models, requestCreateNewThread, threads] ) const deleteThread = async (threadId: string) => { @@ -105,30 +72,27 @@ export default function useDeleteThread() { alert('No active thread') return } - try { - await extensionManager - .get(ExtensionTypeEnum.Conversational) - ?.deleteThread(threadId) - const availableThreads = threads.filter((c) => c.id !== threadId) - setThreads(availableThreads) - - // delete the thread state - deleteThreadState(threadId) - - deleteMessages(threadId) - setCurrentPrompt('') - toaster({ - title: 'Thread successfully deleted.', - description: `Thread ${threadId} has been successfully deleted.`, - type: 'success', - }) - if (availableThreads.length > 0) { - setActiveThreadId(availableThreads[0].id) - } else { - setActiveThreadId(undefined) - } - } catch (err) { - console.error(err) + await extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.deleteThread(threadId) + .catch(console.error) + const availableThreads = threads.filter((c) => c.id !== threadId) + setThreads(availableThreads) + + // delete the thread state + deleteThreadState(threadId) + + deleteMessages(threadId) + setCurrentPrompt('') + toaster({ + title: 'Thread successfully deleted.', + description: `Thread ${threadId} has been successfully deleted.`, + type: 'success', + }) + if (availableThreads.length > 0) { + setActiveThreadId(availableThreads[0].id) + } else { + setActiveThreadId(undefined) } } diff --git a/web/hooks/useDropModelBinaries.test.ts b/web/hooks/useDropModelBinaries.test.ts index dad8c6178f..7ca5a479ec 100644 --- a/web/hooks/useDropModelBinaries.test.ts +++ b/web/hooks/useDropModelBinaries.test.ts @@ -1,3 +1,6 @@ +/** + * @jest-environment jsdom + */ // useDropModelBinaries.test.ts import { renderHook, act } from '@testing-library/react' @@ -18,6 +21,7 @@ jest.mock('jotai', () => ({ jest.mock('uuid') jest.mock('@/utils/file') jest.mock('@/containers/Toast') +jest.mock("@uppy/core") describe('useDropModelBinaries', () => { const mockSetImportingModels = jest.fn() diff --git a/web/hooks/usePath.ts b/web/hooks/usePath.ts index b732926a69..afdafe11ff 100644 --- a/web/hooks/usePath.ts +++ b/web/hooks/usePath.ts @@ -2,6 +2,7 @@ import { openFileExplorer, joinPath, baseName } from '@janhq/core' import { useAtomValue } from 'jotai' import { janDataFolderPathAtom } from '@/helpers/atoms/AppConfig.atom' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' @@ -9,13 +10,14 @@ export const usePath = () => { const janDataFolderPath = useAtomValue(janDataFolderPathAtom) const activeThread = useAtomValue(activeThreadAtom) const selectedModel = useAtomValue(selectedModelAtom) + const activeAssistant = useAtomValue(activeAssistantAtom) const onRevealInFinder = async (type: string) => { // TODO: this logic should be refactored. if (type !== 'Model' && !activeThread) return let filePath = undefined - const assistantId = activeThread?.assistants[0]?.assistant_id + const assistantId = activeAssistant?.assistant_id switch (type) { case 'Engine': case 'Thread': diff --git a/web/hooks/useRecommendedModel.ts b/web/hooks/useRecommendedModel.ts index d5bf0aba73..e1702701bb 100644 --- a/web/hooks/useRecommendedModel.ts +++ b/web/hooks/useRecommendedModel.ts @@ -6,6 +6,7 @@ import { atom, useAtomValue } from 'jotai' import { activeModelAtom } from './useActiveModel' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' @@ -28,6 +29,7 @@ export default function useRecommendedModel() { const [recommendedModel, setRecommendedModel] = useState() const activeThread = useAtomValue(activeThreadAtom) const downloadedModels = useAtomValue(downloadedModelsAtom) + const activeAssistant = useAtomValue(activeAssistantAtom) const getAndSortDownloadedModels = useCallback(async (): Promise => { const models = downloadedModels.sort((a, b) => @@ -45,8 +47,8 @@ export default function useRecommendedModel() { > => { const models = await getAndSortDownloadedModels() - if (!activeThread) return - const modelId = activeThread.assistants[0]?.model.id + if (!activeThread || !activeAssistant) return + const modelId = activeAssistant.model.id const model = models.find((model) => model.id === modelId) if (model) { diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index dc9a52f1be..bbe5e3cd71 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -10,6 +10,7 @@ import { ConversationalExtension, EngineManager, ToolManager, + ThreadAssistantInfo, } from '@janhq/core' import { extractInferenceParams, extractModelLoadParams } from '@janhq/core' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' @@ -28,6 +29,7 @@ import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder' import { useActiveModel } from './useActiveModel' import { extensionManager } from '@/extension/ExtensionManager' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { addNewMessageAtom, deleteMessageAtom, @@ -48,6 +50,7 @@ export const reloadModelAtom = atom(false) export default function useSendChatMessage() { const activeThread = useAtomValue(activeThreadAtom) + const activeAssistant = useAtomValue(activeAssistantAtom) const addNewMessage = useSetAtom(addNewMessageAtom) const updateThread = useSetAtom(updateThreadAtom) const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom) @@ -68,6 +71,7 @@ export default function useSendChatMessage() { const [fileUpload, setFileUpload] = useAtom(fileUploadAtom) const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) const activeThreadRef = useRef() + const activeAssistantRef = useRef() const setTokenSpeed = useSetAtom(tokenSpeedAtom) const selectedModelRef = useRef() @@ -84,36 +88,37 @@ export default function useSendChatMessage() { selectedModelRef.current = selectedModel }, [selectedModel]) - const resendChatMessage = async (currentMessage: ThreadMessage) => { + useEffect(() => { + activeAssistantRef.current = activeAssistant + }, [activeAssistant]) + + const resendChatMessage = async () => { // Delete last response before regenerating - const newConvoData = currentMessages - let toSendMessage = currentMessage - - do { - deleteMessage(currentMessage.id) - const msg = newConvoData.pop() - if (!msg) break - toSendMessage = msg - deleteMessage(toSendMessage.id ?? '') - } while (toSendMessage.role !== ChatCompletionRole.User) + const newConvoData = Array.from(currentMessages) + let toSendMessage = newConvoData.pop() - if (activeThreadRef.current) { + while (toSendMessage && toSendMessage?.role !== ChatCompletionRole.User) { await extensionManager .get(ExtensionTypeEnum.Conversational) - ?.writeMessages(activeThreadRef.current.id, newConvoData) + ?.deleteMessage(toSendMessage.thread_id, toSendMessage.id) + .catch(console.error) + deleteMessage(toSendMessage.id ?? '') + toSendMessage = newConvoData.pop() } - sendChatMessage(toSendMessage.content[0]?.text.value) + if (toSendMessage?.content[0]?.text?.value) + sendChatMessage(toSendMessage.content[0].text.value, true) } const sendChatMessage = async ( message: string, + isResend: boolean = false, messages?: ThreadMessage[] ) => { if (!message || message.trim().length === 0) return - if (!activeThreadRef.current) { - console.error('No active thread') + if (!activeThreadRef.current || !activeAssistantRef.current) { + console.error('No active thread or assistant') return } @@ -129,21 +134,19 @@ export default function useSendChatMessage() { setCurrentPrompt('') setEditPrompt('') - let base64Blob = fileUpload[0] - ? await getBase64(fileUpload[0].file) - : undefined + let base64Blob = fileUpload ? await getBase64(fileUpload.file) : undefined - if (base64Blob && fileUpload[0]?.type === 'image') { + if (base64Blob && fileUpload?.type === 'image') { // Compress image base64Blob = await compressImage(base64Blob, 512) } const modelRequest = - selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model + selectedModelRef?.current ?? activeAssistantRef.current?.model // Fallback support for previous broken threads - if (activeThreadRef.current?.assistants[0]?.model?.id === '*') { - activeThreadRef.current.assistants[0].model = { + if (activeAssistantRef.current?.model?.id === '*') { + activeAssistantRef.current.model = { id: modelRequest.id, settings: modelRequest.settings, parameters: modelRequest.parameters, @@ -163,46 +166,49 @@ export default function useSendChatMessage() { }, activeThreadRef.current, messages ?? currentMessages - ).addSystemMessage(activeThreadRef.current.assistants[0].instructions) - - requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type) - - // Build Thread Message to persist - const threadMessageBuilder = new ThreadMessageBuilder( - requestBuilder - ).pushMessage(prompt, base64Blob, fileUpload) + ).addSystemMessage(activeAssistantRef.current?.instructions) + + if (!isResend) { + requestBuilder.pushMessage(prompt, base64Blob, fileUpload) + + // Build Thread Message to persist + const threadMessageBuilder = new ThreadMessageBuilder( + requestBuilder + ).pushMessage(prompt, base64Blob, fileUpload) + + const newMessage = threadMessageBuilder.build() + + // Update thread state + const updatedThread: Thread = { + ...activeThreadRef.current, + updated: newMessage.created, + metadata: { + ...activeThreadRef.current.metadata, + lastMessage: prompt, + }, + } + updateThread(updatedThread) - const newMessage = threadMessageBuilder.build() + // Add message + const createdMessage = await extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.createMessage(newMessage) - // Push to states - addNewMessage(newMessage) + if (!createdMessage) return - // Update thread state - const updatedThread: Thread = { - ...activeThreadRef.current, - updated: newMessage.created, - metadata: { - ...activeThreadRef.current.metadata, - lastMessage: prompt, - }, + // Push to states + addNewMessage(createdMessage) } - updateThread(updatedThread) - - // Add message - await extensionManager - .get(ExtensionTypeEnum.Conversational) - ?.addNewMessage(newMessage) // Start Model if not started const modelId = - selectedModelRef.current?.id ?? - activeThreadRef.current.assistants[0].model.id + selectedModelRef.current?.id ?? activeAssistantRef.current?.model.id if (base64Blob) { - setFileUpload([]) + setFileUpload(undefined) } - if (modelRef.current?.id !== modelId) { + if (modelRef.current?.id !== modelId && modelId) { const error = await startModel(modelId).catch((error: Error) => error) if (error) { updateThreadWaiting(activeThreadRef.current.id, false) @@ -214,9 +220,7 @@ export default function useSendChatMessage() { // Process message request with Assistants tools const request = await ToolManager.instance().process( requestBuilder.build(), - activeThreadRef.current.assistants?.flatMap( - (assistant) => assistant.tools ?? [] - ) ?? [] + activeAssistantRef?.current.tools ?? [] ) // Request for inference diff --git a/web/hooks/useSetActiveThread.ts b/web/hooks/useSetActiveThread.ts index 6b306224db..8c7ed5361b 100644 --- a/web/hooks/useSetActiveThread.ts +++ b/web/hooks/useSetActiveThread.ts @@ -1,12 +1,10 @@ import { ExtensionTypeEnum, Thread, ConversationalExtension } from '@janhq/core' -import { useAtomValue, useSetAtom } from 'jotai' +import { useSetAtom } from 'jotai' import { extensionManager } from '@/extension' -import { - readyThreadsMessagesAtom, - setConvoMessagesAtom, -} from '@/helpers/atoms/ChatMessage.atom' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' +import { setConvoMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' import { setActiveThreadIdAtom, setThreadModelParamsAtom, @@ -17,21 +15,27 @@ export default function useSetActiveThread() { const setActiveThreadId = useSetAtom(setActiveThreadIdAtom) const setThreadMessage = useSetAtom(setConvoMessagesAtom) const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) - const readyMessageThreads = useAtomValue(readyThreadsMessagesAtom) + const setActiveAssistant = useSetAtom(activeAssistantAtom) const setActiveThread = async (thread: Thread) => { - // Load local messages only if there are no messages in the state - if (!readyMessageThreads[thread?.id]) { - const messages = await getLocalThreadMessage(thread?.id) - setThreadMessage(thread?.id, messages) - } + if (!thread?.id) return setActiveThreadId(thread?.id) - const modelParams: ModelParams = { - ...thread?.assistants[0]?.model?.parameters, - ...thread?.assistants[0]?.model?.settings, + + try { + const assistantInfo = await getThreadAssistant(thread.id) + setActiveAssistant(assistantInfo) + // Load local messages only if there are no messages in the state + const messages = await getLocalThreadMessage(thread.id).catch(() => []) + const modelParams: ModelParams = { + ...assistantInfo?.model?.parameters, + ...assistantInfo?.model?.settings, + } + setThreadModelParams(thread?.id, modelParams) + setThreadMessage(thread.id, messages) + } catch (e) { + console.error(e) } - setThreadModelParams(thread?.id, modelParams) } return { setActiveThread } @@ -40,4 +44,9 @@ export default function useSetActiveThread() { const getLocalThreadMessage = async (threadId: string) => extensionManager .get(ExtensionTypeEnum.Conversational) - ?.getAllMessages(threadId) ?? [] + ?.listMessages(threadId) ?? [] + +const getThreadAssistant = async (threadId: string) => + extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.getThreadAssistant(threadId) diff --git a/web/hooks/useThread.test.ts b/web/hooks/useThread.test.ts index a40c709be6..4db7f87aca 100644 --- a/web/hooks/useThread.test.ts +++ b/web/hooks/useThread.test.ts @@ -78,7 +78,7 @@ describe('useThreads', () => { // Mock extensionManager const mockGetThreads = jest.fn().mockResolvedValue(mockThreads) ;(extensionManager.get as jest.Mock).mockReturnValue({ - getThreads: mockGetThreads, + listThreads: mockGetThreads, }) const { result } = renderHook(() => useThreads()) @@ -119,7 +119,7 @@ describe('useThreads', () => { it('should handle empty threads', async () => { // Mock empty threads ;(extensionManager.get as jest.Mock).mockReturnValue({ - getThreads: jest.fn().mockResolvedValue([]), + listThreads: jest.fn().mockResolvedValue([]), }) const mockSetThreadStates = jest.fn() diff --git a/web/hooks/useThreads.ts b/web/hooks/useThreads.ts index 9366101c3a..1e3b428a9f 100644 --- a/web/hooks/useThreads.ts +++ b/web/hooks/useThreads.ts @@ -68,6 +68,6 @@ const useThreads = () => { const getLocalThreads = async (): Promise => (await extensionManager .get(ExtensionTypeEnum.Conversational) - ?.getThreads()) ?? [] + ?.listThreads()) ?? [] export default useThreads diff --git a/web/hooks/useUpdateModelParameters.test.ts b/web/hooks/useUpdateModelParameters.test.ts index bc60aa631c..6c7ceb8b03 100644 --- a/web/hooks/useUpdateModelParameters.test.ts +++ b/web/hooks/useUpdateModelParameters.test.ts @@ -1,7 +1,12 @@ import { renderHook, act } from '@testing-library/react' +import { useAtom } from 'jotai' // Mock dependencies jest.mock('ulidx') jest.mock('@/extension') +jest.mock('jotai', () => ({ + ...jest.requireActual('jotai'), + useAtom: jest.fn(), +})) import useUpdateModelParameters from './useUpdateModelParameters' import { extensionManager } from '@/extension' @@ -13,7 +18,8 @@ let model: any = { } let extension: any = { - saveThread: jest.fn(), + modifyThread: jest.fn(), + modifyThreadAssistant: jest.fn(), } const mockThread: any = { @@ -35,6 +41,7 @@ const mockThread: any = { describe('useUpdateModelParameters', () => { beforeAll(() => { jest.clearAllMocks() + jest.useFakeTimers() jest.mock('./useRecommendedModel', () => ({ useRecommendedModel: () => ({ recommendedModel: model, @@ -45,6 +52,12 @@ describe('useUpdateModelParameters', () => { }) it('should update model parameters and save thread when params are valid', async () => { + ;(useAtom as jest.Mock).mockReturnValue([ + { + id: 'assistant-1', + }, + jest.fn(), + ]) const mockValidParameters: any = { params: { // Inference @@ -76,7 +89,8 @@ describe('useUpdateModelParameters', () => { // Spy functions jest.spyOn(extensionManager, 'get').mockReturnValue(extension) - jest.spyOn(extension, 'saveThread').mockReturnValue({}) + jest.spyOn(extension, 'modifyThread').mockReturnValue({}) + jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({}) const { result } = renderHook(() => useUpdateModelParameters()) @@ -84,44 +98,46 @@ describe('useUpdateModelParameters', () => { await result.current.updateModelParameter(mockThread, mockValidParameters) }) + jest.runAllTimers() + // Check if the model parameters are valid before persisting - expect(extension.saveThread).toHaveBeenCalledWith({ - assistants: [ - { - model: { - parameters: { - stop: ['', ''], - temperature: 0.5, - token_limit: 1000, - top_k: 0.7, - top_p: 0.1, - stream: true, - max_tokens: 1000, - frequency_penalty: 0.3, - presence_penalty: 0.2, - }, - settings: { - ctx_len: 1024, - ngl: 12, - embedding: true, - n_parallel: 2, - cpu_threads: 4, - prompt_template: 'template', - llama_model_path: 'path', - mmproj: 'mmproj', - }, - }, + expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', { + id: 'assistant-1', + model: { + parameters: { + stop: ['', ''], + temperature: 0.5, + token_limit: 1000, + top_k: 0.7, + top_p: 0.1, + stream: true, + max_tokens: 1000, + frequency_penalty: 0.3, + presence_penalty: 0.2, }, - ], - created: 0, - id: 'thread-1', - object: 'thread', - title: 'New Thread', - updated: 0, + settings: { + ctx_len: 1024, + ngl: 12, + embedding: true, + n_parallel: 2, + cpu_threads: 4, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + }, + id: 'model-1', + engine: 'nitro', + }, }) }) it('should not update invalid model parameters', async () => { + ;(useAtom as jest.Mock).mockReturnValue([ + { + id: 'assistant-1', + }, + jest.fn(), + ]) const mockInvalidParameters: any = { params: { // Inference @@ -153,7 +169,8 @@ describe('useUpdateModelParameters', () => { // Spy functions jest.spyOn(extensionManager, 'get').mockReturnValue(extension) - jest.spyOn(extension, 'saveThread').mockReturnValue({}) + jest.spyOn(extension, 'modifyThread').mockReturnValue({}) + jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({}) const { result } = renderHook(() => useUpdateModelParameters()) @@ -164,36 +181,38 @@ describe('useUpdateModelParameters', () => { ) }) + jest.runAllTimers() + // Check if the model parameters are valid before persisting - expect(extension.saveThread).toHaveBeenCalledWith({ - assistants: [ - { - model: { - parameters: { - max_tokens: 1000, - token_limit: 1000, - }, - settings: { - cpu_threads: 4, - ctx_len: 1024, - prompt_template: 'template', - llama_model_path: 'path', - mmproj: 'mmproj', - n_parallel: 2, - ngl: 12, - }, - }, + expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', { + id: 'assistant-1', + model: { + engine: 'nitro', + id: 'model-1', + parameters: { + token_limit: 1000, + max_tokens: 1000, + }, + settings: { + cpu_threads: 4, + ctx_len: 1024, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + n_parallel: 2, + ngl: 12, }, - ], - created: 0, - id: 'thread-1', - object: 'thread', - title: 'New Thread', - updated: 0, + }, }) }) it('should update valid model parameters only', async () => { + ;(useAtom as jest.Mock).mockReturnValue([ + { + id: 'assistant-1', + }, + jest.fn(), + ]) const mockInvalidParameters: any = { params: { // Inference @@ -225,8 +244,8 @@ describe('useUpdateModelParameters', () => { // Spy functions jest.spyOn(extensionManager, 'get').mockReturnValue(extension) - jest.spyOn(extension, 'saveThread').mockReturnValue({}) - + jest.spyOn(extension, 'modifyThread').mockReturnValue({}) + jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({}) const { result } = renderHook(() => useUpdateModelParameters()) await act(async () => { @@ -235,80 +254,33 @@ describe('useUpdateModelParameters', () => { mockInvalidParameters ) }) + jest.runAllTimers() // Check if the model parameters are valid before persisting - expect(extension.saveThread).toHaveBeenCalledWith({ - assistants: [ - { - model: { - parameters: { - stop: [''], - top_k: 0.7, - top_p: 0.1, - stream: true, - token_limit: 100, - max_tokens: 1000, - presence_penalty: 0.2, - }, - settings: { - ctx_len: 1024, - ngl: 0, - n_parallel: 2, - cpu_threads: 4, - prompt_template: 'template', - llama_model_path: 'path', - mmproj: 'mmproj', - }, - }, + expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', { + id: 'assistant-1', + model: { + engine: 'nitro', + id: 'model-1', + parameters: { + stop: [''], + top_k: 0.7, + top_p: 0.1, + stream: true, + token_limit: 100, + max_tokens: 1000, + presence_penalty: 0.2, }, - ], - created: 0, - id: 'thread-1', - object: 'thread', - title: 'New Thread', - updated: 0, - }) - }) - - it('should handle missing modelId and engine gracefully', async () => { - const mockParametersWithoutModelIdAndEngine: any = { - params: { - stop: ['', ''], - temperature: 0.5, - }, - } - - // Spy functions - jest.spyOn(extensionManager, 'get').mockReturnValue(extension) - jest.spyOn(extension, 'saveThread').mockReturnValue({}) - - const { result } = renderHook(() => useUpdateModelParameters()) - - await act(async () => { - await result.current.updateModelParameter( - mockThread, - mockParametersWithoutModelIdAndEngine - ) - }) - - // Check if the model parameters are valid before persisting - expect(extension.saveThread).toHaveBeenCalledWith({ - assistants: [ - { - model: { - parameters: { - stop: ['', ''], - temperature: 0.5, - }, - settings: {}, - }, + settings: { + ctx_len: 1024, + ngl: 0, + n_parallel: 2, + cpu_threads: 4, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', }, - ], - created: 0, - id: 'thread-1', - object: 'thread', - title: 'New Thread', - updated: 0, + }, }) }) }) diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts index 6eb7c3c5a9..dab2f6e284 100644 --- a/web/hooks/useUpdateModelParameters.ts +++ b/web/hooks/useUpdateModelParameters.ts @@ -12,7 +12,10 @@ import { import { useAtom, useAtomValue, useSetAtom } from 'jotai' +import { useDebouncedCallback } from 'use-debounce' + import { extensionManager } from '@/extension' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { getActiveThreadModelParamsAtom, @@ -29,11 +32,28 @@ export type UpdateModelParameter = { export default function useUpdateModelParameters() { const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) + const [activeAssistant, setActiveAssistant] = useAtom(activeAssistantAtom) const [selectedModel] = useAtom(selectedModelAtom) const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) + const updateAssistantExtension = ( + threadId: string, + assistant: ThreadAssistantInfo + ) => { + return extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.modifyThreadAssistant(threadId, assistant) + } + + const updateAssistantCallback = useDebouncedCallback( + updateAssistantExtension, + 300 + ) + const updateModelParameter = useCallback( async (thread: Thread, settings: UpdateModelParameter) => { + if (!activeAssistant) return + const toUpdateSettings = processStopWords(settings.params ?? {}) const updatedModelParams = settings.modelId ? toUpdateSettings @@ -48,30 +68,34 @@ export default function useUpdateModelParameters() { setThreadModelParams(thread.id, updatedModelParams) const runtimeParams = extractInferenceParams(updatedModelParams) const settingParams = extractModelLoadParams(updatedModelParams) - - const assistants = thread.assistants.map( - (assistant: ThreadAssistantInfo) => { - assistant.model.parameters = runtimeParams - assistant.model.settings = settingParams - if (selectedModel) { - assistant.model.id = settings.modelId ?? selectedModel?.id - assistant.model.engine = settings.engine ?? selectedModel?.engine - } - return assistant - } - ) - - // update thread - const updatedThread: Thread = { - ...thread, - assistants, + const assistantInfo = { + ...activeAssistant, + model: { + ...activeAssistant?.model, + parameters: runtimeParams, + settings: settingParams, + id: settings.modelId ?? selectedModel?.id ?? activeAssistant.model.id, + engine: + settings.engine ?? + selectedModel?.engine ?? + activeAssistant.model.engine, + }, } + setActiveAssistant(assistantInfo) - await extensionManager - .get(ExtensionTypeEnum.Conversational) - ?.saveThread(updatedThread) + updateAssistantCallback(thread.id, assistantInfo) }, - [activeModelParams, selectedModel, setThreadModelParams] + [ + activeAssistant, + selectedModel?.parameters, + selectedModel?.settings, + selectedModel?.id, + selectedModel?.engine, + activeModelParams, + setThreadModelParams, + setActiveAssistant, + updateAssistantCallback, + ] ) const processStopWords = (params: ModelParams): ModelParams => { diff --git a/web/jest.config.js b/web/jest.config.js index f780075323..27e8d0bda3 100644 --- a/web/jest.config.js +++ b/web/jest.config.js @@ -37,5 +37,5 @@ const config = { // module.exports = createJestConfig(config) module.exports = async () => ({ ...(await createJestConfig(config)()), - transformIgnorePatterns: ['/node_modules/(?!(layerr)/)'], + transformIgnorePatterns: ['/node_modules/(?!(layerr|nanoid|@uppy|preact)/)'], }) diff --git a/web/next.config.js b/web/next.config.js index 8c57dd2268..b6da1780c0 100644 --- a/web/next.config.js +++ b/web/next.config.js @@ -35,7 +35,7 @@ const nextConfig = { POSTHOG_HOST: JSON.stringify(process.env.POSTHOG_HOST), ANALYTICS_HOST: JSON.stringify(process.env.ANALYTICS_HOST), API_BASE_URL: JSON.stringify( - process.env.API_BASE_URL ?? 'http://localhost:1337' + process.env.API_BASE_URL ?? 'http://127.0.0.1:39291' ), isMac: process.platform === 'darwin', isWindows: process.platform === 'win32', diff --git a/web/package.json b/web/package.json index 3518c7678b..db57facb51 100644 --- a/web/package.json +++ b/web/package.json @@ -17,6 +17,9 @@ "@janhq/core": "link:./core", "@janhq/joi": "link:./joi", "@tanstack/react-virtual": "^3.10.9", + "@uppy/core": "^4.3.0", + "@uppy/react": "^4.0.4", + "@uppy/xhr-upload": "^4.2.3", "autoprefixer": "10.4.16", "class-variance-authority": "^0.7.0", "framer-motion": "^10.16.4", diff --git a/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.test.tsx b/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.test.tsx index 96ff6f559a..9b4e67ffbf 100644 --- a/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.test.tsx +++ b/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.test.tsx @@ -7,6 +7,8 @@ import { useAtomValue, useSetAtom } from 'jotai' import { useActiveModel } from '@/hooks/useActiveModel' import { useCreateNewThread } from '@/hooks/useCreateNewThread' import AssistantSetting from './index' +import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' jest.mock('jotai', () => { const originalModule = jest.requireActual('jotai') @@ -68,6 +70,7 @@ describe('AssistantSetting Component', () => { beforeEach(() => { jest.clearAllMocks() + jest.useFakeTimers() }) test('renders AssistantSetting component with proper data', async () => { @@ -75,7 +78,14 @@ describe('AssistantSetting Component', () => { ;(useSetAtom as jest.Mock).mockImplementationOnce( () => setEngineParamsUpdate ) - ;(useAtomValue as jest.Mock).mockImplementationOnce(() => mockActiveThread) + ;(useAtomValue as jest.Mock).mockImplementation((atom) => { + switch (atom) { + case activeThreadAtom: + return mockActiveThread + case activeAssistantAtom: + return {} + } + }) const updateThreadMetadata = jest.fn() ;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel: jest.fn() }) ;(useCreateNewThread as jest.Mock).mockReturnValueOnce({ @@ -98,7 +108,14 @@ describe('AssistantSetting Component', () => { const setEngineParamsUpdate = jest.fn() const updateThreadMetadata = jest.fn() const stopModel = jest.fn() - ;(useAtomValue as jest.Mock).mockImplementationOnce(() => mockActiveThread) + ;(useAtomValue as jest.Mock).mockImplementation((atom) => { + switch (atom) { + case activeThreadAtom: + return mockActiveThread + case activeAssistantAtom: + return {} + } + }) ;(useSetAtom as jest.Mock).mockImplementation(() => setEngineParamsUpdate) ;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel }) ;(useCreateNewThread as jest.Mock).mockReturnValueOnce({ diff --git a/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx b/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx index 95c905dde3..19ec3328a5 100644 --- a/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx @@ -8,6 +8,7 @@ import { useCreateNewThread } from '@/hooks/useCreateNewThread' import SettingComponentBuilder from '../../../../containers/ModelSetting/SettingComponent' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { activeThreadAtom, engineParamsUpdateAtom, @@ -19,13 +20,14 @@ type Props = { const AssistantSetting: React.FC = ({ componentData }) => { const activeThread = useAtomValue(activeThreadAtom) + const activeAssistant = useAtomValue(activeAssistantAtom) const { updateThreadMetadata } = useCreateNewThread() const { stopModel } = useActiveModel() const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom) const onValueChanged = useCallback( (key: string, value: string | number | boolean | string[]) => { - if (!activeThread) return + if (!activeThread || !activeAssistant) return const shouldReloadModel = componentData.find((x) => x.key === key)?.requireModelReload ?? false if (shouldReloadModel) { @@ -34,40 +36,40 @@ const AssistantSetting: React.FC = ({ componentData }) => { } if ( - activeThread.assistants[0].tools && + activeAssistant?.tools && (key === 'chunk_overlap' || key === 'chunk_size') ) { if ( - activeThread.assistants[0].tools[0]?.settings?.chunk_size < - activeThread.assistants[0].tools[0]?.settings?.chunk_overlap + activeAssistant.tools[0]?.settings?.chunk_size < + activeAssistant.tools[0]?.settings?.chunk_overlap ) { - activeThread.assistants[0].tools[0].settings.chunk_overlap = - activeThread.assistants[0].tools[0].settings.chunk_size + activeAssistant.tools[0].settings.chunk_overlap = + activeAssistant.tools[0].settings.chunk_size } if ( key === 'chunk_size' && - value < activeThread.assistants[0].tools[0].settings?.chunk_overlap + value < activeAssistant.tools[0].settings?.chunk_overlap ) { - activeThread.assistants[0].tools[0].settings.chunk_overlap = value + activeAssistant.tools[0].settings.chunk_overlap = value } else if ( key === 'chunk_overlap' && - value > activeThread.assistants[0].tools[0].settings?.chunk_size + value > activeAssistant.tools[0].settings?.chunk_size ) { - activeThread.assistants[0].tools[0].settings.chunk_size = value + activeAssistant.tools[0].settings.chunk_size = value } } updateThreadMetadata({ ...activeThread, assistants: [ { - ...activeThread.assistants[0], + ...activeAssistant, tools: [ { type: 'retrieval', enabled: true, settings: { - ...(activeThread.assistants[0].tools && - activeThread.assistants[0].tools[0]?.settings), + ...(activeAssistant.tools && + activeAssistant.tools[0]?.settings), [key]: value, }, }, @@ -77,6 +79,7 @@ const AssistantSetting: React.FC = ({ componentData }) => { }) }, [ + activeAssistant, activeThread, componentData, setEngineParamsUpdate, diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx index fbca6d2904..b3246a26b7 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx @@ -24,6 +24,7 @@ import { useActiveModel } from '@/hooks/useActiveModel' import useSendChatMessage from '@/hooks/useSendChatMessage' +import { uploader } from '@/utils/file' import { isLocalEngine } from '@/utils/modelEngine' import FileUploadPreview from '../FileUploadPreview' @@ -33,6 +34,7 @@ import RichTextEditor from './RichTextEditor' import { showRightPanelAtom } from '@/helpers/atoms/App.atom' import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { spellCheckAtom } from '@/helpers/atoms/Setting.atom' @@ -67,8 +69,10 @@ const ChatInput = () => { const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom) const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom) const threadStates = useAtomValue(threadStatesAtom) + const activeAssistant = useAtomValue(activeAssistantAtom) const { stopInference } = useActiveModel() + const upload = uploader() const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom( activeTabThreadRightPanelAtom ) @@ -102,18 +106,26 @@ const ChatInput = () => { const handleFileChange = (event: React.ChangeEvent) => { const file = event.target.files?.[0] if (!file) return - setFileUpload([{ file: file, type: 'pdf' }]) + upload.addFile(file) + upload.upload().then((data) => { + setFileUpload({ + file: file, + type: 'pdf', + id: data?.successful?.[0]?.response?.body?.id, + name: data?.successful?.[0]?.response?.body?.filename, + }) + }) } const handleImageChange = (event: React.ChangeEvent) => { const file = event.target.files?.[0] if (!file) return - setFileUpload([{ file: file, type: 'image' }]) + setFileUpload({ file: file, type: 'image' }) } const renderPreview = (fileUpload: any) => { - if (fileUpload.length > 0) { - if (fileUpload[0].type === 'image') { + if (fileUpload) { + if (fileUpload.type === 'image') { return } else { return @@ -130,7 +142,7 @@ const ChatInput = () => { 'relative mb-1 max-h-[400px] resize-none rounded-lg border border-[hsla(var(--app-border))] p-3 pr-20', 'focus-within:outline-none focus-visible:outline-0 focus-visible:ring-1 focus-visible:ring-[hsla(var(--primary-bg))] focus-visible:ring-offset-0', 'overflow-y-auto', - fileUpload.length && 'rounded-t-none', + fileUpload && 'rounded-t-none', experimentalFeature && 'pl-10', activeSettingInputBox && 'pb-14 pr-16' )} @@ -152,10 +164,10 @@ const ChatInput = () => { className="absolute left-3 top-2.5" onClick={(e) => { if ( - fileUpload.length > 0 || - (activeThread?.assistants[0].tools && - !activeThread?.assistants[0].tools[0]?.enabled && - !activeThread?.assistants[0].model.settings?.vision_model) + !!fileUpload || + (activeAssistant?.tools && + !activeAssistant?.tools[0]?.enabled && + !activeAssistant?.model.settings?.vision_model) ) { e.stopPropagation() } else { @@ -171,26 +183,24 @@ const ChatInput = () => { } disabled={ isModelSupportRagAndTools && - activeThread?.assistants[0].tools && - activeThread?.assistants[0].tools[0]?.enabled + activeAssistant?.tools && + activeAssistant?.tools[0]?.enabled } content={ <> - {fileUpload.length > 0 || - (activeThread?.assistants[0].tools && - !activeThread?.assistants[0].tools[0]?.enabled && - !activeThread?.assistants[0].model.settings - ?.vision_model && ( + {!!fileUpload || + (activeAssistant?.tools && + !activeAssistant?.tools[0]?.enabled && + !activeAssistant?.model.settings?.vision_model && ( <> - {fileUpload.length !== 0 && ( + {!!fileUpload && ( Currently, we only support 1 attachment at the same time. )} - {activeThread?.assistants[0].tools && - activeThread?.assistants[0].tools[0]?.enabled === - false && + {activeAssistant?.tools && + activeAssistant?.tools[0]?.enabled === false && isModelSupportRagAndTools && ( Turn on Retrieval in Tools settings to use this @@ -221,14 +231,12 @@ const ChatInput = () => {
  • { - if ( - activeThread?.assistants[0].model.settings?.vision_model - ) { + if (activeAssistant?.model.settings?.vision_model) { imageInputRef.current?.click() setShowAttacmentMenus(false) } @@ -239,9 +247,7 @@ const ChatInput = () => {
  • } content="This feature only supports multimodal models." - disabled={ - activeThread?.assistants[0].model.settings?.vision_model - } + disabled={activeAssistant?.model.settings?.vision_model} /> { } content={ - (!activeThread?.assistants[0].tools || - !activeThread?.assistants[0].tools[0]?.enabled) && ( + (!activeAssistant?.tools || + !activeAssistant?.tools[0]?.enabled) && ( Turn on Retrieval in Assistant Settings to use this feature. diff --git a/web/screens/Thread/ThreadCenterPanel/EditChatInput/index.tsx b/web/screens/Thread/ThreadCenterPanel/EditChatInput/index.tsx index ea22e3a584..9b81ea6518 100644 --- a/web/screens/Thread/ThreadCenterPanel/EditChatInput/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/EditChatInput/index.tsx @@ -72,7 +72,8 @@ const EditChatInput: React.FC = ({ message }) => { }, [editPrompt]) useEffect(() => { - setEditPrompt(message.content[0]?.text?.value) + if (message.content?.[0]?.text?.value) + setEditPrompt(message.content[0].text.value) // eslint-disable-next-line react-hooks/exhaustive-deps }, []) @@ -80,19 +81,17 @@ const EditChatInput: React.FC = ({ message }) => { setEditMessage('') const messageIdx = messages.findIndex((msg) => msg.id === message.id) const newMessages = messages.slice(0, messageIdx) - if (activeThread) { - setMessages(activeThread.id, newMessages) - await extensionManager - .get(ExtensionTypeEnum.Conversational) - ?.writeMessages( - activeThread.id, - // Remove all of the messages below this - newMessages - ) - .then(() => { - sendChatMessage(editPrompt, newMessages) - }) - } + const toDeleteMessages = messages.slice(messageIdx) + const threadId = messages[0].thread_id + await Promise.all( + toDeleteMessages.map(async (message) => + extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.deleteMessage(message.thread_id, message.id) + ) + ) + setMessages(threadId, newMessages) + sendChatMessage(editPrompt, false, newMessages) } const onKeyDown = async (e: React.KeyboardEvent) => { diff --git a/web/screens/Thread/ThreadCenterPanel/FileUploadPreview/index.tsx b/web/screens/Thread/ThreadCenterPanel/FileUploadPreview/index.tsx index 348e915e61..0e4872e101 100644 --- a/web/screens/Thread/ThreadCenterPanel/FileUploadPreview/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/FileUploadPreview/index.tsx @@ -15,31 +15,33 @@ const FileUploadPreview = () => { const setCurrentPrompt = useSetAtom(currentPromptAtom) const onDeleteClick = () => { - setFileUpload([]) + setFileUpload(undefined) setCurrentPrompt('') } return (
    -
    - - -
    -
    - {fileUpload[0].file.name.replaceAll(/[-._]/g, ' ')} -
    -

    - {toGibibytes(fileUpload[0].file.size)} -

    + {!!fileUpload && ( +
    + + +
    +
    + {fileUpload?.file.name.replaceAll(/[-._]/g, ' ')} +
    +

    + {toGibibytes(fileUpload?.file.size)} +

    +
    + +
    + +
    - -
    - -
    -
    + )}
    ) } diff --git a/web/screens/Thread/ThreadCenterPanel/ImageUploadPreview/index.tsx b/web/screens/Thread/ThreadCenterPanel/ImageUploadPreview/index.tsx index b43b808309..7fa9e417a3 100644 --- a/web/screens/Thread/ThreadCenterPanel/ImageUploadPreview/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ImageUploadPreview/index.tsx @@ -29,7 +29,7 @@ const ImageUploadPreview: React.FC = ({ file }) => { } const onDeleteClick = () => { - setFileUpload([]) + setFileUpload(undefined) setCurrentPrompt('') } diff --git a/web/screens/Thread/ThreadCenterPanel/LoadModelError/index.tsx b/web/screens/Thread/ThreadCenterPanel/LoadModelError/index.tsx index d6fed48043..204ec40fb9 100644 --- a/web/screens/Thread/ThreadCenterPanel/LoadModelError/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/LoadModelError/index.tsx @@ -10,15 +10,15 @@ import { MainViewState } from '@/constants/screens' import { loadModelErrorAtom } from '@/hooks/useActiveModel' import { mainViewStateAtom } from '@/helpers/atoms/App.atom' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom' -import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' const LoadModelError = () => { const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom) const loadModelError = useAtomValue(loadModelErrorAtom) const setMainState = useSetAtom(mainViewStateAtom) const setSelectedSettingScreen = useSetAtom(selectedSettingAtom) - const activeThread = useAtomValue(activeThreadAtom) + const activeAssistant = useAtomValue(activeAssistantAtom) const ErrorMessage = () => { if ( @@ -33,9 +33,9 @@ const LoadModelError = () => { className="cursor-pointer font-medium text-[hsla(var(--app-link))]" onClick={() => { setMainState(MainViewState.Settings) - if (activeThread?.assistants[0]?.model.engine) { + if (activeAssistant?.model.engine) { const engine = EngineManager.instance().get( - activeThread.assistants[0].model.engine + activeAssistant.model.engine ) engine?.name && setSelectedSettingScreen(engine.name) } diff --git a/web/screens/Thread/ThreadCenterPanel/MessageToolbar/index.tsx b/web/screens/Thread/ThreadCenterPanel/MessageToolbar/index.tsx index c4a97a6b9f..a15f0ec583 100644 --- a/web/screens/Thread/ThreadCenterPanel/MessageToolbar/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/MessageToolbar/index.tsx @@ -55,15 +55,11 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => { .slice(-1)[0] if (thread) { - // Should also delete error messages to clear out the error state + // TODO: Should also delete error messages to clear out the error state await extensionManager .get(ExtensionTypeEnum.Conversational) - ?.writeMessages( - thread.id, - messages.filter( - (msg) => msg.id !== message.id && msg.status !== MessageStatus.Error - ) - ) + ?.deleteMessage(thread.id, message.id) + .catch(console.error) const updatedThread: Thread = { ...thread, @@ -74,7 +70,7 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => { )[ messages.filter((msg) => msg.role === ChatCompletionRole.Assistant) .length - 1 - ]?.content[0]?.text.value, + ]?.content[0]?.text?.value, }, } @@ -89,10 +85,6 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => { setEditMessage(message.id ?? '') } - const onRegenerateClick = async () => { - resendChatMessage(message) - } - if (message.status === MessageStatus.Pending) return null return ( @@ -118,11 +110,10 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => { {message.id === messages[messages.length - 1]?.id && messages[messages.length - 1].status !== MessageStatus.Error && - messages[messages.length - 1].content[0]?.type !== - ContentType.Pdf && ( + !messages[messages.length - 1].attachments?.length && (
    { +const DocMessage = ({ id, name }: { id: string; name?: string }) => { const { onViewFile, onViewFileContainer } = usePath() return ( @@ -44,9 +36,9 @@ const DocMessage = ({
    {name?.replaceAll(/[-._]/g, ' ')}
    -

    + {/*

    {toGibibytes(Number(size))} -

    +

    */}
    ) diff --git a/web/screens/Thread/ThreadCenterPanel/TextMessage/ImageMessage.tsx b/web/screens/Thread/ThreadCenterPanel/TextMessage/ImageMessage.tsx index 117f259c0b..e83d35fbbb 100644 --- a/web/screens/Thread/ThreadCenterPanel/TextMessage/ImageMessage.tsx +++ b/web/screens/Thread/ThreadCenterPanel/TextMessage/ImageMessage.tsx @@ -1,6 +1,5 @@ -import { memo, useMemo } from 'react' +import { memo } from 'react' -import { ThreadContent } from '@janhq/core' import { Tooltip } from '@janhq/joi' import { FolderOpenIcon } from 'lucide-react' @@ -11,21 +10,13 @@ import { openFileTitle } from '@/utils/titleUtils' import { RelativeImage } from '../TextMessage/RelativeImage' -const ImageMessage = ({ content }: { content: ThreadContent }) => { +const ImageMessage = ({ image }: { image: string }) => { const { onViewFile, onViewFileContainer } = usePath() - const annotation = useMemo( - () => content?.text?.annotations[0] ?? '', - [content] - ) - return (
    - onViewFile(annotation)} - /> + onViewFile(image)} />
    props.content[0]?.text?.value ?? '', + () => + props.content.find((e) => e.type === ContentType.Text)?.text?.value ?? '', [props.content] ) - const messageType = useMemo( - () => props.content[0]?.type ?? '', + + const image = useMemo( + () => + props.content.find((e) => e.type === ContentType.Image)?.image_url?.url, [props.content] ) + const attachedFile = useMemo(() => 'attachments' in props, [props]) + return (
    {isUser ? props.role - : (activeThread?.assistants[0].assistant_name ?? props.role)} + : (activeAssistant?.assistant_name ?? props.role)}

    - {displayDate(props.created)} + {props.created && displayDate(props.created ?? new Date())}

    @@ -111,16 +116,8 @@ const MessageContainer: React.FC< )} > <> - {messageType === ContentType.Image && ( - - )} - {messageType === ContentType.Pdf && ( - - )} + {image && } + {attachedFile && } {editMessage === props.id ? (
    diff --git a/web/screens/Thread/ThreadCenterPanel/index.tsx b/web/screens/Thread/ThreadCenterPanel/index.tsx index 01ba0aaeb5..ca04f9e595 100644 --- a/web/screens/Thread/ThreadCenterPanel/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/index.tsx @@ -22,11 +22,14 @@ import { reloadModelAtom } from '@/hooks/useSendChatMessage' import ChatBody from '@/screens/Thread/ThreadCenterPanel/ChatBody' +import { uploader } from '@/utils/file' + import ChatInput from './ChatInput' import RequestDownloadModel from './RequestDownloadModel' import { showSystemMonitorPanelAtom } from '@/helpers/atoms/App.atom' import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' import { @@ -55,9 +58,9 @@ const ThreadCenterPanel = () => { const setFileUpload = useSetAtom(fileUploadAtom) const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom) const activeThread = useAtomValue(activeThreadAtom) - - const acceptedFormat: Accept = activeThread?.assistants[0].model.settings - ?.vision_model + const activeAssistant = useAtomValue(activeAssistantAtom) + const upload = uploader() + const acceptedFormat: Accept = activeAssistant?.model.settings?.vision_model ? { 'application/pdf': ['.pdf'], 'image/jpeg': ['.jpeg'], @@ -78,14 +81,13 @@ const ThreadCenterPanel = () => { if (!experimentalFeature) return if ( e.dataTransfer.items.length === 1 && - ((activeThread?.assistants[0].tools && - activeThread?.assistants[0].tools[0]?.enabled) || - activeThread?.assistants[0].model.settings?.vision_model) + ((activeAssistant?.tools && activeAssistant?.tools[0]?.enabled) || + activeAssistant?.model.settings?.vision_model) ) { setDragOver(true) } else if ( - activeThread?.assistants[0].tools && - !activeThread?.assistants[0].tools[0]?.enabled + activeAssistant?.tools && + !activeAssistant?.tools[0]?.enabled ) { setDragRejected({ code: 'retrieval-off' }) } else { @@ -93,27 +95,36 @@ const ThreadCenterPanel = () => { } }, onDragLeave: () => setDragOver(false), - onDrop: (files, rejectFiles) => { + onDrop: async (files, rejectFiles) => { // Retrieval file drag and drop is experimental feature if (!experimentalFeature) return if ( !files || files.length !== 1 || rejectFiles.length !== 0 || - (activeThread?.assistants[0].tools && - !activeThread?.assistants[0].tools[0]?.enabled && - !activeThread?.assistants[0].model.settings?.vision_model) + (activeAssistant?.tools && + !activeAssistant?.tools[0]?.enabled && + !activeAssistant?.model.settings?.vision_model) ) return const imageType = files[0]?.type.includes('image') - setFileUpload([{ file: files[0], type: imageType ? 'image' : 'pdf' }]) + if (imageType) { + setFileUpload({ file: files[0], type: 'image' }) + } else { + upload.addFile(files[0]) + upload.upload().then((data) => { + setFileUpload({ + file: files[0], + type: imageType ? 'image' : 'pdf', + id: data?.successful?.[0]?.response?.body?.id, + name: data?.successful?.[0]?.response?.body?.filename, + }) + }) + } setDragOver(false) }, onDropRejected: (e) => { - if ( - activeThread?.assistants[0].tools && - !activeThread?.assistants[0].tools[0]?.enabled - ) { + if (activeAssistant?.tools && !activeAssistant?.tools[0]?.enabled) { setDragRejected({ code: 'retrieval-off' }) } else { setDragRejected({ code: e[0].errors[0].code }) @@ -186,8 +197,7 @@ const ThreadCenterPanel = () => {
    {isDragReject ? `Currently, we only support 1 attachment at the same time with ${ - activeThread?.assistants[0].model.settings - ?.vision_model + activeAssistant?.model.settings?.vision_model ? 'PDF, JPEG, JPG, PNG' : 'PDF' } format` @@ -195,7 +205,7 @@ const ThreadCenterPanel = () => {
    {!isDragReject && (

    - {activeThread?.assistants[0].model.settings?.vision_model + {activeAssistant?.model.settings?.vision_model ? 'PDF, JPEG, JPG, PNG' : 'PDF'}

    diff --git a/web/screens/Thread/ThreadLeftPanel/ModalEditTitleThread/index.tsx b/web/screens/Thread/ThreadLeftPanel/ModalEditTitleThread/index.tsx index ddeaedf407..21b415f49f 100644 --- a/web/screens/Thread/ThreadLeftPanel/ModalEditTitleThread/index.tsx +++ b/web/screens/Thread/ThreadLeftPanel/ModalEditTitleThread/index.tsx @@ -15,13 +15,15 @@ const ModalEditTitleThread = () => { const [modalActionThread, setModalActionThread] = useAtom( modalActionThreadAtom ) - const [title, setTitle] = useState(modalActionThread.thread?.title as string) + const [title, setTitle] = useState( + modalActionThread.thread?.metadata?.title as string + ) useLayoutEffect(() => { - if (modalActionThread.thread?.title) { - setTitle(modalActionThread.thread?.title) + if (modalActionThread.thread?.metadata?.title) { + setTitle(modalActionThread.thread?.metadata?.title as string) } - }, [modalActionThread.thread?.title]) + }, [modalActionThread.thread?.metadata]) const onUpdateTitle = useCallback( (e: React.MouseEvent) => { @@ -30,6 +32,10 @@ const ModalEditTitleThread = () => { updateThreadMetadata({ ...modalActionThread?.thread, title: title || 'New Thread', + metadata: { + ...modalActionThread?.thread.metadata, + title: title || 'New Thread', + }, }) }, [modalActionThread?.thread, title, updateThreadMetadata] diff --git a/web/screens/Thread/ThreadLeftPanel/index.tsx b/web/screens/Thread/ThreadLeftPanel/index.tsx index 61c6672fcf..46763e555c 100644 --- a/web/screens/Thread/ThreadLeftPanel/index.tsx +++ b/web/screens/Thread/ThreadLeftPanel/index.tsx @@ -20,7 +20,10 @@ import { useCreateNewThread } from '@/hooks/useCreateNewThread' import useRecommendedModel from '@/hooks/useRecommendedModel' import useSetActiveThread from '@/hooks/useSetActiveThread' -import { assistantsAtom } from '@/helpers/atoms/Assistant.atom' +import { + activeAssistantAtom, + assistantsAtom, +} from '@/helpers/atoms/Assistant.atom' import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom' import { @@ -34,6 +37,7 @@ import { const ThreadLeftPanel = () => { const threads = useAtomValue(threadsAtom) const activeThreadId = useAtomValue(getActiveThreadIdAtom) + const activeAssistant = useAtomValue(activeAssistantAtom) const { setActiveThread } = useSetActiveThread() const assistants = useAtomValue(assistantsAtom) const threadDataReady = useAtomValue(threadDataReadyAtom) @@ -67,6 +71,7 @@ const ThreadLeftPanel = () => { useEffect(() => { if ( threadDataReady && + activeAssistant && assistants.length > 0 && threads.length === 0 && downloadedModels.length > 0 @@ -75,7 +80,10 @@ const ThreadLeftPanel = () => { (model) => model.engine === InferenceEngine.cortex_llamacpp ) const selectedModel = model[0] || recommendedModel - requestCreateNewThread(assistants[0], selectedModel) + requestCreateNewThread( + { ...assistants[0], ...activeAssistant }, + selectedModel + ) } else if (threadDataReady && !activeThreadId) { setActiveThread(threads[0]) } @@ -88,6 +96,7 @@ const ThreadLeftPanel = () => { setActiveThread, recommendedModel, downloadedModels, + activeAssistant, ]) const onContextMenu = (event: React.MouseEvent, thread: Thread) => { @@ -138,7 +147,7 @@ const ThreadLeftPanel = () => { activeThreadId && 'font-medium' )} > - {thread.title} + {thread.title ?? thread.metadata?.title}
    { const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom) const activeThread = useAtomValue(activeThreadAtom) + const activeAssistant = useAtomValue(activeAssistantAtom) const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom) const { updateThreadMetadata } = useCreateNewThread() const { recommendedModel, downloadedModels } = useRecommendedModel() const componentDataAssistantSetting = getConfigurationsData( - (activeThread?.assistants[0]?.tools && - activeThread?.assistants[0]?.tools[0]?.settings) ?? - {} + (activeAssistant?.tools && activeAssistant?.tools[0]?.settings) ?? {} ) useEffect(() => { if (!activeThread) return let model = downloadedModels.find( - (model) => model.id === activeThread.assistants[0].model.id + (model) => model.id === activeAssistant?.model.id ) if (!model) { model = recommendedModel } setSelectedModel(model) - }, [recommendedModel, activeThread, downloadedModels, setSelectedModel]) + }, [ + recommendedModel, + activeThread, + downloadedModels, + setSelectedModel, + activeAssistant?.model.id, + ]) const onRetrievalSwitchUpdate = useCallback( (enabled: boolean) => { - if (!activeThread) return + if (!activeThread || !activeAssistant) return updateThreadMetadata({ ...activeThread, assistants: [ { - ...activeThread.assistants[0], + ...activeAssistant, tools: [ { type: 'retrieval', enabled: enabled, settings: - (activeThread.assistants[0].tools && - activeThread.assistants[0].tools[0]?.settings) ?? + (activeAssistant.tools && + activeAssistant.tools[0]?.settings) ?? {}, }, ], @@ -63,25 +69,25 @@ const Tools = () => { ], }) }, - [activeThread, updateThreadMetadata] + [activeAssistant, activeThread, updateThreadMetadata] ) const onTimeWeightedRetrieverSwitchUpdate = useCallback( (enabled: boolean) => { - if (!activeThread) return + if (!activeThread || !activeAssistant) return updateThreadMetadata({ ...activeThread, assistants: [ { - ...activeThread.assistants[0], + ...activeAssistant, tools: [ { type: 'retrieval', enabled: true, useTimeWeightedRetriever: enabled, settings: - (activeThread.assistants[0].tools && - activeThread.assistants[0].tools[0]?.settings) ?? + (activeAssistant.tools && + activeAssistant.tools[0]?.settings) ?? {}, }, ], @@ -89,23 +95,54 @@ const Tools = () => { ], }) }, - [activeThread, updateThreadMetadata] + [activeAssistant, activeThread, updateThreadMetadata] ) if (!experimentalFeature) return null return ( - {activeThread?.assistants[0]?.tools && - componentDataAssistantSetting.length > 0 && ( -
    -
    + {activeAssistant?.tools && componentDataAssistantSetting.length > 0 && ( +
    +
    +
    +
    -
    +
    +
    + {activeAssistant?.tools[0].enabled && ( +
    +
    +
    + { className="ml-2 flex-shrink-0 text-[hsl(var(--text-secondary))]" /> } - content="Retrieval helps the assistant use information from - files you send to it. Once you share a file, the - assistant automatically fetches the relevant content - based on your request." - /> - -
    - onRetrievalSwitchUpdate(e.target.checked)} + content="Embedding model is crucial for understanding and + processing the input text effectively by + converting text to numerical representations. + Align the model choice with your task, evaluate + its performance, and consider factors like + resource availability. Experiment to find the best + fit for your specific use case." />
    +
    + +
    -
    - {activeThread?.assistants[0]?.tools[0].enabled && ( -
    -
    -
    - +
    +
    +
    -
    - -
    -
    -
    -
    - -
    + /> + +
    -
    - -
    +
    +
    -
    -
    - - - } - content="Time-Weighted Retriever looks at how similar +
    +
    +
    + + + } + content="Time-Weighted Retriever looks at how similar they are and how new they are. It compares documents based on their meaning like usual, but also considers when they were added to give newer ones more importance." + /> +
    + + onTimeWeightedRetrieverSwitchUpdate(e.target.checked) + } /> -
    - - onTimeWeightedRetrieverSwitchUpdate(e.target.checked) - } - /> -
    -
    - )} -
    - )} + +
    + )} +
    + )} ) } diff --git a/web/screens/Thread/ThreadRightPanel/index.tsx b/web/screens/Thread/ThreadRightPanel/index.tsx index 952ba8eb32..939bb3fe72 100644 --- a/web/screens/Thread/ThreadRightPanel/index.tsx +++ b/web/screens/Thread/ThreadRightPanel/index.tsx @@ -38,6 +38,7 @@ import PromptTemplateSetting from './PromptTemplateSetting' import Tools from './Tools' import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { activeThreadAtom, @@ -53,6 +54,7 @@ const ENGINE_SETTINGS = 'Engine Settings' const ThreadRightPanel = () => { const activeThread = useAtomValue(activeThreadAtom) + const activeAssistant = useAtomValue(activeAssistantAtom) const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) const selectedModel = useAtomValue(selectedModelAtom) const [activeTabThreadRightPanel, setActiveTabThreadRightPanel] = useAtom( @@ -154,18 +156,18 @@ const ThreadRightPanel = () => { const onAssistantInstructionChanged = useCallback( (e: React.ChangeEvent) => { - if (activeThread) + if (activeThread && activeAssistant) updateThreadMetadata({ ...activeThread, assistants: [ { - ...activeThread.assistants[0], + ...activeAssistant, instructions: e.target.value || '', }, ], }) }, - [activeThread, updateThreadMetadata] + [activeAssistant, activeThread, updateThreadMetadata] ) const resetModel = useDebouncedCallback(() => { @@ -174,9 +176,7 @@ const ThreadRightPanel = () => { const onValueChanged = useCallback( (key: string, value: string | number | boolean | string[]) => { - if (!activeThread) { - return - } + if (!activeThread || !activeAssistant) return setEngineParamsUpdate(true) resetModel() @@ -186,32 +186,38 @@ const ThreadRightPanel = () => { }) if ( - activeThread.assistants[0].model.parameters?.max_tokens && - activeThread.assistants[0].model.settings?.ctx_len + activeAssistant.model.parameters?.max_tokens && + activeAssistant.model.settings?.ctx_len ) { if ( key === 'max_tokens' && - Number(value) > activeThread.assistants[0].model.settings.ctx_len + Number(value) > activeAssistant.model.settings.ctx_len ) { updateModelParameter(activeThread, { params: { - max_tokens: activeThread.assistants[0].model.settings.ctx_len, + max_tokens: activeAssistant.model.settings.ctx_len, }, }) } if ( key === 'ctx_len' && - Number(value) < activeThread.assistants[0].model.parameters.max_tokens + Number(value) < activeAssistant.model.parameters.max_tokens ) { updateModelParameter(activeThread, { params: { - max_tokens: activeThread.assistants[0].model.settings.ctx_len, + max_tokens: activeAssistant.model.settings.ctx_len, }, }) } } }, - [activeThread, resetModel, setEngineParamsUpdate, updateModelParameter] + [ + activeAssistant, + activeThread, + resetModel, + setEngineParamsUpdate, + updateModelParameter, + ] ) if (!activeThread) { @@ -250,7 +256,7 @@ const ThreadRightPanel = () => {