diff --git a/src/wandbot/configs/vector_store_config.py b/src/wandbot/configs/vector_store_config.py index 7d7ea7b..b9f4660 100644 --- a/src/wandbot/configs/vector_store_config.py +++ b/src/wandbot/configs/vector_store_config.py @@ -1,6 +1,6 @@ import pathlib from pydantic_settings import BaseSettings - +from typing import Literal class VectorStoreConfig(BaseSettings): # Vector Store collection_name: str = "vectorstore" @@ -14,5 +14,6 @@ class VectorStoreConfig(BaseSettings): embeddings_dimensions: int = 512 # needed when using OpenAI embeddings embeddings_query_input_type: str = "search_query" # needed when using Cohere embeddings embeddings_document_input_type: str = "search_document" # needed when using Cohere embeddings + embeddings_encoding_format: Literal["float", "base64"] = "float" # Ingestions batch_size: int = 256 # used during ingestion when adding docs to vectorstore \ No newline at end of file diff --git a/src/wandbot/models/embedding.py b/src/wandbot/models/embedding.py index cc53612..a63b3ab 100644 --- a/src/wandbot/models/embedding.py +++ b/src/wandbot/models/embedding.py @@ -8,6 +8,9 @@ from wandbot.configs.vector_store_config import VectorStoreConfig from wandbot.utils import get_logger +import base64 +import numpy as np + logger = get_logger(__name__) vector_store_config = VectorStoreConfig() @@ -22,6 +25,7 @@ def __init__(self, input_type:str = None, dimensions:int = None, n_parallel_api_calls:int = 50, + encoding_format:str = "float", ): self.provider = provider.lower() self.model_name = model_name @@ -55,7 +59,7 @@ def __init__(self, raise ValueError("`dimensions` needs to be specified when using OpenAI embeddings models") self.n_parallel_api_calls = n_parallel_api_calls - + self.encoding_format = encoding_format @weave.op @retry( stop=stop_after_attempt(3), @@ -83,10 +87,15 @@ async def get_single_openai_embedding(text): response = await client.embeddings.create( input=text, model=self.model_name, - encoding_format="float", + encoding_format=self.encoding_format, dimensions=self.dimensions ) - return response.data[0].embedding + if self.encoding_format == "base64": + decoded_embeddings = base64.b64decode(response.data[0].embedding) + embeddings = np.frombuffer(decoded_embeddings, dtype=np.float32).tolist() + return embeddings + else: + return response.data[0].embedding return await asyncio.gather(*[get_single_openai_embedding(text) for text in inputs]) finally: diff --git a/src/wandbot/retriever/base.py b/src/wandbot/retriever/base.py index d2eb804..3f3d390 100644 --- a/src/wandbot/retriever/base.py +++ b/src/wandbot/retriever/base.py @@ -25,7 +25,8 @@ def __init__(self, vector_store_config: VectorStoreConfig, chat_config: ChatConf provider = self.vector_store_config.embeddings_provider, model_name = self.vector_store_config.embeddings_model_name, dimensions = self.vector_store_config.embeddings_dimensions, - input_type = self.vector_store_config.embeddings_query_input_type + input_type = self.vector_store_config.embeddings_query_input_type, + encoding_format = self.vector_store_config.embeddings_encoding_format ) except Exception as e: raise RuntimeError(f"Failed to initialize embedding model:\n{str(e)}\n") from e