Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for MLX #5872

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions apps/desktop/src/lib/ai/mlxClient.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import {
LONG_DEFAULT_BRANCH_TEMPLATE,
LONG_DEFAULT_COMMIT_TEMPLATE,
SHORT_DEFAULT_PR_TEMPLATE
} from '$lib/ai/prompts';
import { MessageRole, type PromptMessage, type AIClient, type Prompt } from '$lib/ai/types';
import { andThen, buildFailureFromAny, ok, wrap, wrapAsync, type Result } from '$lib/result';
import { isNonEmptyObject } from '@gitbutler/ui/utils/typeguards';
import { fetch } from '@tauri-apps/plugin-http';

export const DEFAULT_MLX_ENDPOINT = 'http://localhost:8080';
export const DEFAULT_MLX_MODEL_NAME = 'mlx-community/Llama-3.2-3B-Instruct-4bit';

enum MLXApiEndpoint {
Chat = 'v1/chat/completions',
}

interface MLXRequestOptions {
/**
* The temperature of the model.
* Increasing the temperature will make the model answer more creatively. (Default: 0.8)
*/
temperature: number;
repetition_penalty: number;
top_p: number;
max_tokens: number;
}

interface MLXChatRequest {
model: string;
stream: boolean;
messages: Prompt;
options?: MLXRequestOptions;
}

interface MLXChatResponse {
choices: [MLXChatResponseChoice];
}

interface MLXChatResponseChoice {
message: PromptMessage;
}

interface MLXChatMessageFormat {
result: string;
}

const MLX_CHAT_MESSAGE_FORMAT_SCHEMA = {
type: 'object',
properties: {
result: { type: 'string' }
},
required: ['result'],
additionalProperties: false
};

function isMLXChatMessageFormat(message: unknown): message is MLXChatMessageFormat {
return isNonEmptyObject(message) && message.result !== undefined;
}

function isMLXChatResponse(response: MLXChatResponse): response is MLXChatResponse {
if (!isNonEmptyObject(response)) {
return false;
}

return response.choices.length > 0 && response.choices[0].message !== undefined;
}

