Skip to content

Commit

Permalink
chore: Refactor thread state management (#4350)
Browse files Browse the repository at this point in the history
* chore: Refactor thread state management

• Replace isGeneratingResponseAtom with isBlockingSendAtom
• Update dependencies in ChatBody, ChatInput, and MessageToolbar components
• Remove unused code and variables

* chore: clean states
  • Loading branch information
louis-jan authored Dec 29, 2024
1 parent c0f3fb5 commit 3af34c0
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 52 deletions.
2 changes: 1 addition & 1 deletion web/containers/Providers/ModelHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ export default function ModelHandler() {
const activeModel = useAtomValue(activeModelAtom)
const setActiveModel = useSetAtom(activeModelAtom)
const setStateModel = useSetAtom(stateModelAtom)
const [subscribedGeneratingMessage, setSubscribedGeneratingMessage] = useAtom(
const subscribedGeneratingMessage = useAtomValue(
subscribedGeneratingMessageAtom
)
const activeThread = useAtomValue(activeThreadAtom)
Expand Down
18 changes: 17 additions & 1 deletion web/helpers/atoms/Thread.atom.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Thread, ThreadContent, ThreadState } from '@janhq/core'

import { atom } from 'jotai'
import { atomWithStorage } from 'jotai/utils'
import { atomWithStorage, selectAtom } from 'jotai/utils'

import { ModelParams } from '@/types/model'

Expand Down Expand Up @@ -34,6 +34,22 @@ export const threadStatesAtom = atomWithStorage<Record<string, ThreadState>>(
{}
)

/**
* Returns whether there is a thread waiting for response or not
*/
const isWaitingForResponseAtom = selectAtom(threadStatesAtom, (threads) =>
Object.values(threads).some((t) => t.waitingForResponse)
)

/**
* Combine 2 states to reduce rerender
* 1. isWaitingForResponse
* 2. isGenerating
*/
export const isBlockingSendAtom = atom(
(get) => get(isWaitingForResponseAtom) || get(isGeneratingResponseAtom)
)

/**
* Stores all threads for the current user
*/
Expand Down
44 changes: 7 additions & 37 deletions web/screens/Thread/ThreadCenterPanel/ChatBody/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ import EmptyThread from './EmptyThread'
import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
import {
activeThreadAtom,
isGeneratingResponseAtom,
threadStatesAtom,
isBlockingSendAtom,
} from '@/helpers/atoms/Thread.atom'

const ChatConfigurator = memo(() => {
Expand Down Expand Up @@ -65,12 +64,7 @@ const ChatBody = memo(
const prevScrollTop = useRef(0)
const isUserManuallyScrollingUp = useRef(false)
const currentThread = useAtomValue(activeThreadAtom)
const threadStates = useAtomValue(threadStatesAtom)
const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom)

const isStreamingResponse = Object.values(threadStates).some(
(threadState) => threadState.waitingForResponse
)
const isBlockingSend = useAtomValue(isBlockingSendAtom)

const count = useMemo(
() => (messages?.length ?? 0) + (loadModelError ? 1 : 0),
Expand All @@ -87,35 +81,11 @@ const ChatBody = memo(

useEffect(() => {
isUserManuallyScrollingUp.current = false
if (parentRef.current) {
parentRef.current.scrollTo({ top: parentRef.current.scrollHeight })
virtualizer.scrollToIndex(count - 1)
}
}, [count, virtualizer])

useEffect(() => {
isUserManuallyScrollingUp.current = false
if (parentRef.current && isGeneratingResponse) {
parentRef.current.scrollTo({ top: parentRef.current.scrollHeight })
virtualizer.scrollToIndex(count - 1)
}
}, [count, virtualizer, isGeneratingResponse])

useEffect(() => {
isUserManuallyScrollingUp.current = false
if (parentRef.current && isGeneratingResponse) {
parentRef.current.scrollTo({ top: parentRef.current.scrollHeight })
virtualizer.scrollToIndex(count - 1)
}
}, [count, virtualizer, isGeneratingResponse, currentThread?.id])

useEffect(() => {
isUserManuallyScrollingUp.current = false
if (parentRef.current) {
if (parentRef.current && isBlockingSend) {
parentRef.current.scrollTo({ top: parentRef.current.scrollHeight })
virtualizer.scrollToIndex(count - 1)
}
}, [count, currentThread?.id, virtualizer])
}, [count, virtualizer, isBlockingSend, currentThread?.id])

const items = virtualizer.getVirtualItems()

Expand All @@ -124,7 +94,7 @@ const ChatBody = memo(
_,
instance
) => {
if (isUserManuallyScrollingUp.current === true && isStreamingResponse)
if (isUserManuallyScrollingUp.current === true && isBlockingSend)
return false
return (
// item.start < (instance.scrollOffset ?? 0) &&
Expand All @@ -136,7 +106,7 @@ const ChatBody = memo(
(event: React.UIEvent<HTMLElement>) => {
const currentScrollTop = event.currentTarget.scrollTop

if (prevScrollTop.current > currentScrollTop && isStreamingResponse) {
if (prevScrollTop.current > currentScrollTop && isBlockingSend) {
isUserManuallyScrollingUp.current = true
} else {
const currentScrollTop = event.currentTarget.scrollTop
Expand All @@ -154,7 +124,7 @@ const ChatBody = memo(
}
prevScrollTop.current = currentScrollTop
},
[isStreamingResponse]
[isBlockingSend]
)

return (
Expand Down
14 changes: 3 additions & 11 deletions web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,19 @@ import RichTextEditor from './RichTextEditor'
import { showRightPanelAtom } from '@/helpers/atoms/App.atom'
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { spellCheckAtom } from '@/helpers/atoms/Setting.atom'
import {
activeSettingInputBoxAtom,
activeThreadAtom,
getActiveThreadIdAtom,
isGeneratingResponseAtom,
threadStatesAtom,
isBlockingSendAtom,
} from '@/helpers/atoms/Thread.atom'
import { activeTabThreadRightPanelAtom } from '@/helpers/atoms/ThreadRightPanel.atom'

const ChatInput = () => {
const activeThread = useAtomValue(activeThreadAtom)
const { stateModel } = useActiveModel()
const messages = useAtomValue(getCurrentChatMessagesAtom)
const spellCheck = useAtomValue(spellCheckAtom)

const [currentPrompt, setCurrentPrompt] = useAtom(currentPromptAtom)
Expand All @@ -67,8 +64,7 @@ const ChatInput = () => {
const fileInputRef = useRef<HTMLInputElement>(null)
const imageInputRef = useRef<HTMLInputElement>(null)
const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom)
const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom)
const threadStates = useAtomValue(threadStatesAtom)
const isBlockingSend = useAtomValue(isBlockingSendAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const { stopInference } = useActiveModel()

Expand All @@ -77,10 +73,6 @@ const ChatInput = () => {
activeTabThreadRightPanelAtom
)

const isStreamingResponse = Object.values(threadStates).some(
(threadState) => threadState.waitingForResponse
)

const refAttachmentMenus = useClickOutside(() => setShowAttacmentMenus(false))
const [showRightPanel, setShowRightPanel] = useAtom(showRightPanelAtom)

Expand Down Expand Up @@ -302,7 +294,7 @@ const ChatInput = () => {
</div>
)}

{!isGeneratingResponse && !isStreamingResponse ? (
{!isBlockingSend ? (
<>
{currentPrompt.length !== 0 && (
<Button
Expand Down
8 changes: 6 additions & 2 deletions web/screens/Thread/ThreadCenterPanel/MessageToolbar/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
} from '@/helpers/atoms/ChatMessage.atom'
import {
activeThreadAtom,
isBlockingSendAtom,
updateThreadAtom,
updateThreadStateLastMessageAtom,
} from '@/helpers/atoms/Thread.atom'
Expand All @@ -43,6 +44,7 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
const clipboard = useClipboard({ timeout: 1000 })
const updateThreadLastMessage = useSetAtom(updateThreadStateLastMessageAtom)
const updateThread = useSetAtom(updateThreadAtom)
const isBlockingSend = useAtomValue(isBlockingSendAtom)

const onDeleteClick = useCallback(async () => {
deleteMessage(message.id ?? '')
Expand Down Expand Up @@ -91,7 +93,8 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
<div className="flex flex-row items-center">
<div className="flex gap-1 bg-[hsla(var(--app-bg))]">
{message.role === ChatCompletionRole.User &&
message.content[0]?.type === ContentType.Text && (
message.content[0]?.type === ContentType.Text &&
!isBlockingSend && (
<div
className="cursor-pointer rounded-lg border border-[hsla(var(--app-border))] p-2"
onClick={onEditClick}
Expand All @@ -110,7 +113,8 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {

{message.id === messages[messages.length - 1]?.id &&
!messages[messages.length - 1]?.metadata?.error &&
!messages[messages.length - 1].attachments?.length && (
!messages[messages.length - 1].attachments?.length &&
!isBlockingSend && (
<div
className="cursor-pointer rounded-lg border border-[hsla(var(--app-border))] p-2"
onClick={resendChatMessage}
Expand Down

0 comments on commit 3af34c0

Please sign in to comment.