From f2688771fff078c2e29a775784af645bf92cf613 Mon Sep 17 00:00:00 2001 From: hiro Date: Tue, 12 Dec 2023 08:54:43 +0700 Subject: [PATCH 1/2] feat: Add triton trtllm for engine for remote models --- core/src/types/index.ts | 2 +- .../README.md | 78 ++++++ .../package.json | 41 +++ .../src/@types/global.d.ts | 7 + .../src/helpers/sse.ts | 63 +++++ .../src/index.ts | 235 ++++++++++++++++++ .../tsconfig.json | 15 ++ .../webpack.config.js | 38 +++ 8 files changed, 478 insertions(+), 1 deletion(-) create mode 100644 extensions/inference-triton-trtllm-extension/README.md create mode 100644 extensions/inference-triton-trtllm-extension/package.json create mode 100644 extensions/inference-triton-trtllm-extension/src/@types/global.d.ts create mode 100644 extensions/inference-triton-trtllm-extension/src/helpers/sse.ts create mode 100644 extensions/inference-triton-trtllm-extension/src/index.ts create mode 100644 extensions/inference-triton-trtllm-extension/tsconfig.json create mode 100644 extensions/inference-triton-trtllm-extension/webpack.config.js diff --git a/core/src/types/index.ts b/core/src/types/index.ts index 7314a4ae3f..2e19f61d81 100644 --- a/core/src/types/index.ts +++ b/core/src/types/index.ts @@ -174,7 +174,7 @@ export type ThreadState = { enum InferenceEngine { nitro = "nitro", openai = "openai", - nvidia_triton = "nvidia_triton", + triton_trtllm = "triton_trtllm", hf_endpoint = "hf_endpoint", } diff --git a/extensions/inference-triton-trtllm-extension/README.md b/extensions/inference-triton-trtllm-extension/README.md new file mode 100644 index 0000000000..455783efb1 --- /dev/null +++ b/extensions/inference-triton-trtllm-extension/README.md @@ -0,0 +1,78 @@ +# Jan inference plugin + +Created using Jan app example + +# Create a Jan Plugin using Typescript + +Use this template to bootstrap the creation of a TypeScript Jan plugin. 🚀 + +## Create Your Own Plugin + +To create your own plugin, you can use this repository as a template! Just follow the below instructions: + +1. Click the Use this template button at the top of the repository +2. Select Create a new repository +3. Select an owner and name for your new repository +4. Click Create repository +5. Clone your new repository + +## Initial Setup + +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your plugin. + +> [!NOTE] +> +> You'll need to have a reasonably modern version of +> [Node.js](https://nodejs.org) handy. If you are using a version manager like +> [`nodenv`](https://github.com/nodenv/nodenv) or +> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the +> root of your repository to install the version specified in +> [`package.json`](./package.json). Otherwise, 20.x or later should work! + +1. :hammer_and_wrench: Install the dependencies + + ```bash + npm install + ``` + +1. :building_construction: Package the TypeScript for distribution + + ```bash + npm run bundle + ``` + +1. :white_check_mark: Check your artifact + + There will be a tgz file in your plugin directory now + +## Update the Plugin Metadata + +The [`package.json`](package.json) file defines metadata about your plugin, such as +plugin name, main entry, description and version. + +When you copy this repository, update `package.json` with the name, description for your plugin. + +## Update the Plugin Code + +The [`src/`](./src/) directory is the heart of your plugin! This contains the +source code that will be run when your plugin extension functions are invoked. You can replace the +contents of this directory with your own code. + +There are a few things to keep in mind when writing your plugin code: + +- Most Jan Plugin Extension functions are processed asynchronously. + In `index.ts`, you will see that the extension function will return a `Promise`. + + ```typescript + import { core } from "@janhq/core"; + + function onStart(): Promise { + return core.invokePluginFunc(MODULE_PATH, "run", 0); + } + ``` + + For more information about the Jan Plugin Core module, see the + [documentation](https://github.com/janhq/jan/blob/main/core/README.md). + +So, what are you waiting for? Go ahead and start customizing your plugin! + diff --git a/extensions/inference-triton-trtllm-extension/package.json b/extensions/inference-triton-trtllm-extension/package.json new file mode 100644 index 0000000000..862359fe61 --- /dev/null +++ b/extensions/inference-triton-trtllm-extension/package.json @@ -0,0 +1,41 @@ +{ + "name": "@janhq/inference-triton-trt-llm-extension", + "version": "1.0.0", + "description": "Inference Engine for NVIDIA Triton with TensorRT-LLM Extension integration on Jan extension framework", + "main": "dist/index.js", + "module": "dist/module.js", + "author": "Jan ", + "license": "AGPL-3.0", + "scripts": { + "build": "tsc -b . && webpack --config webpack.config.js", + "build:publish": "rimraf *.tgz --glob && npm run build && npm pack && cpx *.tgz ../../electron/pre-install" + }, + "exports": { + ".": "./dist/index.js", + "./main": "./dist/module.js" + }, + "devDependencies": { + "cpx": "^1.5.0", + "rimraf": "^3.0.2", + "webpack": "^5.88.2", + "webpack-cli": "^5.1.4" + }, + "dependencies": { + "@janhq/core": "file:../../core", + "fetch-retry": "^5.0.6", + "path-browserify": "^1.0.1", + "ts-loader": "^9.5.0", + "ulid": "^2.3.0" + }, + "engines": { + "node": ">=18.0.0" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "bundleDependencies": [ + "fetch-retry" + ] +} diff --git a/extensions/inference-triton-trtllm-extension/src/@types/global.d.ts b/extensions/inference-triton-trtllm-extension/src/@types/global.d.ts new file mode 100644 index 0000000000..141284ad68 --- /dev/null +++ b/extensions/inference-triton-trtllm-extension/src/@types/global.d.ts @@ -0,0 +1,7 @@ +import { Model } from "@janhq/core"; + +declare const MODULE: string; + +declare interface EngineSettings { + base_url?: string; +} diff --git a/extensions/inference-triton-trtllm-extension/src/helpers/sse.ts b/extensions/inference-triton-trtllm-extension/src/helpers/sse.ts new file mode 100644 index 0000000000..da20fa32d9 --- /dev/null +++ b/extensions/inference-triton-trtllm-extension/src/helpers/sse.ts @@ -0,0 +1,63 @@ +import { Observable } from "rxjs"; +import { EngineSettings } from "../@types/global"; +import { Model } from "@janhq/core"; + +/** + * Sends a request to the inference server to generate a response based on the recent messages. + * @param recentMessages - An array of recent messages to use as context for the inference. + * @param engine - The engine settings to use for the inference. + * @param model - The model to use for the inference. + * @returns An Observable that emits the generated response as a string. + */ +export function requestInference( + recentMessages: any[], + engine: EngineSettings, + model: Model, + controller?: AbortController +): Observable { + return new Observable((subscriber) => { + const text_input = recentMessages.map((message) => message.text).join("\n"); + const requestBody = JSON.stringify({ + text_input: text_input, + max_tokens: 4096, + temperature: 0, + bad_words: "", + stop_words: "[DONE]", + stream: true + }); + fetch(`${engine.base_url}/v2/models/ensemble/generate_stream`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + "Access-Control-Allow-Origin": "*", + }, + body: requestBody, + signal: controller?.signal, + }) + .then(async (response) => { + const stream = response.body; + const decoder = new TextDecoder("utf-8"); + const reader = stream?.getReader(); + let content = ""; + + while (true && reader) { + const { done, value } = await reader.read(); + if (done) { + break; + } + const text = decoder.decode(value); + const lines = text.trim().split("\n"); + for (const line of lines) { + if (line.startsWith("data: ") && !line.includes("data: [DONE]")) { + const data = JSON.parse(line.replace("data: ", "")); + content += data.choices[0]?.delta?.content ?? ""; + subscriber.next(content); + } + } + } + subscriber.complete(); + }) + .catch((err) => subscriber.error(err)); + }); +} diff --git a/extensions/inference-triton-trtllm-extension/src/index.ts b/extensions/inference-triton-trtllm-extension/src/index.ts new file mode 100644 index 0000000000..9e8d64bb25 --- /dev/null +++ b/extensions/inference-triton-trtllm-extension/src/index.ts @@ -0,0 +1,235 @@ +/** + * @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + * @version 1.0.0 + * @module inference-nvidia-triton-trt-llm-extension/src/index + */ + +import { + ChatCompletionRole, + ContentType, + EventName, + MessageRequest, + MessageStatus, + ModelSettingParams, + ExtensionType, + ThreadContent, + ThreadMessage, + events, + fs, + Model, +} from "@janhq/core"; +import { InferenceExtension } from "@janhq/core"; +import { requestInference } from "./helpers/sse"; +import { ulid } from "ulid"; +import { join } from "path"; +import { EngineSettings } from "./@types/global"; + +/** + * A class that implements the InferenceExtension interface from the @janhq/core package. + * The class provides methods for initializing and stopping a model, and for making inference requests. + * It also subscribes to events emitted by the @janhq/core package and handles new message requests. + */ +export default class JanInferenceTritonTrtLLMExtension implements InferenceExtension { + private static readonly _homeDir = 'engines' + private static readonly _engineMetadataFileName = 'triton_trtllm.json' + + static _currentModel: Model; + + static _engineSettings: EngineSettings = { + "base_url": "", + }; + + controller = new AbortController(); + isCancelled = false; + + /** + * Returns the type of the extension. + * @returns {ExtensionType} The type of the extension. + */ + // TODO: To fix + type(): ExtensionType { + return undefined; + } + /** + * Subscribes to events emitted by the @janhq/core package. + */ + onLoad(): void { + fs.mkdir(JanInferenceTritonTrtLLMExtension._homeDir) + JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings() + + // Events subscription + events.on(EventName.OnMessageSent, (data) => + JanInferenceTritonTrtLLMExtension.handleMessageRequest(data, this) + ); + + events.on(EventName.OnModelInit, (model: Model) => { + JanInferenceTritonTrtLLMExtension.handleModelInit(model); + }); + + events.on(EventName.OnModelStop, (model: Model) => { + JanInferenceTritonTrtLLMExtension.handleModelStop(model); + }); + } + + /** + * Stops the model inference. + */ + onUnload(): void {} + + /** + * Initializes the model with the specified file name. + * @param {string} modelId - The ID of the model to initialize. + * @returns {Promise} A promise that resolves when the model is initialized. + */ + async initModel( + modelId: string, + settings?: ModelSettingParams + ): Promise { + return + } + + static async writeDefaultEngineSettings() { + try { + const engine_json = join(JanInferenceTritonTrtLLMExtension._homeDir, JanInferenceTritonTrtLLMExtension._engineMetadataFileName) + if (await fs.exists(engine_json)) { + JanInferenceTritonTrtLLMExtension._engineSettings = JSON.parse(await fs.readFile(engine_json)) + } + else { + await fs.writeFile(engine_json, JSON.stringify(JanInferenceTritonTrtLLMExtension._engineSettings, null, 2)) + } + } catch (err) { + console.error(err) + } + } + /** + * Stops the model. + * @returns {Promise} A promise that resolves when the model is stopped. + */ + async stopModel(): Promise {} + + /** + * Stops streaming inference. + * @returns {Promise} A promise that resolves when the streaming is stopped. + */ + async stopInference(): Promise { + this.isCancelled = true; + this.controller?.abort(); + } + + /** + * Makes a single response inference request. + * @param {MessageRequest} data - The data for the inference request. + * @returns {Promise} A promise that resolves with the inference response. + */ + async inference(data: MessageRequest): Promise { + const timestamp = Date.now(); + const message: ThreadMessage = { + thread_id: data.threadId, + created: timestamp, + updated: timestamp, + status: MessageStatus.Ready, + id: "", + role: ChatCompletionRole.Assistant, + object: "thread.message", + content: [], + }; + + return new Promise(async (resolve, reject) => { + requestInference(data.messages ?? [], + JanInferenceTritonTrtLLMExtension._engineSettings, + JanInferenceTritonTrtLLMExtension._currentModel) + .subscribe({ + next: (_content) => {}, + complete: async () => { + resolve(message); + }, + error: async (err) => { + reject(err); + }, + }); + }); + } + + private static async handleModelInit(model: Model) { + if (model.engine !== 'triton_trtllm') { return } + else { + JanInferenceTritonTrtLLMExtension._currentModel = model + JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings() + // Todo: Check model list with API key + events.emit(EventName.OnModelReady, model) + // events.emit(EventName.OnModelFail, model) + } + } + + private static async handleModelStop(model: Model) { + if (model.engine !== 'triton_trtllm') { return } + events.emit(EventName.OnModelStopped, model) + } + + /** + * Handles a new message request by making an inference request and emitting events. + * Function registered in event manager, should be static to avoid binding issues. + * Pass instance as a reference. + * @param {MessageRequest} data - The data for the new message request. + */ + private static async handleMessageRequest( + data: MessageRequest, + instance: JanInferenceTritonTrtLLMExtension + ) { + if (data.model.engine !== 'triton_trtllm') { return } + + const timestamp = Date.now(); + const message: ThreadMessage = { + id: ulid(), + thread_id: data.threadId, + assistant_id: data.assistantId, + role: ChatCompletionRole.Assistant, + content: [], + status: MessageStatus.Pending, + created: timestamp, + updated: timestamp, + object: "thread.message", + }; + events.emit(EventName.OnMessageResponse, message); + + instance.isCancelled = false; + instance.controller = new AbortController(); + + requestInference( + data?.messages ?? [], + this._engineSettings, + JanInferenceTritonTrtLLMExtension._currentModel, + instance.controller + ).subscribe({ + next: (content) => { + const messageContent: ThreadContent = { + type: ContentType.Text, + text: { + value: content.trim(), + annotations: [], + }, + }; + message.content = [messageContent]; + events.emit(EventName.OnMessageUpdate, message); + }, + complete: async () => { + message.status = MessageStatus.Ready; + events.emit(EventName.OnMessageUpdate, message); + }, + error: async (err) => { + const messageContent: ThreadContent = { + type: ContentType.Text, + text: { + value: "Error occurred: " + err.message, + annotations: [], + }, + }; + message.content = [messageContent]; + message.status = MessageStatus.Ready; + events.emit(EventName.OnMessageUpdate, message); + }, + }); + } +} diff --git a/extensions/inference-triton-trtllm-extension/tsconfig.json b/extensions/inference-triton-trtllm-extension/tsconfig.json new file mode 100644 index 0000000000..b48175a169 --- /dev/null +++ b/extensions/inference-triton-trtllm-extension/tsconfig.json @@ -0,0 +1,15 @@ +{ + "compilerOptions": { + "target": "es2016", + "module": "ES6", + "moduleResolution": "node", + + "outDir": "./dist", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": false, + "skipLibCheck": true, + "rootDir": "./src" + }, + "include": ["./src"] +} diff --git a/extensions/inference-triton-trtllm-extension/webpack.config.js b/extensions/inference-triton-trtllm-extension/webpack.config.js new file mode 100644 index 0000000000..57a0adb0a2 --- /dev/null +++ b/extensions/inference-triton-trtllm-extension/webpack.config.js @@ -0,0 +1,38 @@ +const path = require("path"); +const webpack = require("webpack"); +const packageJson = require("./package.json"); + +module.exports = { + experiments: { outputModule: true }, + entry: "./src/index.ts", // Adjust the entry point to match your project's main file + mode: "production", + module: { + rules: [ + { + test: /\.tsx?$/, + use: "ts-loader", + exclude: /node_modules/, + }, + ], + }, + plugins: [ + new webpack.DefinePlugin({ + MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`), + }), + ], + output: { + filename: "index.js", // Adjust the output file name as needed + path: path.resolve(__dirname, "dist"), + library: { type: "module" }, // Specify ESM output format + }, + resolve: { + extensions: [".ts", ".js"], + fallback: { + path: require.resolve("path-browserify"), + }, + }, + optimization: { + minimize: false, + }, + // Add loaders and other configuration as needed for your project +}; From 587f5addfa663e915ccb349b83079a79cdfa6474 Mon Sep 17 00:00:00 2001 From: hiro Date: Wed, 13 Dec 2023 01:27:18 +0700 Subject: [PATCH 2/2] fix: Fix issues based on Louis comments --- extensions/inference-triton-trtllm-extension/package.json | 3 ++- .../inference-triton-trtllm-extension/src/@types/global.d.ts | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/extensions/inference-triton-trtllm-extension/package.json b/extensions/inference-triton-trtllm-extension/package.json index 862359fe61..ff2d4cc8bb 100644 --- a/extensions/inference-triton-trtllm-extension/package.json +++ b/extensions/inference-triton-trtllm-extension/package.json @@ -25,7 +25,8 @@ "fetch-retry": "^5.0.6", "path-browserify": "^1.0.1", "ts-loader": "^9.5.0", - "ulid": "^2.3.0" + "ulid": "^2.3.0", + "rxjs": "^7.8.1" }, "engines": { "node": ">=18.0.0" diff --git a/extensions/inference-triton-trtllm-extension/src/@types/global.d.ts b/extensions/inference-triton-trtllm-extension/src/@types/global.d.ts index 141284ad68..6224b8e68c 100644 --- a/extensions/inference-triton-trtllm-extension/src/@types/global.d.ts +++ b/extensions/inference-triton-trtllm-extension/src/@types/global.d.ts @@ -1,7 +1,5 @@ import { Model } from "@janhq/core"; -declare const MODULE: string; - declare interface EngineSettings { base_url?: string; }