Skip to content

Commit

Permalink
Adds float or base64 encoding option to EmbeddingModel, set config to…
Browse files Browse the repository at this point in the history
… base64
  • Loading branch information
morganmcg1 committed Jan 20, 2025
1 parent 6ddc48a commit b1e45cf
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/wandbot/configs/vector_store_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pathlib
from pydantic_settings import BaseSettings

from typing import Literal
class VectorStoreConfig(BaseSettings):
# Vector Store
collection_name: str = "vectorstore"
Expand All @@ -14,5 +14,6 @@ class VectorStoreConfig(BaseSettings):
embeddings_dimensions: int = 512 # needed when using OpenAI embeddings
embeddings_query_input_type: str = "search_query" # needed when using Cohere embeddings
embeddings_document_input_type: str = "search_document" # needed when using Cohere embeddings
embeddings_encoding_format: Literal["float", "base64"] = "float"
# Ingestions
batch_size: int = 256 # used during ingestion when adding docs to vectorstore
15 changes: 12 additions & 3 deletions src/wandbot/models/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from wandbot.configs.vector_store_config import VectorStoreConfig
from wandbot.utils import get_logger

import base64
import numpy as np

logger = get_logger(__name__)

vector_store_config = VectorStoreConfig()
Expand All @@ -22,6 +25,7 @@ def __init__(self,
input_type:str = None,
dimensions:int = None,
n_parallel_api_calls:int = 50,
encoding_format:str = "float",
):
self.provider = provider.lower()
self.model_name = model_name
Expand Down Expand Up @@ -55,7 +59,7 @@ def __init__(self,
raise ValueError("`dimensions` needs to be specified when using OpenAI embeddings models")

self.n_parallel_api_calls = n_parallel_api_calls

self.encoding_format = encoding_format
@weave.op
@retry(
stop=stop_after_attempt(3),
Expand Down Expand Up @@ -83,10 +87,15 @@ async def get_single_openai_embedding(text):
response = await client.embeddings.create(
input=text,
model=self.model_name,
encoding_format="float",
encoding_format=self.encoding_format,
dimensions=self.dimensions
)
return response.data[0].embedding
if self.encoding_format == "base64":
decoded_embeddings = base64.b64decode(response.data[0].embedding)
embeddings = np.frombuffer(decoded_embeddings, dtype=np.float32).tolist()
return embeddings
else:
return response.data[0].embedding

return await asyncio.gather(*[get_single_openai_embedding(text) for text in inputs])
finally:
Expand Down
3 changes: 2 additions & 1 deletion src/wandbot/retriever/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(self, vector_store_config: VectorStoreConfig, chat_config: ChatConf
provider = self.vector_store_config.embeddings_provider,
model_name = self.vector_store_config.embeddings_model_name,
dimensions = self.vector_store_config.embeddings_dimensions,
input_type = self.vector_store_config.embeddings_query_input_type
input_type = self.vector_store_config.embeddings_query_input_type,
encoding_format = self.vector_store_config.embeddings_encoding_format
)
except Exception as e:
raise RuntimeError(f"Failed to initialize embedding model:\n{str(e)}\n") from e
Expand Down

0 comments on commit b1e45cf

Please sign in to comment.