Skip to content

Commit

Permalink
Add milvus vector database for rag recipe
Browse files Browse the repository at this point in the history
Signed-off-by: Shreyanand <[email protected]>
Co-authored-by: Michael Clifford <[email protected]>
Co-authored-by: greg pereira <[email protected]>
  • Loading branch information
3 people committed May 3, 2024
1 parent e2ead10 commit 705b5d1
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 101 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/rag.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ on:
branches:
- main
paths:
- ./recipes/common/Makefile.common
- ./recipes/natural_language_processing/rag/**
- .github/workflows/rag.yaml
- 'recipes/common/Makefile.common'
- 'recipes/natural_language_processing/rag/**'
- '.github/workflows/rag.yaml'
push:
branches:
- main
paths:
- ./recipes/common/Makefile.common
- ./recipes/natural_language_processing/rag/**
- .github/workflows/rag.yaml
- 'recipes/common/Makefile.common'
- '/recipes/natural_language_processing/rag/**'
- '.github/workflows/rag.yaml'

workflow_dispatch:

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ recipes/common/bin/*
*/.venv/
training/cloud/examples
training/instructlab/instructlab
vector_dbs/milvus/volumes/milvus/*
1 change: 1 addition & 0 deletions recipes/natural_language_processing/rag/app/Containerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ COPY requirements.txt .
RUN pip install --upgrade pip
RUN pip install --no-cache-dir --upgrade -r /rag/requirements.txt
COPY rag_app.py .
COPY manage_vectordb.py .
EXPOSE 8501
ENV HF_HUB_CACHE=/rag/models/
ENTRYPOINT [ "streamlit", "run" ,"rag_app.py" ]
81 changes: 81 additions & 0 deletions recipes/natural_language_processing/rag/app/manage_vectordb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from langchain_community.vectorstores import Chroma
from chromadb import HttpClient
from chromadb.config import Settings
import chromadb.utils.embedding_functions as embedding_functions
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Milvus
from pymilvus import MilvusClient
from pymilvus import connections, utility

class VectorDB:
def __init__(self, vector_vendor, host, port, collection_name, embedding_model):
self.vector_vendor = vector_vendor
self.host = host
self.port = port
self.collection_name = collection_name
self.embedding_model = embedding_model

def connect(self):
# Connection logic
print(f"Connecting to {self.host}:{self.port}...")
if self.vector_vendor == "chromadb":
self.client = HttpClient(host=self.host,
port=self.port,
settings=Settings(allow_reset=True,))
elif self.vector_vendor == "milvus":
self.client = MilvusClient(uri=f"http://{self.host}:{self.port}")
return self.client

def populate_db(self, documents):
# Logic to populate the VectorDB with vectors
e = SentenceTransformerEmbeddings(model_name=self.embedding_model)
print(f"Populating VectorDB with vectors...")
if self.vector_vendor == "chromadb":
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.embedding_model)
collection = self.client.get_or_create_collection(self.collection_name,
embedding_function=embedding_func)
if collection.count() < 1:
db = Chroma.from_documents(
documents=documents,
embedding=e,
collection_name=self.collection_name,
client=self.client
)
print("DB populated")
else:
db = Chroma(client=self.client,
collection_name=self.collection_name,
embedding_function=e,
)
print("DB already populated")

elif self.vector_vendor == "milvus":
connections.connect(host=self.host, port=self.port)
if not utility.has_collection(self.collection_name):
print("Populating VectorDB with vectors...")
db = Milvus.from_documents(
documents,
e,
collection_name=self.collection_name,
connection_args={"host": self.host, "port": self.port},
)
print("DB populated")
else:
print("DB already populated")
db = Milvus(
e,
collection_name=self.collection_name,
connection_args={"host": self.host, "port": self.port},
)
return db

def clear_db(self):
print(f"Clearing VectorDB...")
try:
if self.vector_vendor == "chromadb":
self.client.delete_collection(self.collection_name)
elif self.vector_vendor == "milvus":
self.client.drop_collection(self.collection_name)
print("Cleared DB")
except:
print("Couldn't clear the collection possibly because it doesn't exist")
36 changes: 0 additions & 36 deletions recipes/natural_language_processing/rag/app/populate_vectordb.py

This file was deleted.

88 changes: 30 additions & 58 deletions recipes/natural_language_processing/rag/app/rag_app.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,68 @@
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.callbacks import StreamlitCallbackHandler
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import TextLoader
from langchain_community.document_loaders import PyPDFLoader
from langchain.schema.document import Document
from chromadb import HttpClient
from chromadb.config import Settings
import chromadb.utils.embedding_functions as embedding_functions
import streamlit as st
from manage_vectordb import VectorDB
import tempfile
import uuid
import streamlit as st
import os

model_service = os.getenv("MODEL_ENDPOINT","http://0.0.0.0:8001")
model_service = f"{model_service}/v1"
chunk_size = os.getenv("CHUNK_SIZE", 150)
embedding_model = os.getenv("EMBEDDING_MODEL","BAAI/bge-base-en-v1.5")
vdb_vendor = os.getenv("VECTORDB_VENDOR", "chromadb")
vdb_host = os.getenv("VECTORDB_HOST", "0.0.0.0")
vdb_port = os.getenv("VECTORDB_PORT", "8000")
vdb_name = os.getenv("VECTORDB_NAME", "test_collection")

vdb = VectorDB(vdb_vendor, vdb_host, vdb_port, vdb_name, embedding_model)
vectorDB_client = vdb.connect()
def split_docs(raw_documents):
text_splitter = CharacterTextSplitter(separator = ".",
chunk_size=int(chunk_size),
chunk_overlap=0)
docs = text_splitter.split_documents(raw_documents)
return docs

vectorDB_client = HttpClient(host=vdb_host,
port=vdb_port,
settings=Settings(allow_reset=True,))

def clear_vdb():
global vectorDB_client
try:
vectorDB_client.delete_collection(vdb_name)
print("Cleared DB")
except:
pass

def read_file(file):
file_type = file.type

if file_type == "application/pdf":
temp = tempfile.NamedTemporaryFile()
with open(temp.name, "wb") as f:
f.write(file.getvalue())
loader = PyPDFLoader(temp.name)
pages = loader.load()
text = "".join([p.page_content for p in pages])

if file_type == "text/plain":
text = file.read().decode()

return text
temp = tempfile.NamedTemporaryFile()
with open(temp.name, "wb") as f:
f.write(file.getvalue())
loader = TextLoader(temp.name)
raw_documents = loader.load()
return raw_documents

st.title("📚 RAG DEMO")
with st.sidebar:
file = st.file_uploader(label="📄 Upload Document",
type=[".txt",".pdf"],
on_change=clear_vdb
)
type=[".txt",".pdf"],
on_change=vdb.clear_db
)

### populate the DB ####
os.environ["TOKENIZERS_PARALLELISM"] = "false"

embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=embedding_model)
e = SentenceTransformerEmbeddings(model_name=embedding_model)

collection = vectorDB_client.get_or_create_collection(vdb_name,
embedding_function=embedding_func)
if collection.count() < 1 and file != None:
print("populating db")
if file != None:
text = read_file(file)
raw_documents = [Document(page_content=text,
metadata={"":""})]
text_splitter = CharacterTextSplitter(separator = ".",
chunk_size=int(chunk_size),
chunk_overlap=0)
docs = text_splitter.split_documents(raw_documents)
for doc in docs:
collection.add(
ids=[str(uuid.uuid1())],
metadatas=doc.metadata,
documents=doc.page_content
)
if file == None:
print("Empty VectorDB")
documents = split_docs(text)
db = vdb.populate_db(documents)
retriever = db.as_retriever(threshold=0.75)
else:
print("DB already populated")
retriever = {}
print("Empty VectorDB")


########################

if "messages" not in st.session_state:
Expand All @@ -95,11 +72,6 @@ def read_file(file):
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])

db = Chroma(client=vectorDB_client,
collection_name=vdb_name,
embedding_function=e
)
retriever = db.as_retriever(threshold=0.75)

llm = ChatOpenAI(base_url=model_service,
api_key="EMPTY",
Expand All @@ -126,4 +98,4 @@ def read_file(file):
response = chain.invoke(prompt)
st.chat_message("assistant").markdown(response.content)
st.session_state.messages.append({"role": "assistant", "content": response.content})
st.rerun()
st.rerun()
2 changes: 1 addition & 1 deletion vector_dbs/Makefile → vector_dbs/chromadb/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ APPIMAGE ?= quay.io/ai-lab/${APP}:latest

.PHONY: build
build:
podman build -f chromadb/Containerfile -t ${APPIMAGE} .
podman build -f Containerfile -t ${APPIMAGE} .
2 changes: 2 additions & 0 deletions vector_dbs/milvus/Containerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
FROM docker.io/milvusdb/milvus:master-20240426-bed6363f
ADD embedEtcd.yaml /milvus/configs/embedEtcd.yaml
55 changes: 55 additions & 0 deletions vector_dbs/milvus/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
REGISTRY ?= quay.io
REGISTRY_ORG ?= ai-lab
COMPONENT = vector_dbs

IMAGE ?= $(REGISTRY)/$(REGISTRY_ORG)/$(COMPONENT)/milvus:latest

ARCH ?= $(shell uname -m)
PLATFORM ?= linux/$(ARCH)

gRCP_PORT := 19530
REST_PORT := 9091
CLIENT_PORT := 2379

LIB_MILVUS_DIR_MOUNTPATH := $(shell pwd)/volumes/milvus

.PHONY: build
build:
podman build --platform $(PLATFORM) -f Containerfile -t ${IMAGE} .

.PHONY: run
run:
podman run -d \
--name milvus-standalone \
--security-opt seccomp:unconfined \
-e ETCD_USE_EMBED=true \
-e ETCD_CONFIG_PATH=/milvus/configs/embedEtcd.yaml \
-e COMMON_STORAGETYPE=local \
-v $(LIB_MILVUS_DIR_MOUNTPATH):/var/lib/milvus \
-p $(gRCP_PORT):$(gRCP_PORT) \
-p $(REST_PORT):$(REST_PORT) \
-p $(CLIENT_PORT):$(CLIENT_PORT) \
--health-cmd="curl -f http://localhost:$(REST_PORT)/healthz" \
--health-interval=30s \
--health-start-period=90s \
--health-timeout=20s \
--health-retries=3 \
$(IMAGE) \
milvus run standalone 1> /dev/null

.PHONY: stop
stop:
-podman stop milvus-standalone

.PHONY: delete
delete:
-podman rm milvus-standalone -f

.PHONY: podman-clean
podman-clean:
@container_ids=$$(podman ps --format "{{.ID}} {{.Image}}" | awk '$$2 == "$(IMAGE)" {print $$1}'); \
echo "removing all containers with IMAGE=$(IMAGE)"; \
for id in $$container_ids; do \
echo "Removing container: $$id,"; \
podman rm -f $$id; \
done
5 changes: 5 additions & 0 deletions vector_dbs/milvus/embedEtcd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
listen-client-urls: http://0.0.0.0:2379
advertise-client-urls: http://0.0.0.0:2379
quota-backend-bytes: 4294967296
auto-compaction-mode: revision
auto-compaction-retention: '1000'

0 comments on commit 705b5d1

Please sign in to comment.