From 654132577d098cd3323bcd25e51ec2194dfa83f4 Mon Sep 17 00:00:00 2001 From: Shreyanand Date: Wed, 24 Apr 2024 17:32:24 -0400 Subject: [PATCH] Add switching between vdbs logic --- .../rag/app/manage_vectordb.py | 98 +++++++++++++++ .../rag/app/populate_vectordb.py | 36 ------ .../rag/app/rag_app.py | 90 ++++++-------- .../rag/app/rag_app_milvus.py | 116 ------------------ 4 files changed, 134 insertions(+), 206 deletions(-) create mode 100644 recipes/natural_language_processing/rag/app/manage_vectordb.py delete mode 100644 recipes/natural_language_processing/rag/app/populate_vectordb.py delete mode 100644 recipes/natural_language_processing/rag/app/rag_app_milvus.py diff --git a/recipes/natural_language_processing/rag/app/manage_vectordb.py b/recipes/natural_language_processing/rag/app/manage_vectordb.py new file mode 100644 index 000000000..43f6fefd5 --- /dev/null +++ b/recipes/natural_language_processing/rag/app/manage_vectordb.py @@ -0,0 +1,98 @@ +from langchain_community.vectorstores import Chroma +from chromadb import HttpClient +from chromadb.config import Settings +import chromadb.utils.embedding_functions as embedding_functions +from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings +from langchain_community.vectorstores import Milvus +from pymilvus import MilvusClient +from pymilvus import connections, utility +import uuid + +class VectorDB: + def __init__(self, vector_vendor, host, port, collection_name, embedding_model): + self.vector_vendor = vector_vendor + self.host = host + self.port = port + self.collection_name = collection_name + self.embedding_model = embedding_model + + def connect(self): + # Connection logic here + print(f"Connecting to {self.host}:{self.port}...") + if self.vector_vendor == "chromadb": + self.client = HttpClient(host=self.host, + port=self.port, + settings=Settings(allow_reset=True,)) + elif self.vector_vendor == "milvus": + self.client = MilvusClient(uri=f"http://{self.host}:{self.port}") + return self.client + + def populate_db(self, documents): + # Implement logic to populate the VectorDB with vectors + e = SentenceTransformerEmbeddings(model_name=self.embedding_model) + print(f"Populating VectorDB with vectors...") + if self.vector_vendor == "chromadb": + embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.embedding_model) + collection = self.client.get_or_create_collection(self.collection_name, + embedding_function=embedding_func) + if collection.count() < 1: + for doc in documents: + collection.add( + ids=[str(uuid.uuid1())], + metadatas=doc.metadata, + documents=doc.page_content + ) + print("DB populated") + else: + print("DB already populated") + db = Chroma(client=self.client, + collection_name=self.collection_name, + embedding_function=e) + + elif self.vector_vendor == "milvus": + connections.connect(host=self.host, port=self.port) + if not utility.has_collection(self.collection_name): + print("Populating VectorDB with vectors...") + db = Milvus.from_documents( + documents, + e, + collection_name=self.collection_name, + connection_args={"host": self.host, "port": self.port}, + ) + print("DB populated") + else: + print("DB already populated") + db = Milvus( + e, + collection_name=self.collection_name, + connection_args={"host": self.host, "port": self.port}, + ) + return db + + def clear_db(self): + # Implement logic to clear the entire VectorDB + print(f"Clearing the entire VectorDB...") + try: + if self.vector_vendor == "chromadb": + self.client.delete_collection(self.collection_name) + elif self.vector_vendor == "milvus": + self.client.drop_collection(self.collection_name) + print("Cleared DB") + except: + print("Couldn't clear the collection possibly because it doesn't exist") + + + # def has_collection(self): + # # Implement logic to check if the VectorDB has a collection + # print(f"Checking if collection {self.collection_name} exists in VectorDB...") + # if self.vector_vendor == "chromadb": + # hc = collection.count() < 1 + # elif self.vector_vendor == "milvus": + # hc = + # return hc + + + def create_retriever(self): + # Implement logic to create and return a retriever object from the VectorDB + print("Creating retriever object from VectorDB...") + diff --git a/recipes/natural_language_processing/rag/app/populate_vectordb.py b/recipes/natural_language_processing/rag/app/populate_vectordb.py deleted file mode 100644 index 2bbb6efca..000000000 --- a/recipes/natural_language_processing/rag/app/populate_vectordb.py +++ /dev/null @@ -1,36 +0,0 @@ -from langchain_community.document_loaders import TextLoader -from langchain.text_splitter import CharacterTextSplitter -import chromadb.utils.embedding_functions as embedding_functions -import chromadb -from chromadb.config import Settings -import uuid -import os -import argparse -import time - -parser = argparse.ArgumentParser() -parser.add_argument("-d", "--docs", default="data/fake_meeting.txt") -parser.add_argument("-c", "--chunk_size", default=150) -parser.add_argument("-e", "--embedding_model", default="BAAI/bge-base-en-v1.5") -parser.add_argument("-H", "--vdb_host", default="0.0.0.0") -parser.add_argument("-p", "--vdb_port", default="8000") -parser.add_argument("-n", "--name", default="test_collection") -args = parser.parse_args() - -raw_documents = TextLoader(args.docs).load() -text_splitter = CharacterTextSplitter(separator = ".", chunk_size=int(args.chunk_size), chunk_overlap=0) -docs = text_splitter.split_documents(raw_documents) -os.environ["TORCH_HOME"] = "./models/" - -embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=args.embedding_model) -client = chromadb.HttpClient(host=args.vdb_host, - port=args.vdb_port, - settings=Settings(allow_reset=True,)) -collection = client.get_or_create_collection(args.name, - embedding_function=embedding_func) -for doc in docs: - collection.add( - ids=[str(uuid.uuid1())], - metadatas=doc.metadata, - documents=doc.page_content - ) \ No newline at end of file diff --git a/recipes/natural_language_processing/rag/app/rag_app.py b/recipes/natural_language_processing/rag/app/rag_app.py index e71fd8a42..7db6ce1dc 100644 --- a/recipes/natural_language_processing/rag/app/rag_app.py +++ b/recipes/natural_language_processing/rag/app/rag_app.py @@ -1,74 +1,61 @@ from langchain_openai import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough -from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings from langchain.text_splitter import CharacterTextSplitter from langchain_community.callbacks import StreamlitCallbackHandler -from langchain.schema.document import Document -from langchain_community.vectorstores import Chroma - -from chromadb import HttpClient -from chromadb.config import Settings -import chromadb.utils.embedding_functions as embedding_functions - +from langchain_community.document_loaders import TextLoader +from manage_vectordb import VectorDB import streamlit as st - -import uuid import os -import argparse -import pathlib -model_service = os.getenv("MODEL_ENDPOINT","http://0.0.0.0:8001/v1") +model_service = os.getenv("MODEL_ENDPOINT","http://0.0.0.0:8001") model_service = f"{model_service}/v1" chunk_size = os.getenv("CHUNK_SIZE", 150) embedding_model = os.getenv("EMBEDDING_MODEL","BAAI/bge-base-en-v1.5") +vdb_vendor = os.getenv("VECTORDB_VENDOR", "chromadb") vdb_host = os.getenv("VECTORDB_HOST", "0.0.0.0") vdb_port = os.getenv("VECTORDB_PORT", "8000") vdb_name = os.getenv("VECTORDB_NAME", "test_collection") +# Use the following defaults if using milvus db +# vdb_vendor = os.getenv("VECTORDB_VENDOR", "milvus") +# vdb_host = os.getenv("VECTORDB_HOST", "127.0.0.1") +# vdb_port = os.getenv("VECTORDB_PORT", "19530") +# vdb_name = os.getenv("VECTORDB_NAME", "test_collection") + +vdb = VectorDB(vdb_vendor, vdb_host, vdb_port, vdb_name, embedding_model) +vectorDB_client = vdb.connect() +def get_docs(filename): + loader = TextLoader(filename) + raw_documents = loader.load() + text_splitter = CharacterTextSplitter(separator = ".", + chunk_size=int(chunk_size), + chunk_overlap=0) + docs = text_splitter.split_documents(raw_documents) + return docs + -vectorDB_client = HttpClient(host=vdb_host, - port=vdb_port, - settings=Settings(allow_reset=True,)) +def create_tmp_file(): + with open(data.name, mode='wb') as w: + w.write(data.getvalue()) + print(f"{data.name} uploaded") -def clear_vdb(): - global vectorDB_client - try: - vectorDB_client.delete_collection(vdb_name) - print("Cleared DB") - except: - pass st.title("📚 RAG DEMO") with st.sidebar: - data = st.file_uploader(label="📄 Upload Document",type=['txt'], on_change=clear_vdb) + data = st.file_uploader(label="📄 Upload Document",type=['txt'], on_change=vdb.clear_db) ### populate the DB #### -os.environ["TOKENIZERS_PARALLELISM"] = "false" - -embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=embedding_model) -e = SentenceTransformerEmbeddings(model_name=embedding_model) - -collection = vectorDB_client.get_or_create_collection(vdb_name, - embedding_function=embedding_func) -if collection.count() < 1 and data != None: - print("populating db") - raw_documents = [Document(page_content=data.getvalue().decode("utf-8"), - metadata={"":""})] - text_splitter = CharacterTextSplitter(separator = ".", - chunk_size=int(chunk_size), - chunk_overlap=0) - docs = text_splitter.split_documents(raw_documents) - for doc in docs: - collection.add( - ids=[str(uuid.uuid1())], - metadatas=doc.metadata, - documents=doc.page_content - ) -if data == None: - print("Empty VectorDB") +if data != None: + with open(data.name, mode='wb') as w: + w.write(data.getvalue()) + print(f"{data.name} uploaded") + documents = get_docs(data.name) else: - print("DB already populated") + documents = get_docs("../sample-data/fake_meeting.txt") + print("Empty VectorDB") +db = vdb.populate_db(documents) +retriever = db.as_retriever(threshold=0.75) ######################## if "messages" not in st.session_state: @@ -78,11 +65,6 @@ def clear_vdb(): for msg in st.session_state.messages: st.chat_message(msg["role"]).write(msg["content"]) -db = Chroma(client=vectorDB_client, - collection_name=vdb_name, - embedding_function=e - ) -retriever = db.as_retriever(threshold=0.75) llm = ChatOpenAI(base_url=model_service, api_key="EMPTY", @@ -109,4 +91,4 @@ def clear_vdb(): response = chain.invoke(prompt) st.chat_message("assistant").markdown(response.content) st.session_state.messages.append({"role": "assistant", "content": response.content}) - st.rerun() + st.rerun() \ No newline at end of file diff --git a/recipes/natural_language_processing/rag/app/rag_app_milvus.py b/recipes/natural_language_processing/rag/app/rag_app_milvus.py deleted file mode 100644 index a66e2c088..000000000 --- a/recipes/natural_language_processing/rag/app/rag_app_milvus.py +++ /dev/null @@ -1,116 +0,0 @@ -from langchain_openai import ChatOpenAI -from langchain_core.prompts import ChatPromptTemplate -from langchain_core.runnables import RunnablePassthrough -from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings -from langchain.text_splitter import CharacterTextSplitter -from langchain_community.callbacks import StreamlitCallbackHandler -from langchain_community.vectorstores import Milvus -from pymilvus import MilvusClient -from pymilvus import utility -import streamlit as st -from pymilvus import connections, utility - -import uuid -import os -import argparse -import pathlib - -model_service = os.getenv("MODEL_ENDPOINT","http://localhost:8000/v1") -#model_service = f"{model_service}/v1" -chunk_size = os.getenv("CHUNK_SIZE", 150) -embedding_model = os.getenv("EMBEDDING_MODEL","BAAI/bge-base-en-v1.5") -vdb_host = os.getenv("VECTORDB_HOST", "127.0.0.1") -vdb_port = os.getenv("VECTORDB_PORT", "19530") -vdb_name = os.getenv("VECTORDB_NAME", "test_collection") - -vectorDB_client = MilvusClient(uri=f"http://{vdb_host}:{vdb_port}") - -def clear_vdb(): - global vectorDB_client - try: - vectorDB_client.drop_collection(vdb_name) - print("Cleared DB") - except: - pass - -st.title("📚 RAG DEMO") -with st.sidebar: - data = st.file_uploader(label="📄 Upload Document",type=['txt'], on_change=clear_vdb) - -with open(data.name, mode='wb') as w: - w.write(data.getvalue()) -print(f"{data.name} uploaded") - -### populate the DB #### -os.environ["TOKENIZERS_PARALLELISM"] = "false" - -#embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=embedding_model) -e = SentenceTransformerEmbeddings(model_name=embedding_model) - -connections.connect( - host=vdb_host, - port=vdb_port -) -if not utility.has_collection(vdb_name) and data != None: - print("populating db") - from langchain_community.document_loaders import TextLoader - - loader = TextLoader(data.name) - raw_documents = loader.load() - text_splitter = CharacterTextSplitter(separator = ".", - chunk_size=int(chunk_size), - chunk_overlap=0) - docs = text_splitter.split_documents(raw_documents) - vector_db = Milvus.from_documents( - docs, - e, - collection_name=vdb_name, - connection_args={"host": vdb_host, "port": vdb_port}, - ) -else: - print("DB already populated") - vector_db = Milvus( - e, - collection_name=vdb_name, - connection_args={"host": vdb_host, "port": vdb_port}, - ) -if data == None: - print("Empty VectorDB") - -######################## - -if "messages" not in st.session_state: - st.session_state["messages"] = [{"role": "assistant", - "content": "How can I help you?"}] - -for msg in st.session_state.messages: - st.chat_message(msg["role"]).write(msg["content"]) - -retriever = vector_db.as_retriever(threshold=0.75) - -llm = ChatOpenAI(base_url=model_service, - api_key="EMPTY", - streaming=True, - callbacks=[StreamlitCallbackHandler(st.container(), - collapse_completed_thoughts=True)]) - -prompt = ChatPromptTemplate.from_template("""Answer the question based only on the following context: -{context} - -Question: {input} -""" -) - -chain = ( - {"context": retriever, "input": RunnablePassthrough()} - | prompt - | llm -) - -if prompt := st.chat_input(): - st.session_state.messages.append({"role": "user", "content": prompt}) - st.chat_message("user").markdown(prompt) - response = chain.invoke(prompt) - st.chat_message("assistant").markdown(response.content) - st.session_state.messages.append({"role": "assistant", "content": response.content}) - st.rerun() \ No newline at end of file