Skip to content

Commit

Permalink
use message history
Browse files Browse the repository at this point in the history
  • Loading branch information
lx-0 committed Oct 28, 2024
1 parent 6b91272 commit f43967d
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 65 deletions.
2 changes: 2 additions & 0 deletions .cursorrules
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,5 @@ Ensure all required dependencies are properly installed and typed:
- Implement proper error handling
- Add response validation
- Support function execution tracking
- Add message history support
- Handle message reconstruction
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ A Next.js application that uses a large language model to control a computer thr
> - ✅ Base architecture
> - ✅ Model selection
> - ✅ Model tracking
> - ✅ Message history
> - 🔳 Context management
> - 🔳 Function calling
> - ⬜ Streaming support
> - ⬜ Computer use tooling
Expand Down
25 changes: 24 additions & 1 deletion src/app/api/llm/route.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { LLMService } from '@/services/llm.service';
import { AIMessage, HumanMessage, SystemMessage } from '@langchain/core/messages';
import { NextRequest, NextResponse } from 'next/server';

export async function POST(req: NextRequest) {
Expand All @@ -9,8 +10,30 @@ export async function POST(req: NextRequest) {
return NextResponse.json({ error: 'Model ID is required' }, { status: 400 });
}

// Reconstruct Langchain message instances
const history = options?.history
?.map((msg: any) => {
if (msg.type === 'constructor') {
switch (msg.id[2]) {
case 'HumanMessage':
return new HumanMessage(msg.kwargs);
case 'AIMessage':
return new AIMessage(msg.kwargs);
case 'SystemMessage':
return new SystemMessage(msg.kwargs);
default:
return null;
}
}
return null;
})
.filter(Boolean);

const llmService = LLMService.getInstance();
const response = await llmService.sendMessage(message, modelId, options);
const response = await llmService.sendMessage(message, modelId, {
...options,
history,
});

return NextResponse.json(response);
} catch (error) {
Expand Down
11 changes: 8 additions & 3 deletions src/components/chat/ChatComponent.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { ScrollArea } from '@/components/ui/scroll-area';
import { useChatMessages } from '@/hooks/useChatMessages';
import { useDockerHandlers } from '@/hooks/useDockerHandlers';
import { AVAILABLE_MODELS } from '@/lib/llm/types';
import { AVAILABLE_MODELS, convertToLangchainMessage } from '@/lib/llm/types';
import { cn } from '@/lib/utils';
import { LLMApiService } from '@/services/llm-api.service';
import { AIMessage, HumanMessage, SystemMessage } from '@langchain/core/messages';
import { Settings as SettingsIcon } from 'lucide-react';
import { useCallback, useEffect, useRef, useState } from 'react';
import ChatCopyButton from './ChatCopyButton';
Expand Down Expand Up @@ -80,15 +81,19 @@ const ChatComponent: React.FC<ChatComponentProps> = ({
const handleSendMessage = async () => {
if (!inputMessage.trim()) return;

// Find the selected model info
const selectedModelInfo = AVAILABLE_MODELS.find((m) => m.id === selectedModel);

const userMessageId = addChatMessage('user', inputMessage);
setInputMessage('');

try {
// Convert and filter out log messages and nulls
const history = chatMessages
.map(convertToLangchainMessage)
.filter((msg): msg is HumanMessage | AIMessage | SystemMessage => msg !== null);

const response = await llmApiService.sendMessage(inputMessage, selectedModel, {
stream: false,
history,
});
addChatMessage('assistant', response.content, undefined, undefined, selectedModelInfo);
} catch (error) {
Expand Down
21 changes: 12 additions & 9 deletions src/lib/llm/provider.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import { ChatAnthropic } from '@langchain/anthropic';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { HumanMessage } from '@langchain/core/messages';
import { AIMessage, HumanMessage, SystemMessage } from '@langchain/core/messages';
import { ChatOpenAI } from '@langchain/openai';
import { FunctionDefinition, LLMConfig, LLMResponse } from './types';
import { FunctionDefinition, LLMConfig, LLMRequestOptions, LLMResponse } from './types';

interface GenerateOptions {
export interface GenerateOptions {
functions?: string[];
stream?: boolean;
maxTokens?: number;
history?: Array<HumanMessage | AIMessage | SystemMessage>;
}

export class LLMProvider {
Expand Down Expand Up @@ -46,13 +47,14 @@ export class LLMProvider {
this.functions.set(definition.name, definition);
}

public async generateResponse(prompt: string, options?: GenerateOptions): Promise<LLMResponse> {
public async generateResponse(prompt: string, options?: LLMRequestOptions): Promise<LLMResponse> {
try {
const response = await this.model.invoke([
new HumanMessage({
content: prompt,
}),
]);
const messages = [
...(Array.isArray(options?.history) ? options.history : []),
new HumanMessage({ content: prompt }),
];

const response = await this.model.invoke(messages);

const content =
typeof response.content === 'string' ? response.content : JSON.stringify(response.content);
Expand All @@ -66,6 +68,7 @@ export class LLMProvider {
},
};
} catch (error) {
console.error('Provider error:', error);
throw new Error(
`Failed to generate response: ${error instanceof Error ? error.message : 'Unknown error'}`
);
Expand Down
60 changes: 60 additions & 0 deletions src/lib/llm/types.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import { ChatMessageData } from '@/components/chat/ChatMessage';
import { AIMessage, BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messages';

// Core LLM types
export type LLMProvider = 'openai' | 'anthropic' | 'local';

Expand Down Expand Up @@ -146,3 +149,60 @@ export const AVAILABLE_MODELS_SORTED = AVAILABLE_MODELS.sort((a, b) => {
// Then by context window size (larger first)
return b.contextWindow - a.contextWindow;
});

// Chat Memory Types
export interface ChatMemory {
messages: BaseMessage[]; // Changed from ChatMessage[] to BaseMessage[]
returnMessages: boolean;
maxTokens?: number;
}

export interface ChatMessageHistory {
addMessage(message: BaseMessage): Promise<void>;
getMessages(): Promise<BaseMessage[]>;
clear(): Promise<void>;
}

export interface LLMRequestOptions {
stream?: boolean;
functions?: string[];
history?: Array<HumanMessage | AIMessage | SystemMessage>;
maxTokens?: number;
}

// Convert our message types to Langchain message types
export function convertToLangchainMessage(
message: ChatMessageData
): HumanMessage | AIMessage | SystemMessage | null {
// Skip log messages
if (message.type === 'log') {
return null;
}

switch (message.type) {
case 'assistant':
return new AIMessage({ content: message.content });
case 'system':
return new SystemMessage({ content: message.content });
case 'user':
return new HumanMessage({ content: message.content });
default:
return null;
}
}

// Helper type for message roles
export type MessageRole = 'human' | 'assistant' | 'system';

// Helper function to create messages with proper typing
export function createMessage(content: string, role: MessageRole): BaseMessage {
switch (role) {
case 'assistant':
return new AIMessage(content);
case 'system':
return new SystemMessage(content);
case 'human':
default:
return new HumanMessage({ content });
}
}
15 changes: 10 additions & 5 deletions src/services/llm-api.service.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
'use client';

import { GenerateOptions } from '@/lib/llm/provider';
import { LLMResponse } from '@/lib/llm/types';

export class LLMApiService {
Expand All @@ -15,18 +16,22 @@ export class LLMApiService {
public async sendMessage(
message: string,
modelId: string,
options?: {
stream?: boolean;
functions?: string[];
}
options?: GenerateOptions
): Promise<LLMResponse> {
const response = await fetch('/api/llm', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': process.env.NEXT_PUBLIC_API_KEY || '', // Ensure this is set
},
body: JSON.stringify({ message, modelId, options }),
body: JSON.stringify({
message,
modelId,
options: {
...options,
history: options?.history || [],
},
}),
});

if (!response.ok) {
Expand Down
27 changes: 16 additions & 11 deletions src/services/llm.service.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import { LLMProvider } from '@/lib/llm/provider';
import { FunctionRegistry } from '@/lib/llm/registry';
import { AVAILABLE_MODELS, FunctionDefinition, LLMConfig } from '@/lib/llm/types';
import {
AVAILABLE_MODELS,
FunctionDefinition,
LLMConfig,
LLMRequestOptions,
} from '@/lib/llm/types';
import { AIMessage, HumanMessage, SystemMessage } from '@langchain/core/messages';

export class LLMService {
private static instance: LLMService;
Expand Down Expand Up @@ -31,7 +37,7 @@ export class LLMService {
model: model.id,
apiKey: this.getApiKey(model.provider),
temperature: 0.7,
maxTokens: model.maxOutputTokens, // Use the model-specific output token limit
maxTokens: model.maxOutputTokens,
};

this.providers.set(modelId, new LLMProvider(config));
Expand All @@ -51,24 +57,23 @@ export class LLMService {
return key;
}

public async sendMessage(
message: string,
modelId: string,
options?: {
stream?: boolean;
functions?: string[];
maxTokens?: number;
}
) {
public async sendMessage(message: string, modelId: string, options?: LLMRequestOptions) {
try {
const model = AVAILABLE_MODELS.find((m) => m.id === modelId);
if (!model) {
throw new Error(`Model ${modelId} not found`);
}

// Ensure history contains valid Langchain message types
const history = options?.history?.filter(
(msg) =>
msg instanceof HumanMessage || msg instanceof AIMessage || msg instanceof SystemMessage
);

const provider = this.getProvider(modelId);
return await provider.generateResponse(message, {
...options,
history,
maxTokens: model.maxOutputTokens,
});
} catch (error) {
Expand Down
36 changes: 0 additions & 36 deletions src/types/langchain.d.ts

This file was deleted.

0 comments on commit f43967d

Please sign in to comment.