Skip to content

Commit

Permalink
Add final touches
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyanand committed May 3, 2024
1 parent 5c965d0 commit c0306f1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 37 deletions.
40 changes: 12 additions & 28 deletions recipes/natural_language_processing/rag/app/manage_vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from langchain_community.vectorstores import Milvus
from pymilvus import MilvusClient
from pymilvus import connections, utility
import uuid

class VectorDB:
def __init__(self, vector_vendor, host, port, collection_name, embedding_model):
Expand All @@ -17,7 +16,7 @@ def __init__(self, vector_vendor, host, port, collection_name, embedding_model):
self.embedding_model = embedding_model

def connect(self):
# Connection logic here
# Connection logic
print(f"Connecting to {self.host}:{self.port}...")
if self.vector_vendor == "chromadb":
self.client = HttpClient(host=self.host,
Expand All @@ -28,26 +27,27 @@ def connect(self):
return self.client

def populate_db(self, documents):
# Implement logic to populate the VectorDB with vectors
# Logic to populate the VectorDB with vectors
e = SentenceTransformerEmbeddings(model_name=self.embedding_model)
print(f"Populating VectorDB with vectors...")
if self.vector_vendor == "chromadb":
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.embedding_model)
collection = self.client.get_or_create_collection(self.collection_name,
embedding_function=embedding_func)
if collection.count() < 1:
for doc in documents:
collection.add(
ids=[str(uuid.uuid1())],
metadatas=doc.metadata,
documents=doc.page_content
)
db = Chroma.from_documents(
documents=documents,
embedding=e,
collection_name=self.collection_name,
client=self.client
)
print("DB populated")
else:
db = Chroma(client=self.client,
collection_name=self.collection_name,
embedding_function=e,
)
print("DB already populated")
db = Chroma(client=self.client,
collection_name=self.collection_name,
embedding_function=e)

elif self.vector_vendor == "milvus":
connections.connect(host=self.host, port=self.port)
Expand Down Expand Up @@ -79,19 +79,3 @@ def clear_db(self):
print("Cleared DB")
except:
print("Couldn't clear the collection possibly because it doesn't exist")


# def has_collection(self):
# # Implement logic to check if the VectorDB has a collection
# print(f"Checking if collection {self.collection_name} exists in VectorDB...")
# if self.vector_vendor == "chromadb":
# hc = collection.count() < 1
# elif self.vector_vendor == "milvus":
# hc =
# return hc


def create_retriever(self):
# Implement logic to create and return a retriever object from the VectorDB
print("Creating retriever object from VectorDB...")

12 changes: 3 additions & 9 deletions recipes/natural_language_processing/rag/app/rag_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,9 @@
model_service = f"{model_service}/v1"
chunk_size = os.getenv("CHUNK_SIZE", 150)
embedding_model = os.getenv("EMBEDDING_MODEL","BAAI/bge-base-en-v1.5")
# vdb_vendor = os.getenv("VECTORDB_VENDOR", "chromadb")
# vdb_host = os.getenv("VECTORDB_HOST", "0.0.0.0")
# vdb_port = os.getenv("VECTORDB_PORT", "8000")
# vdb_name = os.getenv("VECTORDB_NAME", "test_collection")

# Use the following defaults if using milvus db
vdb_vendor = os.getenv("VECTORDB_VENDOR", "milvus")
vdb_host = os.getenv("VECTORDB_HOST", "127.0.0.1")
vdb_port = os.getenv("VECTORDB_PORT", "19530")
vdb_vendor = os.getenv("VECTORDB_VENDOR", "chromadb")
vdb_host = os.getenv("VECTORDB_HOST", "0.0.0.0")
vdb_port = os.getenv("VECTORDB_PORT", "8000")
vdb_name = os.getenv("VECTORDB_NAME", "test_collection")

vdb = VectorDB(vdb_vendor, vdb_host, vdb_port, vdb_name, embedding_model)
Expand Down

0 comments on commit c0306f1

Please sign in to comment.