From 1517ab1034582742eea17c58bd16655716d56547 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 10 Feb 2024 11:21:45 +0100 Subject: [PATCH] Support more OpenAI embedding models --- embedding.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/embedding.go b/embedding.go index b038c63..a41f9a1 100644 --- a/embedding.go +++ b/embedding.go @@ -12,8 +12,15 @@ import ( ) const ( - baseURLOpenAI = "https://api.openai.com/v1" - embeddingModelOpenAI3Small = "text-embedding-3-small" + baseURLOpenAI = "https://api.openai.com/v1" +) + +type EmbeddingModelOpenAI string + +const ( + EmbeddingModelOpenAI2Ada EmbeddingModelOpenAI = "text-embedding-ada-002" + EmbeddingModelOpenAI3Small EmbeddingModelOpenAI = "text-embedding-3-small" + EmbeddingModelOpenAI3Large EmbeddingModelOpenAI = "text-embedding-3-large" ) type openAIResponse struct { @@ -28,13 +35,13 @@ type openAIResponse struct { // The API key is read from the environment variable "OPENAI_API_KEY". func CreateEmbeddingsDefault() EmbeddingFunc { apiKey := os.Getenv("OPENAI_API_KEY") - return CreateEmbeddingsOpenAI(apiKey) + return CreateEmbeddingsOpenAI(apiKey, EmbeddingModelOpenAI3Small) } // CreateEmbeddingsDefault returns a function that creates embeddings for a document // using OpenAI`s "text-embedding-3-small" model via their API. // The model supports a maximum document length of 8191 tokens. -func CreateEmbeddingsOpenAI(apiKey string) EmbeddingFunc { +func CreateEmbeddingsOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc { // We don't set a default timeout here, although it's usually a good idea. // In our case though, the library user can set the timeout on the context, // and it might have to be a long timeout, depending on the document size. @@ -44,7 +51,7 @@ func CreateEmbeddingsOpenAI(apiKey string) EmbeddingFunc { // Prepare the request body. reqBody, err := json.Marshal(map[string]string{ "input": document, - "model": embeddingModelOpenAI3Small, + "model": string(model), }) if err != nil { return nil, fmt.Errorf("couldn't marshal request body: %w", err)