From 541a37b8f1f2ab63025d973d505574874d3eb851 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 2 Dec 2024 14:04:13 +0700 Subject: [PATCH] fix: 4171 - Model loading gets stuck on stop --- .../bin/version.txt | 2 +- .../resources/default_settings.json | 58 +++++++++++----- .../rollup.config.ts | 2 +- .../src/@types/global.d.ts | 2 +- .../inference-cortex-extension/src/index.ts | 67 +++++++++++++++++++ web/containers/ErrorMessage/index.tsx | 2 +- .../LoadModelError/index.tsx | 2 +- .../ThreadCenterPanel/TextMessage/index.tsx | 2 +- 8 files changed, 115 insertions(+), 22 deletions(-) diff --git a/extensions/inference-cortex-extension/bin/version.txt b/extensions/inference-cortex-extension/bin/version.txt index 40ac6bb0ee..070e5d6e98 100644 --- a/extensions/inference-cortex-extension/bin/version.txt +++ b/extensions/inference-cortex-extension/bin/version.txt @@ -1 +1 @@ -1.0.4-rc4 \ No newline at end of file +1.0.4-rc5 \ No newline at end of file diff --git a/extensions/inference-cortex-extension/resources/default_settings.json b/extensions/inference-cortex-extension/resources/default_settings.json index 09d014a12c..1e5ec8db68 100644 --- a/extensions/inference-cortex-extension/resources/default_settings.json +++ b/extensions/inference-cortex-extension/resources/default_settings.json @@ -1,33 +1,59 @@ [ { - "key": "test", - "title": "Test", - "description": "Test", + "key": "cont_batching", + "title": "Continuous batching", + "description": "The number of parallel operations", + "controllerType": "checkbox", + "controllerProps": { + "value": true + } + }, + { + "key": "n_parallel", + "title": "Parallel operations", + "description": "The number of parallel operations", "controllerType": "input", "controllerProps": { - "placeholder": "Test", - "value": "" + "value": "4", + "placeholder": "4" } }, { - "key": "embedding", - "title": "Embedding", - "description": "Whether to enable embedding.", + "key": "flash_attn", + "title": "Flash Attention enabled", + "description": "To enable Flash Attention, default is true", "controllerType": "checkbox", "controllerProps": { "value": true } }, + { - "key": "ctx_len", - "title": "Context Length", - "description": "The context length for model operations varies; the maximum depends on the specific model used.", - "controllerType": "slider", + "key": "caching_enabled", + "title": "Caching enabled", + "description": "To enable prompt caching or not", + "controllerType": "checkbox", "controllerProps": { - "min": 0, - "max": 4096, - "step": 128, - "value": 2048 + "value": true + } + }, + { + "key": "cache_type", + "title": "KV Cache Type", + "description": "KV cache type: f16, q8_0, q4_0, default is f16 (change this could break the model).", + "controllerType": "input", + "controllerProps": { + "placeholder": "f16", + "value": "f16" + } + }, + { + "key": "use_mmap", + "title": "To enable mmap", + "description": "To enable mmap, default is true", + "controllerType": "checkbox", + "controllerProps": { + "value": true } } ] diff --git a/extensions/inference-cortex-extension/rollup.config.ts b/extensions/inference-cortex-extension/rollup.config.ts index 2843868697..8fa61e91d8 100644 --- a/extensions/inference-cortex-extension/rollup.config.ts +++ b/extensions/inference-cortex-extension/rollup.config.ts @@ -117,7 +117,7 @@ export default [ qwen2572bJson, ]), NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`), - DEFAULT_SETTINGS: JSON.stringify(defaultSettingJson), + SETTINGS: JSON.stringify(defaultSettingJson), CORTEX_API_URL: JSON.stringify('http://127.0.0.1:39291'), CORTEX_SOCKET_URL: JSON.stringify('ws://127.0.0.1:39291'), CORTEX_ENGINE_VERSION: JSON.stringify('v0.1.40'), diff --git a/extensions/inference-cortex-extension/src/@types/global.d.ts b/extensions/inference-cortex-extension/src/@types/global.d.ts index 381a80f5e6..139d836a56 100644 --- a/extensions/inference-cortex-extension/src/@types/global.d.ts +++ b/extensions/inference-cortex-extension/src/@types/global.d.ts @@ -2,7 +2,7 @@ declare const NODE: string declare const CORTEX_API_URL: string declare const CORTEX_SOCKET_URL: string declare const CORTEX_ENGINE_VERSION: string -declare const DEFAULT_SETTINGS: Array +declare const SETTINGS: Array declare const MODELS: Array /** diff --git a/extensions/inference-cortex-extension/src/index.ts b/extensions/inference-cortex-extension/src/index.ts index 15f7a02940..4e9ffd55a9 100644 --- a/extensions/inference-cortex-extension/src/index.ts +++ b/extensions/inference-cortex-extension/src/index.ts @@ -36,6 +36,15 @@ enum DownloadTypes { DownloadStarted = 'onFileDownloadStarted', } +export enum Settings { + n_parallel = 'n_parallel', + cont_batching = 'cont_batching', + caching_enabled = 'caching_enabled', + flash_attn = 'flash_attn', + cache_type = 'cache_type', + use_mmap = 'use_mmap', +} + /** * A class that implements the InferenceExtension interface from the @janhq/core package. * The class provides methods for initializing and stopping a model, and for making inference requests. @@ -50,6 +59,14 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { shouldReconnect = true + /** Default Engine model load settings */ + n_parallel: number = 4 + cont_batching: boolean = true + caching_enabled: boolean = true + flash_attn: boolean = true + use_mmap: boolean = true + cache_type: string = 'f16' + /** * The URL for making inference requests. */ @@ -60,6 +77,8 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { */ socket?: WebSocket = undefined + abortControllers = new Map() + /** * Subscribes to events emitted by the @janhq/core package. */ @@ -70,6 +89,23 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { super.onLoad() + // Register Settings + this.registerSettings(SETTINGS) + + this.n_parallel = + Number(await this.getSetting(Settings.n_parallel, '4')) ?? 4 + this.cont_batching = await this.getSetting( + Settings.cont_batching, + true + ) + this.caching_enabled = await this.getSetting( + Settings.caching_enabled, + true + ) + this.flash_attn = await this.getSetting(Settings.flash_attn, true) + this.use_mmap = await this.getSetting(Settings.use_mmap, true) + this.cache_type = await this.getSetting(Settings.cache_type, 'f16') + this.queue.add(() => this.clean()) // Run the process watchdog @@ -101,6 +137,22 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { super.onUnload() } + onSettingUpdate(key: string, value: T): void { + if (key === Settings.n_parallel && typeof value === 'string') { + this.n_parallel = Number(value) ?? 1 + } else if (key === Settings.cont_batching && typeof value === 'boolean') { + this.cont_batching = value as boolean + } else if (key === Settings.caching_enabled && typeof value === 'boolean') { + this.caching_enabled = value as boolean + } else if (key === Settings.flash_attn && typeof value === 'boolean') { + this.flash_attn = value as boolean + } else if (key === Settings.cache_type && typeof value === 'string') { + this.cache_type = value as string + } else if (key === Settings.use_mmap && typeof value === 'boolean') { + this.use_mmap = value as boolean + } + } + override async loadModel( model: Model & { file_path?: string } ): Promise { @@ -134,6 +186,10 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { const { mmproj, ...settings } = model.settings model.settings = settings } + const controller = new AbortController() + const { signal } = controller + + this.abortControllers.set(model.id, controller) return await this.queue.add(() => ky @@ -145,13 +201,21 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { model.engine === InferenceEngine.nitro // Legacy model cache ? InferenceEngine.cortex_llamacpp : model.engine, + cont_batching: this.cont_batching, + n_parallel: this.n_parallel, + caching_enabled: this.caching_enabled, + flash_attn: this.flash_attn, + cache_type: this.cache_type, + use_mmap: this.use_mmap, }, timeout: false, + signal, }) .json() .catch(async (e) => { throw (await e.response?.json()) ?? e }) + .finally(() => this.abortControllers.delete(model.id)) .then() ) } @@ -162,6 +226,9 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { json: { model: model.id }, }) .json() + .finally(() => { + this.abortControllers.get(model.id)?.abort() + }) .then() } diff --git a/web/containers/ErrorMessage/index.tsx b/web/containers/ErrorMessage/index.tsx index 532f02259c..4c97da14be 100644 --- a/web/containers/ErrorMessage/index.tsx +++ b/web/containers/ErrorMessage/index.tsx @@ -52,7 +52,7 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => { ) default: return ( -

+

{message.content[0]?.text?.value && ( )} diff --git a/web/screens/Thread/ThreadCenterPanel/LoadModelError/index.tsx b/web/screens/Thread/ThreadCenterPanel/LoadModelError/index.tsx index f17bf43c4c..569e93d62b 100644 --- a/web/screens/Thread/ThreadCenterPanel/LoadModelError/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/LoadModelError/index.tsx @@ -49,7 +49,7 @@ const LoadModelError = () => { } else { return (

- {loadModelError &&

{loadModelError}

} + {loadModelError &&

{loadModelError}

}

{`Something's wrong.`} Access  = (props) => { ) : (