Skip to content

Commit

Permalink
Add switching between vdbs logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyanand committed Apr 24, 2024
1 parent d9da2f1 commit 6541325
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 206 deletions.
98 changes: 98 additions & 0 deletions recipes/natural_language_processing/rag/app/manage_vectordb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from langchain_community.vectorstores import Chroma
from chromadb import HttpClient
from chromadb.config import Settings
import chromadb.utils.embedding_functions as embedding_functions
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Milvus
from pymilvus import MilvusClient
from pymilvus import connections, utility
import uuid

class VectorDB:
def __init__(self, vector_vendor, host, port, collection_name, embedding_model):
self.vector_vendor = vector_vendor
self.host = host
self.port = port
self.collection_name = collection_name
self.embedding_model = embedding_model

def connect(self):
# Connection logic here
print(f"Connecting to {self.host}:{self.port}...")
if self.vector_vendor == "chromadb":
self.client = HttpClient(host=self.host,
port=self.port,
settings=Settings(allow_reset=True,))
elif self.vector_vendor == "milvus":
self.client = MilvusClient(uri=f"http://{self.host}:{self.port}")
return self.client

def populate_db(self, documents):
# Implement logic to populate the VectorDB with vectors
e = SentenceTransformerEmbeddings(model_name=self.embedding_model)
print(f"Populating VectorDB with vectors...")
if self.vector_vendor == "chromadb":
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.embedding_model)
collection = self.client.get_or_create_collection(self.collection_name,
embedding_function=embedding_func)
if collection.count() < 1:
for doc in documents:
collection.add(
ids=[str(uuid.uuid1())],
metadatas=doc.metadata,
documents=doc.page_content
)
print("DB populated")
else:
print("DB already populated")
db = Chroma(client=self.client,
collection_name=self.collection_name,
embedding_function=e)

elif self.vector_vendor == "milvus":
connections.connect(host=self.host, port=self.port)
if not utility.has_collection(self.collection_name):
print("Populating VectorDB with vectors...")
db = Milvus.from_documents(
documents,
e,
collection_name=self.collection_name,
connection_args={"host": self.host, "port": self.port},
)
print("DB populated")
else:
print("DB already populated")
db = Milvus(
e,
collection_name=self.collection_name,
connection_args={"host": self.host, "port": self.port},
)
return db

def clear_db(self):
# Implement logic to clear the entire VectorDB
print(f"Clearing the entire VectorDB...")
try:
if self.vector_vendor == "chromadb":
self.client.delete_collection(self.collection_name)
elif self.vector_vendor == "milvus":
self.client.drop_collection(self.collection_name)
print("Cleared DB")
except:
print("Couldn't clear the collection possibly because it doesn't exist")


# def has_collection(self):
# # Implement logic to check if the VectorDB has a collection
# print(f"Checking if collection {self.collection_name} exists in VectorDB...")
# if self.vector_vendor == "chromadb":
# hc = collection.count() < 1
# elif self.vector_vendor == "milvus":
# hc =
# return hc


def create_retriever(self):
# Implement logic to create and return a retriever object from the VectorDB
print("Creating retriever object from VectorDB...")

36 changes: 0 additions & 36 deletions recipes/natural_language_processing/rag/app/populate_vectordb.py

This file was deleted.

90 changes: 36 additions & 54 deletions recipes/natural_language_processing/rag/app/rag_app.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,61 @@
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.callbacks import StreamlitCallbackHandler
from langchain.schema.document import Document
from langchain_community.vectorstores import Chroma

from chromadb import HttpClient
from chromadb.config import Settings
import chromadb.utils.embedding_functions as embedding_functions

from langchain_community.document_loaders import TextLoader
from manage_vectordb import VectorDB
import streamlit as st

import uuid
import os
import argparse
import pathlib

