From c0306f13198e5e62bd8ca1892e92f5ef79617b2f Mon Sep 17 00:00:00 2001 From: Shreyanand Date: Fri, 3 May 2024 09:46:46 -0400 Subject: [PATCH] Add final touches --- .../rag/app/manage_vectordb.py | 40 ++++++------------- .../rag/app/rag_app.py | 12 ++---- 2 files changed, 15 insertions(+), 37 deletions(-) diff --git a/recipes/natural_language_processing/rag/app/manage_vectordb.py b/recipes/natural_language_processing/rag/app/manage_vectordb.py index dd0e2d2d8..82566abdc 100644 --- a/recipes/natural_language_processing/rag/app/manage_vectordb.py +++ b/recipes/natural_language_processing/rag/app/manage_vectordb.py @@ -6,7 +6,6 @@ from langchain_community.vectorstores import Milvus from pymilvus import MilvusClient from pymilvus import connections, utility -import uuid class VectorDB: def __init__(self, vector_vendor, host, port, collection_name, embedding_model): @@ -17,7 +16,7 @@ def __init__(self, vector_vendor, host, port, collection_name, embedding_model): self.embedding_model = embedding_model def connect(self): - # Connection logic here + # Connection logic print(f"Connecting to {self.host}:{self.port}...") if self.vector_vendor == "chromadb": self.client = HttpClient(host=self.host, @@ -28,7 +27,7 @@ def connect(self): return self.client def populate_db(self, documents): - # Implement logic to populate the VectorDB with vectors + # Logic to populate the VectorDB with vectors e = SentenceTransformerEmbeddings(model_name=self.embedding_model) print(f"Populating VectorDB with vectors...") if self.vector_vendor == "chromadb": @@ -36,18 +35,19 @@ def populate_db(self, documents): collection = self.client.get_or_create_collection(self.collection_name, embedding_function=embedding_func) if collection.count() < 1: - for doc in documents: - collection.add( - ids=[str(uuid.uuid1())], - metadatas=doc.metadata, - documents=doc.page_content - ) + db = Chroma.from_documents( + documents=documents, + embedding=e, + collection_name=self.collection_name, + client=self.client + ) print("DB populated") else: + db = Chroma(client=self.client, + collection_name=self.collection_name, + embedding_function=e, + ) print("DB already populated") - db = Chroma(client=self.client, - collection_name=self.collection_name, - embedding_function=e) elif self.vector_vendor == "milvus": connections.connect(host=self.host, port=self.port) @@ -79,19 +79,3 @@ def clear_db(self): print("Cleared DB") except: print("Couldn't clear the collection possibly because it doesn't exist") - - - # def has_collection(self): - # # Implement logic to check if the VectorDB has a collection - # print(f"Checking if collection {self.collection_name} exists in VectorDB...") - # if self.vector_vendor == "chromadb": - # hc = collection.count() < 1 - # elif self.vector_vendor == "milvus": - # hc = - # return hc - - - def create_retriever(self): - # Implement logic to create and return a retriever object from the VectorDB - print("Creating retriever object from VectorDB...") - diff --git a/recipes/natural_language_processing/rag/app/rag_app.py b/recipes/natural_language_processing/rag/app/rag_app.py index 158d43e02..52097cc8d 100644 --- a/recipes/natural_language_processing/rag/app/rag_app.py +++ b/recipes/natural_language_processing/rag/app/rag_app.py @@ -14,15 +14,9 @@ model_service = f"{model_service}/v1" chunk_size = os.getenv("CHUNK_SIZE", 150) embedding_model = os.getenv("EMBEDDING_MODEL","BAAI/bge-base-en-v1.5") -# vdb_vendor = os.getenv("VECTORDB_VENDOR", "chromadb") -# vdb_host = os.getenv("VECTORDB_HOST", "0.0.0.0") -# vdb_port = os.getenv("VECTORDB_PORT", "8000") -# vdb_name = os.getenv("VECTORDB_NAME", "test_collection") - -# Use the following defaults if using milvus db -vdb_vendor = os.getenv("VECTORDB_VENDOR", "milvus") -vdb_host = os.getenv("VECTORDB_HOST", "127.0.0.1") -vdb_port = os.getenv("VECTORDB_PORT", "19530") +vdb_vendor = os.getenv("VECTORDB_VENDOR", "chromadb") +vdb_host = os.getenv("VECTORDB_HOST", "0.0.0.0") +vdb_port = os.getenv("VECTORDB_PORT", "8000") vdb_name = os.getenv("VECTORDB_NAME", "test_collection") vdb = VectorDB(vdb_vendor, vdb_host, vdb_port, vdb_name, embedding_model)