Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: improve models and threads caching #3744

Merged
merged 2 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions web/containers/Layout/RibbonPanel/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@ import { mainViewStateAtom, showLeftPanelAtom } from '@/helpers/atoms/App.atom'
import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'

import { isDownloadALocalModelAtom } from '@/helpers/atoms/Model.atom'
import {
reduceTransparentAtom,
selectedSettingAtom,
} from '@/helpers/atoms/Setting.atom'
import {
isDownloadALocalModelAtom,
threadsAtom,
} from '@/helpers/atoms/Thread.atom'
import { threadsAtom } from '@/helpers/atoms/Thread.atom'

export default function RibbonPanel() {
const [mainViewState, setMainViewState] = useAtom(mainViewStateAtom)
Expand Down
10 changes: 5 additions & 5 deletions web/containers/ModelDropdown/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
)

const isModelSupportRagAndTools = useCallback((model: Model) => {
return (

Check warning on line 102 in web/containers/ModelDropdown/index.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

102 line is not covered with tests
model?.engine === InferenceEngine.openai ||
isLocalEngine(model?.engine as InferenceEngine)
)
Expand All @@ -115,24 +115,24 @@
if (searchFilter === 'local') {
return isLocalEngine(e.engine)
}
if (searchFilter === 'remote') {
return !isLocalEngine(e.engine)

Check warning on line 119 in web/containers/ModelDropdown/index.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

118-119 lines are not covered with tests
}
})
.sort((a, b) => a.name.localeCompare(b.name))

Check warning on line 122 in web/containers/ModelDropdown/index.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

122 line is not covered with tests
.sort((a, b) => {
const aInDownloadedModels = downloadedModels.some(
(item) => item.id === a.id

Check warning on line 125 in web/containers/ModelDropdown/index.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

124-125 lines are not covered with tests
)
const bInDownloadedModels = downloadedModels.some(
(item) => item.id === b.id

Check warning on line 128 in web/containers/ModelDropdown/index.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

127-128 lines are not covered with tests
)
if (aInDownloadedModels && !bInDownloadedModels) {
return -1
} else if (!aInDownloadedModels && bInDownloadedModels) {
return 1

Check warning on line 133 in web/containers/ModelDropdown/index.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

130-133 lines are not covered with tests
} else {
return 0

Check warning on line 135 in web/containers/ModelDropdown/index.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

135 line is not covered with tests
}
}),
[configuredModels, searchText, searchFilter, downloadedModels]
Expand All @@ -140,7 +140,7 @@

useEffect(() => {
if (open && searchInputRef.current) {
searchInputRef.current.focus()

Check warning on line 143 in web/containers/ModelDropdown/index.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

143 line is not covered with tests
}
}, [open])

Expand All @@ -157,11 +157,11 @@

