Skip to content

Commit

Permalink
Merge pull request #2 from cpunion/chat
Browse files Browse the repository at this point in the history
switch to anthropic-sdk-go for better stream control
  • Loading branch information
cpunion authored Dec 6, 2024
2 parents 0fa3155 + 5fa8733 commit 9e43822
Show file tree
Hide file tree
Showing 10 changed files with 438 additions and 152 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ jobs:
run: go mod download

- name: Test examples
run: go run ./examples/chat
run: |
go run ./examples/chat
go run ./examples/stream
shell: bash
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
Expand Down
66 changes: 62 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ A cross-platform Go library for interacting with multiple AI providers' APIs, in
- Unified interface for multiple AI providers
- Currently supports:
- OpenAI (via [go-openai](https://github.com/sashabaranov/go-openai))
- Anthropic (via [go-anthropic](https://github.com/liushuangls/go-anthropic))
- Anthropic (via [official SDK](https://github.com/anthropics/anthropic-sdk-go))
- Carefully designed API that follows each provider's best practices
- Gradual and thoughtful addition of necessary interfaces and fields

Expand All @@ -31,7 +31,7 @@ go get github.com/cpunion/go-aisuite

See complete examples in the [examples](./examples) directory.

Basic usage:
### Chat

<!-- embedme examples/chat/main.go -->

Expand Down Expand Up @@ -60,7 +60,7 @@ func main() {
Model: "openai:gpt-4o-mini", // or "anthropic:claude-3-5-haiku-20241022"
Messages: []aisuite.ChatCompletionMessage{
{
Role: aisuite.User,
Role: aisuite.RoleUser,
Content: "Hello, how are you?",
},
},
Expand All @@ -75,6 +75,64 @@ func main() {

```

### Stream

<!-- embedme examples/stream/main.go -->

```go
package main

import (
"context"
"fmt"

"github.com/cpunion/go-aisuite"
"github.com/cpunion/go-aisuite/client"
)

func main() {
// Initialize client with API keys
c := client.New(&client.APIKey{
OpenAI: "", // Set your OpenAI API key or use OPENAI_API_KEY env
Anthropic: "", // Set your Anthropic API key or use ANTHROPIC_API_KEY env
})

// Create a streaming chat completion request
stream, err := c.StreamChatCompletion(context.Background(), aisuite.ChatCompletionRequest{
Model: "openai:gpt-4o-mini", // or "anthropic:claude-3-5-haiku-20241022"
Messages: []aisuite.ChatCompletionMessage{
{
Role: aisuite.RoleUser,
Content: "Hello, how are you?",
},
},
MaxTokens: 10,
})
if err != nil {
panic(err)
}
defer stream.Close()

// Read the response stream
for {
resp, err := stream.Recv()
if err != nil {
panic(err)
}
if len(resp.Choices) == 0 {
fmt.Println("No choices")
break
}
if resp.Choices[0].FinishReason != "" {
fmt.Printf("\nStream finished: %s\n", resp.Choices[0].FinishReason)
break
}
fmt.Print(resp.Choices[0].Delta.Content)
}
}

```

## Contributing

We welcome contributions! Please feel free to submit a Pull Request. We are carefully expanding the API surface area to maintain compatibility and usability across different providers.
Expand All @@ -87,4 +145,4 @@ MIT License

This project is inspired by [aisuite](https://github.com/andrewyng/aisuite) and builds upon the excellent work of:
- [go-openai](https://github.com/sashabaranov/go-openai)
- [go-anthropic](https://github.com/liushuangls/go-anthropic)
- [anthropic-sdk-go](https://github.com/anthropics/anthropic-sdk-go)
20 changes: 15 additions & 5 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,24 @@ type ToolCall struct {
Function FunctionCall
}

type FinishReason string

const (
FinishReasonNone FinishReason = ""
FinishReasonStop FinishReason = "stop"
FinishReasonMaxTokens FinishReason = "max_tokens"
FinishReasonContentFilter FinishReason = "content_filter"
FinishReasonUnknown FinishReason = "unknown"
)

// ChatCompletionMessage is a message in a chat completion request.

type Role string

const (
User Role = "user"
System Role = "system"
Assistant Role = "assistant"
RoleUser Role = "user"
RoleSystem Role = "system"
RoleAssistant Role = "assistant"
)

type ChatCompletionMessage struct {
Expand All @@ -45,15 +55,15 @@ type ChatCompletionResponse struct {

type ChatCompletionStreamChoiceDelta struct {
Content string
Role string
Role Role
FunctionCall *FunctionCall
ToolCalls []ToolCall
Refusal string
}

type ChatCompletionStreamChoice struct {
Delta ChatCompletionStreamChoiceDelta
FinishReason string
FinishReason FinishReason
}

type ChatCompletionStreamResponse struct {
Expand Down
179 changes: 122 additions & 57 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,84 +2,149 @@ package client

import (
"context"
"fmt"
"strings"
"testing"
"time"

"github.com/cpunion/go-aisuite"
)

func TestChatCompletion(t *testing.T) {
client := New(nil)
models := []string{
"openai:gpt-4o-mini",
"anthropic:claude-3-5-haiku-20241022",
var testModels = []string{
"openai:gpt-4o-mini",
"anthropic:claude-3-5-haiku-20241022",
}

type testCase struct {
name string
model string
prompt string
maxTokens int
wantFinishReason aisuite.FinishReason
}

func generateTestCases() []testCase {
var cases []testCase

// Test cases for max_tokens finish reason
longStory := "Tell me a very long story about a magical adventure with dragons, wizards, and epic battles."
for _, model := range testModels {
cases = append(cases, testCase{
name: fmt.Sprintf("%s_maxtoken_test", strings.Split(model, ":")[1]),
model: model,
prompt: longStory,
maxTokens: 5,
wantFinishReason: aisuite.FinishReasonMaxTokens,
})
}
for _, model := range models {
resp, err := client.ChatCompletion(context.Background(), aisuite.ChatCompletionRequest{
Model: model,
Messages: []aisuite.ChatCompletionMessage{
{
Role: aisuite.User,
Content: "Hello",
},
},
MaxTokens: 10,

// Test cases for normal stop
for _, model := range testModels {
cases = append(cases, testCase{
name: fmt.Sprintf("%s_normal_stop", strings.Split(model, ":")[1]),
model: model,
prompt: "Hi",
maxTokens: 20,
wantFinishReason: aisuite.FinishReasonStop,
})
if err != nil {
t.Fatal(err)
}
t.Logf("resp: %#v", resp)
}

return cases
}

func TestStreamChatCompletion(t *testing.T) {
client := New(nil)
models := []string{
"openai:gpt-4o-mini",
"anthropic:claude-3-5-haiku-20241022",
func withTimeout(t *testing.T, timeout time.Duration, fn func(ctx context.Context)) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

done := make(chan struct{})
go func() {
fn(ctx)
close(done)
}()

select {
case <-done:
return
case <-ctx.Done():
t.Fatal("test timeout")
}
}

func TestChatCompletion(t *testing.T) {
client := New(nil)
models := testModels
for _, model := range models {
t.Run(model, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

stream, err := client.StreamChatCompletion(ctx, aisuite.ChatCompletionRequest{
Model: model,
Messages: []aisuite.ChatCompletionMessage{
{
Role: aisuite.User,
Content: "Hello",
withTimeout(t, 10*time.Second, func(ctx context.Context) {
resp, err := client.ChatCompletion(ctx, aisuite.ChatCompletionRequest{
Model: model,
Messages: []aisuite.ChatCompletionMessage{
{
Role: aisuite.RoleUser,
Content: "Hi",
},
},
},
MaxTokens: 10,
MaxTokens: 30,
})
if err != nil {
t.Fatal(err)
}
if len(resp.Choices) == 0 {
t.Fatal("no choices")
}
fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content)
})
if err != nil {
t.Fatal(err)
}
defer stream.Close()

// Read all chunks from the stream
for {
select {
case <-ctx.Done():
t.Fatal("context deadline exceeded")
return
default:
chunk, err := stream.Recv()
})
}
}

func TestStreamChatCompletion(t *testing.T) {
client := New(nil)
cases := generateTestCases()

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
withTimeout(t, 10*time.Second, func(ctx context.Context) {
stream, err := client.StreamChatCompletion(ctx, aisuite.ChatCompletionRequest{
Model: tc.model,
Messages: []aisuite.ChatCompletionMessage{
{
Role: aisuite.RoleUser,
Content: tc.prompt,
},
},
MaxTokens: tc.maxTokens,
})
if err != nil {
t.Fatal(err)
}
defer stream.Close()

var content string
var finishReason aisuite.FinishReason
for {
resp, err := stream.Recv()
if err != nil {
if err.Error() == "EOF" {
return
}
t.Fatal(err)
}

t.Logf("chunk: %#v", chunk)
if len(chunk.Choices) > 0 && chunk.Choices[0].FinishReason != "" {
t.Logf("finish reason: %s", chunk.Choices[0].FinishReason)
return
if len(resp.Choices) == 0 {
fmt.Println("No choices")
break
}
if resp.Choices[0].FinishReason != "" {
finishReason = resp.Choices[0].FinishReason
break
}
content += resp.Choices[0].Delta.Content
}
}

if finishReason != tc.wantFinishReason {
t.Errorf("got finish reason %q, want %q", finishReason, tc.wantFinishReason)
}

fmt.Printf("Test case %s:\nPrompt: %s\nResponse: %s\nFinish reason: %s\n\n",
tc.name, tc.prompt, content, finishReason)
})
})
}
}
2 changes: 1 addition & 1 deletion examples/chat/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func main() {
Model: "openai:gpt-4o-mini", // or "anthropic:claude-3-5-haiku-20241022"
Messages: []aisuite.ChatCompletionMessage{
{
Role: aisuite.User,
Role: aisuite.RoleUser,
Content: "Hello, how are you?",
},
},
Expand Down
Loading

0 comments on commit 9e43822

Please sign in to comment.