-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrag.py
75 lines (58 loc) · 2.37 KB
/
rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from chunk_vector_store import ChunkVectorStore as cvs
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
import os
class Rag:
vector_store = None
retriever = None
chain = None
def __init__(self) -> None:
self.csv_obj = cvs()
self.prompt = PromptTemplate.from_template(
"""
You are an intelligent assistant designed to answer user questions effectively.
Use the provided context to answer the question as thoroughly as possible.
Pay close attention to every detail to ensure the answer is accurate, clear, and
complete. Keep the tone conversational and engaging while ensuring the information
comes directly from the context. If a question is irrelevant, politey inform the user
that the information is not present.
Question: {question}
Context: {context}
Answer:
""")
self.model = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
openai_api_key=os.getenv("OPENAI_API_KEY"))
def set_retriever(self):
self.retriever = self.vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 5, "score_threshold": 0.5,}
)
def augment(self):
self.chain = (
{"context": self.retriever, "question": RunnablePassthrough()}
| self.prompt
| self.model
| StrOutputParser()
)
def ask(self, query: str):
if not self.chain:
return "Please upload the documents to start the conversation!"
return self.chain.invoke(query)
def feed(self, file_path: str):
chunks = self.csv_obj.split_into_chunks(file_path)
self.vector_store, chunk_ids = self.csv_obj.store_to_vector_database(chunks)
self.set_retriever()
self.augment()
return chunk_ids
def clear(self, chunk_ids: list):
if self.vector_store:
batch_size = 166
for i in range(0, len(chunk_ids), batch_size):
batch = chunk_ids[i:i + batch_size]
self.vector_store.delete(batch)
self.chain = None
self.retriever = None