const onClickModelItem = useCallback(
async (modelId: string) => {
const model = downloadedModels.find((m) => m.id === modelId)
setSelectedModel(model)
setOpen(false)

Check warning on line 162 in web/containers/ModelDropdown/index.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

160-162 lines are not covered with tests

if (activeThread) {

Check warning on line 164 in web/containers/ModelDropdown/index.tsx

View workflow job for this annotation

GitHub Actions / coverage-check

164 line is not covered with tests
// Change assistand tools based on model support RAG
updateThreadMetadata({
...activeThread,
Expand Down Expand Up @@ -513,7 +513,7 @@
const isDownloading = downloadingModels.some(
(md) => md.id === model.id
)
const isdDownloaded = downloadedModels.some(
const isDownloaded = downloadedModels.some(
(c) => c.id === model.id
)
return (
Expand All @@ -528,7 +528,7 @@
onClick={() => {
if (!apiKey && !isLocalEngine(model.engine))
return null
if (isdDownloaded) {
if (isDownloaded) {
onClickModelItem(model.id)
}
}}
Expand All @@ -537,7 +537,7 @@
<p
className={twMerge(
'line-clamp-1',
!isdDownloaded &&
!isDownloaded &&
'text-[hsla(var(--text-secondary))]'
)}
title={model.name}
Expand All @@ -547,12 +547,12 @@
<ModelLabel metadata={model.metadata} compact />
</div>
<div className="flex items-center gap-2 text-[hsla(var(--text-tertiary))]">
{!isdDownloaded && (
{!isDownloaded && (
<span className="font-medium">
{toGibibytes(model.metadata.size)}
</span>
)}
{!isDownloading && !isdDownloaded ? (
{!isDownloading && !isDownloaded ? (
<DownloadCloudIcon
size={18}
className="cursor-pointer text-[hsla(var(--app-link))]"
Expand Down
Empty file.
78 changes: 78 additions & 0 deletions web/helpers/atoms/Extension.atom.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Extension.atom.test.ts

import { act, renderHook } from '@testing-library/react'
import * as ExtensionAtoms from './Extension.atom'
import { useAtom, useAtomValue, useSetAtom } from 'jotai'

describe('Extension.atom.ts', () => {
afterEach(() => {
jest.clearAllMocks()
})

describe('installingExtensionAtom', () => {
it('should initialize as an empty array', () => {
const { result } = renderHook(() => useAtomValue(ExtensionAtoms.installingExtensionAtom))
expect(result.current).toEqual([])
})
})

describe('setInstallingExtensionAtom', () => {
it('should add a new installing extension', () => {
const { result: setAtom } = renderHook(() => useSetAtom(ExtensionAtoms.setInstallingExtensionAtom))
const { result: getAtom } = renderHook(() => useAtomValue(ExtensionAtoms.installingExtensionAtom))

act(() => {
setAtom.current('ext1', { extensionId: 'ext1', percentage: 0 })
})

expect(getAtom.current).toEqual([{ extensionId: 'ext1', percentage: 0 }])
})

it('should update an existing installing extension', () => {
const { result: setAtom } = renderHook(() => useSetAtom(ExtensionAtoms.setInstallingExtensionAtom))
const { result: getAtom } = renderHook(() => useAtomValue(ExtensionAtoms.installingExtensionAtom))

act(() => {
setAtom.current('ext1', { extensionId: 'ext1', percentage: 0 })
setAtom.current('ext1', { extensionId: 'ext1', percentage: 50 })
})

expect(getAtom.current).toEqual([{ extensionId: 'ext1', percentage: 50 }])
})
})

describe('removeInstallingExtensionAtom', () => {
it('should remove an installing extension', () => {
const { result: setAtom } = renderHook(() => useSetAtom(ExtensionAtoms.setInstallingExtensionAtom))
const { result: removeAtom } = renderHook(() => useSetAtom(ExtensionAtoms.removeInstallingExtensionAtom))
const { result: getAtom } = renderHook(() => useAtomValue(ExtensionAtoms.installingExtensionAtom))

act(() => {
setAtom.current('ext1', { extensionId: 'ext1', percentage: 0 })
setAtom.current('ext2', { extensionId: 'ext2', percentage: 50 })
removeAtom.current('ext1')
})

expect(getAtom.current).toEqual([{ extensionId: 'ext2', percentage: 50 }])
})
})

describe('inActiveEngineProviderAtom', () => {
it('should initialize as an empty array', () => {
const { result } = renderHook(() => useAtomValue(ExtensionAtoms.inActiveEngineProviderAtom))
expect(result.current).toEqual([])
})

it('should persist value in storage', () => {
const { result } = renderHook(() => useAtom(ExtensionAtoms.inActiveEngineProviderAtom))

act(() => {
result.current[1](['provider1', 'provider2'])
})

// Simulate a re-render to check if the value persists
const { result: newResult } = renderHook(() => useAtomValue(ExtensionAtoms.inActiveEngineProviderAtom))
expect(newResult.current).toEqual(['provider1', 'provider2'])
})
})
})
7 changes: 1 addition & 6 deletions web/helpers/atoms/Model.atom.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { act, renderHook, waitFor } from '@testing-library/react'
import { act, renderHook } from '@testing-library/react'
import * as ModelAtoms from './Model.atom'
import { useAtom, useAtomValue, useSetAtom } from 'jotai'

Expand All @@ -24,11 +24,6 @@ describe('Model.atom.ts', () => {
})
})
})
describe('activeAssistantModelAtom', () => {
it('should initialize as undefined', () => {
expect(ModelAtoms.activeAssistantModelAtom.init).toBeUndefined()
})
})

describe('selectedModelAtom', () => {
it('should initialize as undefined', () => {
Expand Down
95 changes: 75 additions & 20 deletions web/helpers/atoms/Model.atom.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,59 @@
import { ImportingModel, InferenceEngine, Model, ModelFile } from '@janhq/core'
import { atom } from 'jotai'
import { atomWithStorage } from 'jotai/utils'

/**
* Enum for the keys used to store models in the local storage.
*/
enum ModelStorageAtomKeys {
DownloadedModels = 'downloadedModels',
AvailableModels = 'availableModels',
}
//// Models Atom
/**
* Downloaded Models Atom
* This atom stores the list of models that have been downloaded.
*/
export const downloadedModelsAtom = atomWithStorage<ModelFile[]>(
ModelStorageAtomKeys.DownloadedModels,
[]
)

/**
* Configured Models Atom
* This atom stores the list of models that have been configured and available to download
*/
export const configuredModelsAtom = atomWithStorage<ModelFile[]>(
ModelStorageAtomKeys.AvailableModels,
[]
)
louis-jan marked this conversation as resolved.
Show resolved Hide resolved

export const removeDownloadedModelAtom = atom(
null,
(get, set, modelId: string) => {
const downloadedModels = get(downloadedModelsAtom)

set(
downloadedModelsAtom,
downloadedModels.filter((e) => e.id !== modelId)
)
}
)

/**
* Atom to store the selected model (from ModelDropdown)
*/
export const selectedModelAtom = atom<ModelFile | undefined>(undefined)

/**
* Atom to store the expanded engine sections (from ModelDropdown)
*/
export const showEngineListModelAtom = atom<string[]>([InferenceEngine.nitro])

/// End Models Atom
/// Model Download Atom

export const stateModel = atom({ state: 'start', loading: false, model: '' })
export const activeAssistantModelAtom = atom<Model | undefined>(undefined)

/**
* Stores the list of models which are being downloaded.
Expand Down Expand Up @@ -30,28 +81,20 @@ export const removeDownloadingModelAtom = atom(
}
)

export const downloadedModelsAtom = atom<ModelFile[]>([])

export const removeDownloadedModelAtom = atom(
null,
(get, set, modelId: string) => {
const downloadedModels = get(downloadedModelsAtom)

set(
downloadedModelsAtom,
downloadedModels.filter((e) => e.id !== modelId)
)
}
)

export const configuredModelsAtom = atom<ModelFile[]>([])

export const defaultModelAtom = atom<Model | undefined>(undefined)
/// End Model Download Atom
/// Model Import Atom

/// TODO: move this part to another atom
// store the paths of the models that are being imported
export const importingModelsAtom = atom<ImportingModel[]>([])

// DEPRECATED: Remove when moving to cortex.cpp
// Default model template when importing
export const defaultModelAtom = atom<Model | undefined>(undefined)

/**
* Importing progress Atom
*/
export const updateImportingModelProgressAtom = atom(
null,
(get, set, importId: string, percentage: number) => {
Expand All @@ -69,6 +112,9 @@ export const updateImportingModelProgressAtom = atom(
}
)

/**
* Importing error Atom
*/
export const setImportingModelErrorAtom = atom(
null,
(get, set, importId: string, error: string) => {
Expand All @@ -87,6 +133,9 @@ export const setImportingModelErrorAtom = atom(
}
)

/**
* Importing success Atom
*/
export const setImportingModelSuccessAtom = atom(
null,
(get, set, importId: string, modelId: string) => {
Expand All @@ -105,6 +154,9 @@ export const setImportingModelSuccessAtom = atom(
}
)

/**
* Update importing model metadata Atom
*/
export const updateImportingModelAtom = atom(
null,
(
Expand All @@ -131,6 +183,9 @@ export const updateImportingModelAtom = atom(
}
)

export const selectedModelAtom = atom<ModelFile | undefined>(undefined)
/// End Model Import Atom

export const showEngineListModelAtom = atom<string[]>([InferenceEngine.nitro])
/// ModelDropdown States Atom
export const isDownloadALocalModelAtom = atom<boolean>(false)
export const isAnyRemoteModelConfiguredAtom = atom<boolean>(false)
/// End ModelDropdown States Atom
Loading
Loading