From 74dadd0bced4e38b4c8f591b1166524f64b757c6 Mon Sep 17 00:00:00 2001 From: Alberto Sigismondi Date: Sun, 22 Dec 2024 18:54:59 +0100 Subject: [PATCH 1/2] MLX support for AI service: Implement MLX client utility functions distributed as MLXClient enum to mirror Ollama client functionality, leveraging MLX API endpoints and types. --- apps/desktop/src/lib/ai/mlxClient.ts | 170 ++++++++++++++++++ apps/desktop/src/lib/ai/service.ts | 34 +++- apps/desktop/src/lib/ai/types.ts | 3 +- .../src/lib/commit/CommitMessageInput.svelte | 2 +- .../src/routes/settings/ai/+page.svelte | 38 ++++ 5 files changed, 243 insertions(+), 4 deletions(-) create mode 100644 apps/desktop/src/lib/ai/mlxClient.ts diff --git a/apps/desktop/src/lib/ai/mlxClient.ts b/apps/desktop/src/lib/ai/mlxClient.ts new file mode 100644 index 0000000000..de1515d1d1 --- /dev/null +++ b/apps/desktop/src/lib/ai/mlxClient.ts @@ -0,0 +1,170 @@ +import { + LONG_DEFAULT_BRANCH_TEMPLATE, + LONG_DEFAULT_COMMIT_TEMPLATE, + SHORT_DEFAULT_PR_TEMPLATE +} from '$lib/ai/prompts'; +import { MessageRole, type PromptMessage, type AIClient, type Prompt } from '$lib/ai/types'; +import { andThen, buildFailureFromAny, ok, wrap, wrapAsync, type Result } from '$lib/result'; +import { isNonEmptyObject } from '@gitbutler/ui/utils/typeguards'; +import { fetch } from '@tauri-apps/plugin-http'; + +export const DEFAULT_MLX_ENDPOINT = 'http://localhost:8080'; +export const DEFAULT_MLX_MODEL_NAME = 'mlx-community/Llama-3.2-3B-Instruct-4bit'; + +enum MLXApiEndpoint { + Chat = 'v1/chat/completions', +} + +interface MLXRequestOptions { + /** + * The temperature of the model. + * Increasing the temperature will make the model answer more creatively. (Default: 0.8) + */ + temperature: number; + repetition_penalty: number; + top_p: number; + max_tokens: number; +} + +interface MLXChatRequest { + model: string; + stream: boolean; + messages: Prompt; + options?: MLXRequestOptions; +} + +interface MLXChatResponse { + choices: [MLXChatResponseChoice]; +} + +interface MLXChatResponseChoice { + message: PromptMessage; +} + +interface MLXChatMessageFormat { + result: string; +} + +const MLX_CHAT_MESSAGE_FORMAT_SCHEMA = { + type: 'object', + properties: { + result: { type: 'string' } + }, + required: ['result'], + additionalProperties: false +}; + +function isMLXChatMessageFormat(message: unknown): message is MLXChatMessageFormat { + return isNonEmptyObject(message) && message.result !== undefined; +} + +function isMLXChatResponse(response: MLXChatResponse): response is MLXChatResponse { + if (!isNonEmptyObject(response)) { + return false; + } + + return response.choices.length > 0 && response.choices[0].message !== undefined; +} + +export class MLXClient implements AIClient { + defaultCommitTemplate = LONG_DEFAULT_COMMIT_TEMPLATE; + defaultBranchTemplate = LONG_DEFAULT_BRANCH_TEMPLATE; + defaultPRTemplate = SHORT_DEFAULT_PR_TEMPLATE; + + constructor( + private endpoint: string, + private modelName: string + ) {} + + async evaluate(prompt: Prompt): Promise> { + const messages = this.formatPrompt(prompt); + + const options = { + temperature: 1.0, + repetition_penalty: 1.5, + top_p: 1.0, + max_tokens: 512 + } + const responseResult = await this.chat(messages, options); + + return andThen(responseResult, (response) => { + const choice = response.choices[0]; + const rawResponseResult = wrap(() => JSON.parse(choice.message.content)); + + return andThen(rawResponseResult, (rawResponse) => { + if (!isMLXChatMessageFormat(rawResponse)) { + return buildFailureFromAny('Invalid response: ' + choice.message.content); + } + + return ok(rawResponse.result); + }); + }); + } + + private formatPrompt(prompt: Prompt): Prompt { + const withFormattedResponses = prompt.map((promptMessage) => { + if (promptMessage.role === MessageRole.Assistant) { + return { + role: MessageRole.Assistant, + content: JSON.stringify({ result: promptMessage.content }) + }; + } else { + return promptMessage; + } + }); + + return [ + { + role: MessageRole.System, + content: `You are an expert in software development. Answer the given user prompts following the specified instructions. +Return your response in JSON and only use the following JSON schema: + +${JSON.stringify(MLX_CHAT_MESSAGE_FORMAT_SCHEMA.properties, null, 2)} + +EXAMPLE: + +{"result": "Your content here"} + +Ensure that your response is valid JSON and adheres to the provided JSON schema. +` + + }, + ...withFormattedResponses + ]; + } + + private async fetchChat(request: MLXChatRequest): Promise> { + const url = new URL(MLXApiEndpoint.Chat, this.endpoint); + const body = JSON.stringify(request); + return await wrapAsync( + async () => + await fetch(url.toString(), { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body + }).then(async (response) => await response.json()) + ); + } + + private async chat( + messages: Prompt, + options?: MLXRequestOptions + ): Promise> { + const result = await this.fetchChat({ + model: this.modelName, + stream: false, + messages, + options, + }); + + return andThen(result, (result) => { + if (!isMLXChatResponse(result)) { + return buildFailureFromAny('Invalid response\n' + JSON.stringify(result.data)); + } + + return ok(result); + }); + } +} diff --git a/apps/desktop/src/lib/ai/service.ts b/apps/desktop/src/lib/ai/service.ts index fd5c46c776..69279176fb 100644 --- a/apps/desktop/src/lib/ai/service.ts +++ b/apps/desktop/src/lib/ai/service.ts @@ -9,6 +9,11 @@ import { } from './types'; import { AnthropicAIClient } from '$lib/ai/anthropicClient'; import { ButlerAIClient } from '$lib/ai/butlerClient'; +import { + DEFAULT_MLX_ENDPOINT, + DEFAULT_MLX_MODEL_NAME, + MLXClient +} from '$lib/ai/mlxClient'; import { DEFAULT_OLLAMA_ENDPOINT, DEFAULT_OLLAMA_MODEL_NAME, @@ -44,7 +49,9 @@ export enum GitAIConfigKey { AnthropicModelName = 'gitbutler.aiAnthropicModelName', DiffLengthLimit = 'gitbutler.diffLengthLimit', OllamaEndpoint = 'gitbutler.aiOllamaEndpoint', - OllamaModelName = 'gitbutler.aiOllamaModelName' + OllamaModelName = 'gitbutler.aiOllamaModelName', + MlxModelName = 'gitbutler.aiMlxModelName', + MlxEndpoint = 'gitbutler.aiMlxEndpoint' } interface BaseAIServiceOpts { @@ -182,6 +189,20 @@ export class AIService { ); } + async getMlxEndpoint() { + return await this.gitConfig.getWithDefault( + GitAIConfigKey.MlxEndpoint, + DEFAULT_MLX_ENDPOINT + ); + } + + async getMlxModelName() { + return await this.gitConfig.getWithDefault( + GitAIConfigKey.MlxModelName, + DEFAULT_MLX_MODEL_NAME + ); + } + async usingGitButlerAPI() { const modelKind = await this.getModelKind(); const openAIKeyOption = await this.getOpenAIKeyOption(); @@ -201,6 +222,8 @@ export class AIService { const anthropicKey = await this.getAnthropicKey(); const ollamaEndpoint = await this.getOllamaEndpoint(); const ollamaModelName = await this.getOllamaModelName(); + const mlxEndpoint = await this.getMlxEndpoint(); + const mlxModelName = await this.getMlxModelName(); if (await this.usingGitButlerAPI()) return !!get(this.tokenMemoryService.token); @@ -208,9 +231,10 @@ export class AIService { const anthropicActiveAndKeyProvided = modelKind === ModelKind.Anthropic && !!anthropicKey; const ollamaActiveAndEndpointProvided = modelKind === ModelKind.Ollama && !!ollamaEndpoint && !!ollamaModelName; + const mlxActiveAndEndpointProvided = modelKind === ModelKind.MLX && !!mlxEndpoint && !!mlxModelName; return ( - openAIActiveAndKeyProvided || anthropicActiveAndKeyProvided || ollamaActiveAndEndpointProvided + openAIActiveAndKeyProvided || anthropicActiveAndKeyProvided || ollamaActiveAndEndpointProvided || mlxActiveAndEndpointProvided ); } @@ -238,6 +262,12 @@ export class AIService { return ok(new OllamaClient(ollamaEndpoint, ollamaModelName)); } + if (modelKind === ModelKind.MLX) { + const mlxEndpoint = await this.getMlxEndpoint(); + const mlxModelName = await this.getMlxModelName(); + return ok(new MLXClient(mlxEndpoint, mlxModelName)); + } + if (modelKind === ModelKind.OpenAI) { const openAIModelName = await this.getOpenAIModleName(); const openAIKey = await this.getOpenAIKey(); diff --git a/apps/desktop/src/lib/ai/types.ts b/apps/desktop/src/lib/ai/types.ts index 9d7fb9df45..5c537cf85d 100644 --- a/apps/desktop/src/lib/ai/types.ts +++ b/apps/desktop/src/lib/ai/types.ts @@ -5,7 +5,8 @@ import type { Persisted } from '@gitbutler/shared/persisted'; export enum ModelKind { OpenAI = 'openai', Anthropic = 'anthropic', - Ollama = 'ollama' + Ollama = 'ollama', + MLX = 'mlx' } // https://platform.openai.com/docs/models diff --git a/apps/desktop/src/lib/commit/CommitMessageInput.svelte b/apps/desktop/src/lib/commit/CommitMessageInput.svelte index 6267b5ee65..8b78da89ca 100644 --- a/apps/desktop/src/lib/commit/CommitMessageInput.svelte +++ b/apps/desktop/src/lib/commit/CommitMessageInput.svelte @@ -273,7 +273,7 @@ style="ghost" outline icon="ai-small" - disabled={!($aiGenEnabled && aiConfigurationValid)} + disabled={!(aiConfigurationValid)} loading={aiLoading} menuPosition="top" onclick={async () => await generateCommitMessage()} diff --git a/apps/desktop/src/routes/settings/ai/+page.svelte b/apps/desktop/src/routes/settings/ai/+page.svelte index b1f2917ec5..4a2bef730b 100644 --- a/apps/desktop/src/routes/settings/ai/+page.svelte +++ b/apps/desktop/src/routes/settings/ai/+page.svelte @@ -36,6 +36,8 @@ let diffLengthLimit: number | undefined = $state(); let ollamaEndpoint: string | undefined = $state(); let ollamaModel: string | undefined = $state(); + let mlxEndpoint: string | undefined = $state(); + let mlxModel: string | undefined = $state(); async function setConfiguration(key: GitAIConfigKey, value: string | undefined) { if (!initialized) return; @@ -63,6 +65,9 @@ ollamaEndpoint = await aiService.getOllamaEndpoint(); ollamaModel = await aiService.getOllamaModelName(); + mlxEndpoint = await aiService.getMlxEndpoint(); + mlxModel = await aiService.getMlxModelName(); + // Ensure reactive declarations have finished running before we set initialized to true await tick(); @@ -158,6 +163,12 @@ run(() => { setConfiguration(GitAIConfigKey.OllamaModelName, ollamaModel); }); + run(() => { + setConfiguration(GitAIConfigKey.MlxEndpoint, mlxEndpoint); + }); + run(() => { + setConfiguration(GitAIConfigKey.MlxModelName, mlxModel); + }); run(() => { if (form) form.modelKind.value = modelKind; }); @@ -335,6 +346,33 @@ {/if} + + + {#snippet title()} + MLX + {/snippet} + {#snippet actions()} + + {/snippet} + + {#if modelKind === ModelKind.MLX} + +
+ + + +
+
+ {/if} From 540978b01d3c55566e07dbb292de2438d26a4c60 Mon Sep 17 00:00:00 2001 From: Alberto Sigismondi Date: Fri, 27 Dec 2024 11:46:43 +0100 Subject: [PATCH 2/2] Reverted disabled attribute for Generate message. --- apps/desktop/src/lib/commit/CommitMessageInput.svelte | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/desktop/src/lib/commit/CommitMessageInput.svelte b/apps/desktop/src/lib/commit/CommitMessageInput.svelte index 1a8d21f5b4..456a4e5d1b 100644 --- a/apps/desktop/src/lib/commit/CommitMessageInput.svelte +++ b/apps/desktop/src/lib/commit/CommitMessageInput.svelte @@ -292,7 +292,7 @@ style="ghost" outline icon="ai-small" - disabled={!(aiConfigurationValid)} + disabled={!($aiGenEnabled && aiConfigurationValid)} loading={aiLoading} menuPosition="top" onclick={async () => await generateCommitMessage()}