Skip to content

Commit

Permalink
Merge pull request #3 from cpunion/gemini-sambanova-groq
Browse files Browse the repository at this point in the history
feat: add new AI providers and update existing ones
  • Loading branch information
cpunion authored Dec 10, 2024
2 parents 9e43822 + 02b9687 commit 4466085
Show file tree
Hide file tree
Showing 13 changed files with 178 additions and 31 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ jobs:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
SAMBANOVA_API_KEY: ${{ secrets.SAMBANOVA_API_KEY }}
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.env
_*
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ A cross-platform Go library for interacting with multiple AI providers' APIs, in
- Currently supports:
- OpenAI (via [go-openai](https://github.com/sashabaranov/go-openai))
- Anthropic (via [official SDK](https://github.com/anthropics/anthropic-sdk-go))
- Groq (via OpenAI-compatible API)
- Gemini (via OpenAI-compatible API)
- SambaNova (via OpenAI-compatible API)
- Carefully designed API that follows each provider's best practices
- Gradual and thoughtful addition of necessary interfaces and fields

Expand Down
31 changes: 23 additions & 8 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ import (

"github.com/cpunion/go-aisuite"
"github.com/cpunion/go-aisuite/providers"
_ "github.com/cpunion/go-aisuite/providers/anthropic"
_ "github.com/cpunion/go-aisuite/providers/openai"
"github.com/cpunion/go-aisuite/providers/anthropic"
"github.com/cpunion/go-aisuite/providers/gemini"
"github.com/cpunion/go-aisuite/providers/groq"
"github.com/cpunion/go-aisuite/providers/openai"
"github.com/cpunion/go-aisuite/providers/sambanova"
)

const (
Expand All @@ -18,6 +21,9 @@ const (
type APIKey struct {
OpenAI string
Anthropic string
Sambanova string
Gemini string
Groq string
}

type AdaptiveClient struct {
Expand Down Expand Up @@ -52,12 +58,21 @@ func (c AdaptiveClient) getClientAndModel(model string) (aisuite.Client, string)
if !ok {
panic(fmt.Sprintf("%s: %s", ErrUnknownProvider, providerName))
}
var apiKey string
opts := providers.Options{}
switch providerName {
case "openai":
apiKey = c.apiKey.OpenAI
case "anthropic":
apiKey = c.apiKey.Anthropic
case openai.Name:
opts.Token = c.apiKey.OpenAI
case anthropic.Name:
opts.Token = c.apiKey.Anthropic
case gemini.Name:
opts.BaseURL = "https://generativelanguage.googleapis.com/v1beta/openai/"
opts.Token = c.apiKey.Gemini
case sambanova.Name:
opts.BaseURL = "https://api.sambanova.ai/v1/"
opts.Token = c.apiKey.Sambanova
case groq.Name:
opts.Token = c.apiKey.Groq
opts.BaseURL = "https://api.groq.com/openai/v1/"
}
return provider.NewClient(apiKey), toks[1]
return provider.NewClient(opts), toks[1]
}
8 changes: 7 additions & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package client
import (
"context"
"fmt"
"os"
"strings"
"testing"
"time"
Expand All @@ -13,6 +14,9 @@ import (
var testModels = []string{
"openai:gpt-4o-mini",
"anthropic:claude-3-5-haiku-20241022",
"gemini:gemini-1.5-flash-latest",
"sambanova:Meta-Llama-3.2-1B-Instruct",
"groq:llama-3.1-8b-instant",
}

type testCase struct {
Expand Down Expand Up @@ -43,7 +47,7 @@ func generateTestCases() []testCase {
cases = append(cases, testCase{
name: fmt.Sprintf("%s_normal_stop", strings.Split(model, ":")[1]),
model: model,
prompt: "Hi",
prompt: "Hi (shortly respond)",
maxTokens: 20,
wantFinishReason: aisuite.FinishReasonStop,
})
Expand Down Expand Up @@ -75,6 +79,8 @@ func TestChatCompletion(t *testing.T) {
models := testModels
for _, model := range models {
t.Run(model, func(t *testing.T) {
wd, _ := os.Getwd()
t.Logf("Working directory: %s", wd)
withTimeout(t, 10*time.Second, func(ctx context.Context) {
resp, err := client.ChatCompletion(ctx, aisuite.ChatCompletionRequest{
Model: model,
Expand Down
12 changes: 3 additions & 9 deletions providers/anthropic/anthropic_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package anthropic
import (
"context"
"log/slog"
"os"

"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
"github.com/cpunion/go-aisuite"
"github.com/cpunion/go-aisuite/providers"
)

const (
Expand All @@ -19,14 +19,8 @@ type Client struct {
client *anthropic.Client
}

func NewClient(token string) *Client {
if token == "" {
token = os.Getenv("ANTHROPIC_API_KEY")
if token == "" {
panic("ANTHROPIC_API_KEY not found in environment variables")
}
}
return &Client{client: anthropic.NewClient(option.WithAPIKey(token))}
func NewClient(opts providers.Options) *Client {
return &Client{client: anthropic.NewClient(option.WithAPIKey(opts.Token))}
}

func (c *Client) ChatCompletion(ctx context.Context, req aisuite.ChatCompletionRequest) (*aisuite.ChatCompletionResponse, error) {
Expand Down
13 changes: 11 additions & 2 deletions providers/anthropic/anthropic_provider.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package anthropic

import (
"os"

"github.com/cpunion/go-aisuite"
"github.com/cpunion/go-aisuite/providers"
)

const Name = "anthropic"
const apiKeyEnvVar = "ANTHROPIC_API_KEY"

func init() {
providers.RegisterProvider(Name, Provider{})
Expand All @@ -14,6 +17,12 @@ func init() {
type Provider struct {
}

func (p Provider) NewClient(apiKey string) aisuite.Client {
return NewClient(apiKey)
func (p Provider) NewClient(opts providers.Options) aisuite.Client {
if opts.Token == "" {
opts.Token = os.Getenv(apiKeyEnvVar)
if opts.Token == "" {
panic(apiKeyEnvVar + " not found in environment variables")
}
}
return NewClient(opts)
}
29 changes: 29 additions & 0 deletions providers/gemini/gemini_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package gemini

import (
"os"

"github.com/cpunion/go-aisuite"
"github.com/cpunion/go-aisuite/providers"
"github.com/cpunion/go-aisuite/providers/openai"
)

const Name = "gemini"
const apiKeyEnvVar = "GEMINI_API_KEY"

func init() {
providers.RegisterProvider(Name, Provider{})
}

type Provider struct {
}

func (p Provider) NewClient(opts providers.Options) aisuite.Client {
if opts.Token == "" {
opts.Token = os.Getenv(apiKeyEnvVar)
if opts.Token == "" {
panic(apiKeyEnvVar + " not found in environment variables")
}
}
return openai.NewClient(opts)
}
29 changes: 29 additions & 0 deletions providers/groq/groq_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package groq

import (
"os"

"github.com/cpunion/go-aisuite"
"github.com/cpunion/go-aisuite/providers"
"github.com/cpunion/go-aisuite/providers/openai"
)

const Name = "groq"
const apiKeyEnvVar = "GROQ_API_KEY"

func init() {
providers.RegisterProvider(Name, Provider{})
}

type Provider struct {
}

func (p Provider) NewClient(opts providers.Options) aisuite.Client {
if opts.Token == "" {
opts.Token = os.Getenv(apiKeyEnvVar)
if opts.Token == "" {
panic(apiKeyEnvVar + " not found in environment variables")
}
}
return openai.NewClient(opts)
}
14 changes: 6 additions & 8 deletions providers/openai/openai_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,22 @@ package openai
import (
"context"
"log/slog"
"os"

"github.com/cpunion/go-aisuite"
"github.com/cpunion/go-aisuite/providers"
ai "github.com/sashabaranov/go-openai"
)

type Client struct {
client *ai.Client
}

func NewClient(token string) *Client {
if token == "" {
token = os.Getenv("OPENAI_API_KEY")
if token == "" {
panic("OPENAI_API_KEY not found in environment variables")
}
func NewClient(opts providers.Options) *Client {
config := ai.DefaultConfig(opts.Token)
if opts.BaseURL != "" {
config.BaseURL = opts.BaseURL
}
return &Client{client: ai.NewClient(token)}
return &Client{client: ai.NewClientWithConfig(config)}
}

func (c *Client) ChatCompletion(ctx context.Context, req aisuite.ChatCompletionRequest) (*aisuite.ChatCompletionResponse, error) {
Expand Down
13 changes: 11 additions & 2 deletions providers/openai/openai_provider.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package openai

import (
"os"

"github.com/cpunion/go-aisuite"
"github.com/cpunion/go-aisuite/providers"
)

const Name = "openai"
const apiKeyEnvVar = "OPENAI_API_KEY"

func init() {
providers.RegisterProvider(Name, Provider{})
Expand All @@ -14,6 +17,12 @@ func init() {
type Provider struct {
}

func (p Provider) NewClient(apiKey string) aisuite.Client {
return NewClient(apiKey)
func (p Provider) NewClient(opts providers.Options) aisuite.Client {
if opts.Token == "" {
opts.Token = os.Getenv(apiKeyEnvVar)
if opts.Token == "" {
panic(apiKeyEnvVar + " not found in environment variables")
}
}
return NewClient(opts)
}
23 changes: 22 additions & 1 deletion providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,29 @@ package providers

import "github.com/cpunion/go-aisuite"

type Options struct {
BaseURL string
Token string
}

type Option func(o Options) Options

func WithToken(token string) Option {
return func(o Options) Options {
o.Token = token
return o
}
}

func WithBaseURL(baseURL string) Option {
return func(o Options) Options {
o.BaseURL = baseURL
return o
}
}

type Provider interface {
NewClient(apiKey string) aisuite.Client
NewClient(options Options) aisuite.Client
}

var providers = make(map[string]Provider)
Expand Down
29 changes: 29 additions & 0 deletions providers/sambanova/sambanova_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package sambanova

import (
"os"

"github.com/cpunion/go-aisuite"
"github.com/cpunion/go-aisuite/providers"
"github.com/cpunion/go-aisuite/providers/openai"
)

const Name = "sambanova"
const apiKeyEnvVar = "SAMBANOVA_API_KEY"

func init() {
providers.RegisterProvider(Name, Provider{})
}

type Provider struct {
}

func (p Provider) NewClient(opts providers.Options) aisuite.Client {
if opts.Token == "" {
opts.Token = os.Getenv(apiKeyEnvVar)
if opts.Token == "" {
panic(apiKeyEnvVar + " not found in environment variables")
}
}
return openai.NewClient(opts)
}

0 comments on commit 4466085

Please sign in to comment.