From b28cac70830c18ad9c968e04dd5092551001f892 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 23 Dec 2024 21:04:37 +0700 Subject: [PATCH] fix: render performance while generating messages (#4328) --- web/containers/Providers/ModelHandler.tsx | 37 +++++++++++++++---- web/helpers/atoms/ChatMessage.atom.ts | 7 ++++ web/hooks/useSetActiveThread.ts | 21 ++++++++--- .../ThreadCenterPanel/ChatItem/index.tsx | 22 ++++++++++- 4 files changed, 73 insertions(+), 14 deletions(-) diff --git a/web/containers/Providers/ModelHandler.tsx b/web/containers/Providers/ModelHandler.tsx index d838df3246..8c565bab11 100644 --- a/web/containers/Providers/ModelHandler.tsx +++ b/web/containers/Providers/ModelHandler.tsx @@ -18,7 +18,7 @@ import { extractInferenceParams, ModelExtension, } from '@janhq/core' -import { useAtomValue, useSetAtom } from 'jotai' +import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { ulid } from 'ulidx' import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel' @@ -32,6 +32,7 @@ import { updateMessageAtom, tokenSpeedAtom, deleteMessageAtom, + subscribedGeneratingMessageAtom, } from '@/helpers/atoms/ChatMessage.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { @@ -40,6 +41,7 @@ import { isGeneratingResponseAtom, updateThreadAtom, getActiveThreadModelParamsAtom, + activeThreadAtom, } from '@/helpers/atoms/Thread.atom' const maxWordForThreadTitle = 10 @@ -54,6 +56,10 @@ export default function ModelHandler() { const activeModel = useAtomValue(activeModelAtom) const setActiveModel = useSetAtom(activeModelAtom) const setStateModel = useSetAtom(stateModelAtom) + const [subscribedGeneratingMessage, setSubscribedGeneratingMessage] = useAtom( + subscribedGeneratingMessageAtom + ) + const activeThread = useAtomValue(activeThreadAtom) const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom) const threads = useAtomValue(threadsAtom) @@ -62,11 +68,17 @@ export default function ModelHandler() { const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) const updateThread = useSetAtom(updateThreadAtom) const messagesRef = useRef(messages) + const messageGenerationSubscriber = useRef(subscribedGeneratingMessage) const activeModelRef = useRef(activeModel) + const activeThreadRef = useRef(activeThread) const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) const activeModelParamsRef = useRef(activeModelParams) const setTokenSpeed = useSetAtom(tokenSpeedAtom) + useEffect(() => { + activeThreadRef.current = activeThread + }, [activeThread]) + useEffect(() => { threadsRef.current = threads }, [threads]) @@ -87,6 +99,10 @@ export default function ModelHandler() { activeModelParamsRef.current = activeModelParams }, [activeModelParams]) + useEffect(() => { + messageGenerationSubscriber.current = subscribedGeneratingMessage + }, [subscribedGeneratingMessage]) + const onNewMessageResponse = useCallback( async (message: ThreadMessage) => { if (message.type === MessageRequestType.Thread) { @@ -179,12 +195,19 @@ export default function ModelHandler() { const updateThreadMessage = useCallback( (message: ThreadMessage) => { - updateMessage( - message.id, - message.thread_id, - message.content, - message.status - ) + if ( + messageGenerationSubscriber.current && + message.thread_id === activeThreadRef.current?.id && + !messageGenerationSubscriber.current!.thread_id + ) { + updateMessage( + message.id, + message.thread_id, + message.content, + message.status + ) + } + if (message.status === MessageStatus.Pending) { if (message.content.length) { setIsGeneratingResponse(false) diff --git a/web/helpers/atoms/ChatMessage.atom.ts b/web/helpers/atoms/ChatMessage.atom.ts index 7034396652..1847aa4224 100644 --- a/web/helpers/atoms/ChatMessage.atom.ts +++ b/web/helpers/atoms/ChatMessage.atom.ts @@ -35,6 +35,13 @@ export const chatMessages = atom( } ) +/** + * Store subscribed generating message thread + */ +export const subscribedGeneratingMessageAtom = atom<{ + thread_id?: string +}>({}) + /** * Stores the status of the messages load for each thread */ diff --git a/web/hooks/useSetActiveThread.ts b/web/hooks/useSetActiveThread.ts index 8c7ed5361b..62baa4a870 100644 --- a/web/hooks/useSetActiveThread.ts +++ b/web/hooks/useSetActiveThread.ts @@ -1,11 +1,15 @@ import { ExtensionTypeEnum, Thread, ConversationalExtension } from '@janhq/core' -import { useSetAtom } from 'jotai' +import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { extensionManager } from '@/extension' import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' -import { setConvoMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' import { + setConvoMessagesAtom, + subscribedGeneratingMessageAtom, +} from '@/helpers/atoms/ChatMessage.atom' +import { + getActiveThreadIdAtom, setActiveThreadIdAtom, setThreadModelParamsAtom, } from '@/helpers/atoms/Thread.atom' @@ -13,14 +17,18 @@ import { ModelParams } from '@/types/model' export default function useSetActiveThread() { const setActiveThreadId = useSetAtom(setActiveThreadIdAtom) - const setThreadMessage = useSetAtom(setConvoMessagesAtom) + const activeThreadId = useAtomValue(getActiveThreadIdAtom) + const setThreadMessages = useSetAtom(setConvoMessagesAtom) const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) const setActiveAssistant = useSetAtom(activeAssistantAtom) + const [messageSubscriber, setMessageSubscriber] = useAtom( + subscribedGeneratingMessageAtom + ) const setActiveThread = async (thread: Thread) => { - if (!thread?.id) return + if (!thread?.id || activeThreadId === thread.id) return - setActiveThreadId(thread?.id) + setActiveThreadId(thread.id) try { const assistantInfo = await getThreadAssistant(thread.id) @@ -32,7 +40,8 @@ export default function useSetActiveThread() { ...assistantInfo?.model?.settings, } setThreadModelParams(thread?.id, modelParams) - setThreadMessage(thread.id, messages) + setThreadMessages(thread.id, messages) + if (messageSubscriber.thread_id !== thread.id) setMessageSubscriber({}) } catch (e) { console.error(e) } diff --git a/web/screens/Thread/ThreadCenterPanel/ChatItem/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatItem/index.tsx index 57876d0448..b1c7386331 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatItem/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatItem/index.tsx @@ -1,4 +1,4 @@ -import React, { forwardRef, useEffect, useState } from 'react' +import React, { forwardRef, useEffect, useRef, useState } from 'react' import { events, @@ -8,10 +8,14 @@ import { ThreadMessage, } from '@janhq/core' +import { useAtom } from 'jotai' + import ErrorMessage from '@/containers/ErrorMessage' import MessageContainer from '../TextMessage' +import { subscribedGeneratingMessageAtom } from '@/helpers/atoms/ChatMessage.atom' + type Ref = HTMLDivElement type Props = { @@ -22,9 +26,13 @@ type Props = { const ChatItem = forwardRef((message, ref) => { const [content, setContent] = useState(message.content) const [status, setStatus] = useState(message.status) + const [subscribedGeneratingMessage, setSubscribedGeneratingMessage] = useAtom( + subscribedGeneratingMessageAtom + ) const [errorMessage, setErrorMessage] = useState( message.isCurrentMessage && !!message?.metadata?.error ? message : undefined ) + const subscribedGeneratingMessageRef = useRef(subscribedGeneratingMessage) function onMessageUpdate(data: ThreadMessage) { if (data.id === message.id) { @@ -32,9 +40,21 @@ const ChatItem = forwardRef((message, ref) => { if (data.status !== status) setStatus(data.status) if (data.status === MessageStatus.Error && message.isCurrentMessage) setErrorMessage(data) + + // Update subscriber if the message is generating + if ( + subscribedGeneratingMessageRef.current?.thread_id !== message.thread_id + ) + setSubscribedGeneratingMessage({ + thread_id: message.thread_id, + }) } } + useEffect(() => { + subscribedGeneratingMessageRef.current = subscribedGeneratingMessage + }, [subscribedGeneratingMessage]) + useEffect(() => { if (!message.isCurrentMessage && errorMessage) setErrorMessage(undefined) }, [message, errorMessage])