export class MLXClient implements AIClient {
defaultCommitTemplate = LONG_DEFAULT_COMMIT_TEMPLATE;
defaultBranchTemplate = LONG_DEFAULT_BRANCH_TEMPLATE;
defaultPRTemplate = SHORT_DEFAULT_PR_TEMPLATE;

constructor(
private endpoint: string,
private modelName: string
) {}

async evaluate(prompt: Prompt): Promise<Result<string, Error>> {
const messages = this.formatPrompt(prompt);

const options = {
temperature: 1.0,
repetition_penalty: 1.5,
top_p: 1.0,
max_tokens: 512
}
const responseResult = await this.chat(messages, options);

return andThen(responseResult, (response) => {
const choice = response.choices[0];
const rawResponseResult = wrap<unknown, Error>(() => JSON.parse(choice.message.content));

return andThen(rawResponseResult, (rawResponse) => {
if (!isMLXChatMessageFormat(rawResponse)) {
return buildFailureFromAny('Invalid response: ' + choice.message.content);
}

return ok(rawResponse.result);
});
});
}

private formatPrompt(prompt: Prompt): Prompt {
const withFormattedResponses = prompt.map((promptMessage) => {
if (promptMessage.role === MessageRole.Assistant) {
return {
role: MessageRole.Assistant,
content: JSON.stringify({ result: promptMessage.content })
};
} else {
return promptMessage;
}
});

return [
{
role: MessageRole.System,
content: `You are an expert in software development. Answer the given user prompts following the specified instructions.
Return your response in JSON and only use the following JSON schema:
<json schema>
${JSON.stringify(MLX_CHAT_MESSAGE_FORMAT_SCHEMA.properties, null, 2)}
</json schema>
EXAMPLE:
<json>
{"result": "Your content here"}
</json>
Ensure that your response is valid JSON and adheres to the provided JSON schema.
`

},
...withFormattedResponses
];
}

private async fetchChat(request: MLXChatRequest): Promise<Result<any, Error>> {
const url = new URL(MLXApiEndpoint.Chat, this.endpoint);
const body = JSON.stringify(request);
return await wrapAsync(
async () =>
await fetch(url.toString(), {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body
}).then(async (response) => await response.json())
);
}

private async chat(
messages: Prompt,
options?: MLXRequestOptions
): Promise<Result<MLXChatResponse, Error>> {
const result = await this.fetchChat({
model: this.modelName,
stream: false,
messages,
options,
});

return andThen(result, (result) => {
if (!isMLXChatResponse(result)) {
return buildFailureFromAny('Invalid response\n' + JSON.stringify(result.data));
}

return ok(result);
});
}
}
34 changes: 32 additions & 2 deletions apps/desktop/src/lib/ai/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import {
} from './types';
import { AnthropicAIClient } from '$lib/ai/anthropicClient';
import { ButlerAIClient } from '$lib/ai/butlerClient';
import {
DEFAULT_MLX_ENDPOINT,
DEFAULT_MLX_MODEL_NAME,
MLXClient
} from '$lib/ai/mlxClient';
import {
DEFAULT_OLLAMA_ENDPOINT,
DEFAULT_OLLAMA_MODEL_NAME,
Expand Down Expand Up @@ -44,7 +49,9 @@ export enum GitAIConfigKey {
AnthropicModelName = 'gitbutler.aiAnthropicModelName',
DiffLengthLimit = 'gitbutler.diffLengthLimit',
OllamaEndpoint = 'gitbutler.aiOllamaEndpoint',
OllamaModelName = 'gitbutler.aiOllamaModelName'
OllamaModelName = 'gitbutler.aiOllamaModelName',
MlxModelName = 'gitbutler.aiMlxModelName',
MlxEndpoint = 'gitbutler.aiMlxEndpoint'
}

interface BaseAIServiceOpts {
Expand Down Expand Up @@ -182,6 +189,20 @@ export class AIService {
);
}

async getMlxEndpoint() {
return await this.gitConfig.getWithDefault<string>(
GitAIConfigKey.MlxEndpoint,
DEFAULT_MLX_ENDPOINT
);
}

async getMlxModelName() {
return await this.gitConfig.getWithDefault<string>(
GitAIConfigKey.MlxModelName,
DEFAULT_MLX_MODEL_NAME
);
}

async usingGitButlerAPI() {
const modelKind = await this.getModelKind();
const openAIKeyOption = await this.getOpenAIKeyOption();
Expand All @@ -201,16 +222,19 @@ export class AIService {
const anthropicKey = await this.getAnthropicKey();
const ollamaEndpoint = await this.getOllamaEndpoint();
const ollamaModelName = await this.getOllamaModelName();
const mlxEndpoint = await this.getMlxEndpoint();
const mlxModelName = await this.getMlxModelName();

if (await this.usingGitButlerAPI()) return !!get(this.tokenMemoryService.token);

const openAIActiveAndKeyProvided = modelKind === ModelKind.OpenAI && !!openAIKey;
const anthropicActiveAndKeyProvided = modelKind === ModelKind.Anthropic && !!anthropicKey;
const ollamaActiveAndEndpointProvided =
modelKind === ModelKind.Ollama && !!ollamaEndpoint && !!ollamaModelName;
const mlxActiveAndEndpointProvided = modelKind === ModelKind.MLX && !!mlxEndpoint && !!mlxModelName;

return (
openAIActiveAndKeyProvided || anthropicActiveAndKeyProvided || ollamaActiveAndEndpointProvided
openAIActiveAndKeyProvided || anthropicActiveAndKeyProvided || ollamaActiveAndEndpointProvided || mlxActiveAndEndpointProvided
);
}

Expand Down Expand Up @@ -238,6 +262,12 @@ export class AIService {
return ok(new OllamaClient(ollamaEndpoint, ollamaModelName));
}

if (modelKind === ModelKind.MLX) {
const mlxEndpoint = await this.getMlxEndpoint();
const mlxModelName = await this.getMlxModelName();
return ok(new MLXClient(mlxEndpoint, mlxModelName));
}

if (modelKind === ModelKind.OpenAI) {
const openAIModelName = await this.getOpenAIModleName();
const openAIKey = await this.getOpenAIKey();
Expand Down
3 changes: 2 additions & 1 deletion apps/desktop/src/lib/ai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import type { Persisted } from '@gitbutler/shared/persisted';
export enum ModelKind {
OpenAI = 'openai',
Anthropic = 'anthropic',
Ollama = 'ollama'
Ollama = 'ollama',
MLX = 'mlx'
}

// https://platform.openai.com/docs/models
Expand Down
38 changes: 38 additions & 0 deletions apps/desktop/src/routes/settings/ai/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
let diffLengthLimit: number | undefined = $state();
let ollamaEndpoint: string | undefined = $state();
let ollamaModel: string | undefined = $state();
let mlxEndpoint: string | undefined = $state();
let mlxModel: string | undefined = $state();

async function setConfiguration(key: GitAIConfigKey, value: string | undefined) {
if (!initialized) return;
Expand Down Expand Up @@ -63,6 +65,9 @@
ollamaEndpoint = await aiService.getOllamaEndpoint();
ollamaModel = await aiService.getOllamaModelName();

mlxEndpoint = await aiService.getMlxEndpoint();
mlxModel = await aiService.getMlxModelName();

// Ensure reactive declarations have finished running before we set initialized to true
await tick();

Expand Down Expand Up @@ -158,6 +163,12 @@
run(() => {
setConfiguration(GitAIConfigKey.OllamaModelName, ollamaModel);
});
run(() => {
setConfiguration(GitAIConfigKey.MlxEndpoint, mlxEndpoint);
});
run(() => {
setConfiguration(GitAIConfigKey.MlxModelName, mlxModel);
});
run(() => {
if (form) form.modelKind.value = modelKind;
});
Expand Down Expand Up @@ -335,6 +346,33 @@
</div>
</SectionCard>
{/if}

<SectionCard
roundedBottom={modelKind !== ModelKind.MLX}
orientation="row"
labelFor="mlx"
bottomBorder={modelKind !== ModelKind.MLX}
>
{#snippet title()}
MLX
{/snippet}
{#snippet actions()}
<RadioButton name="modelKind" id="custom" value={ModelKind.MLX} />
{/snippet}
</SectionCard>
{#if modelKind === ModelKind.MLX}
<SectionCard roundedTop={false} orientation="row" topDivider>
<div class="inputs-group">
<Textbox
label="Endpoint"
bind:value={mlxEndpoint}
placeholder="http://localhost:8080"
/>

<Textbox label="Model" bind:value={mlxModel} placeholder="mlx-community/Llama-3.2-3B-Instruct-4bit" />
</div>
</SectionCard>
{/if}
</form>

<Spacer />
Expand Down
Loading