model_service = os.getenv("MODEL_ENDPOINT","http://0.0.0.0:8001/v1")
model_service = os.getenv("MODEL_ENDPOINT","http://0.0.0.0:8001")
model_service = f"{model_service}/v1"
chunk_size = os.getenv("CHUNK_SIZE", 150)
embedding_model = os.getenv("EMBEDDING_MODEL","BAAI/bge-base-en-v1.5")
vdb_vendor = os.getenv("VECTORDB_VENDOR", "chromadb")
vdb_host = os.getenv("VECTORDB_HOST", "0.0.0.0")
vdb_port = os.getenv("VECTORDB_PORT", "8000")
vdb_name = os.getenv("VECTORDB_NAME", "test_collection")

# Use the following defaults if using milvus db
# vdb_vendor = os.getenv("VECTORDB_VENDOR", "milvus")
# vdb_host = os.getenv("VECTORDB_HOST", "127.0.0.1")
# vdb_port = os.getenv("VECTORDB_PORT", "19530")
# vdb_name = os.getenv("VECTORDB_NAME", "test_collection")

vdb = VectorDB(vdb_vendor, vdb_host, vdb_port, vdb_name, embedding_model)
vectorDB_client = vdb.connect()
def get_docs(filename):
loader = TextLoader(filename)
raw_documents = loader.load()
text_splitter = CharacterTextSplitter(separator = ".",
chunk_size=int(chunk_size),
chunk_overlap=0)
docs = text_splitter.split_documents(raw_documents)
return docs


vectorDB_client = HttpClient(host=vdb_host,
port=vdb_port,
settings=Settings(allow_reset=True,))
def create_tmp_file():
with open(data.name, mode='wb') as w:
w.write(data.getvalue())
print(f"{data.name} uploaded")

def clear_vdb():
global vectorDB_client
try:
vectorDB_client.delete_collection(vdb_name)
print("Cleared DB")
except:
pass

st.title("📚 RAG DEMO")
with st.sidebar:
data = st.file_uploader(label="📄 Upload Document",type=['txt'], on_change=clear_vdb)
data = st.file_uploader(label="📄 Upload Document",type=['txt'], on_change=vdb.clear_db)

### populate the DB ####
os.environ["TOKENIZERS_PARALLELISM"] = "false"

embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=embedding_model)
e = SentenceTransformerEmbeddings(model_name=embedding_model)

collection = vectorDB_client.get_or_create_collection(vdb_name,
embedding_function=embedding_func)
if collection.count() < 1 and data != None:
print("populating db")
raw_documents = [Document(page_content=data.getvalue().decode("utf-8"),
metadata={"":""})]
text_splitter = CharacterTextSplitter(separator = ".",
chunk_size=int(chunk_size),
chunk_overlap=0)
docs = text_splitter.split_documents(raw_documents)
for doc in docs:
collection.add(
ids=[str(uuid.uuid1())],
metadatas=doc.metadata,
documents=doc.page_content
)
if data == None:
print("Empty VectorDB")
if data != None:
with open(data.name, mode='wb') as w:
w.write(data.getvalue())
print(f"{data.name} uploaded")
documents = get_docs(data.name)
else:
print("DB already populated")
documents = get_docs("../sample-data/fake_meeting.txt")
print("Empty VectorDB")
db = vdb.populate_db(documents)
retriever = db.as_retriever(threshold=0.75)
########################

if "messages" not in st.session_state:
Expand All @@ -78,11 +65,6 @@ def clear_vdb():
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])

db = Chroma(client=vectorDB_client,
collection_name=vdb_name,
embedding_function=e
)
retriever = db.as_retriever(threshold=0.75)

llm = ChatOpenAI(base_url=model_service,
api_key="EMPTY",
Expand All @@ -109,4 +91,4 @@ def clear_vdb():
response = chain.invoke(prompt)
st.chat_message("assistant").markdown(response.content)
st.session_state.messages.append({"role": "assistant", "content": response.content})
st.rerun()
st.rerun()
116 changes: 0 additions & 116 deletions recipes/natural_language_processing/rag/app/rag_app_milvus.py

This file was deleted.

0 comments on commit 6541325

Please sign in to comment.