Skip to content

Commit daba1f5

Browse files
authored
Remove local storage and enable Elasticsearch hybrid query mode (#60)
* Add gpu dockerfile * Fix bug * Fix gb2312 * Update embedding batch size * Set default embedding and llm model * Update docker tag * Fix hologres check * Update registry * Fix bug * Fix tests * Add queue * Update batch size * Add async interface * Fix index conflict * Add change index parameter for FAISS * Fix batch size * Update
1 parent ba1132a commit daba1f5

26 files changed

+1217
-180
lines changed

src/pai_rag/app/api/models.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
from typing import List, Dict
33

44

5+
class VectorDbConfig(BaseModel):
6+
faiss_path: str | None = None
7+
8+
59
class RagQuery(BaseModel):
610
question: str
711
temperature: float | None = 0.1
8-
vector_topk: int | None = 3
9-
score_threshold: float | None = 0.5
1012
chat_history: List[Dict[str, str]] | None = None
1113
session_id: str | None = None
14+
vector_db: VectorDbConfig | None = None
1215

1316

1417
class LlmQuery(BaseModel):
@@ -20,8 +23,7 @@ class LlmQuery(BaseModel):
2023

2124
class RetrievalQuery(BaseModel):
2225
question: str
23-
topk: int | None = 3
24-
score_threshold: float | None = 0.5
26+
vector_db: VectorDbConfig | None = None
2527

2628

2729
class RagResponse(BaseModel):

src/pai_rag/core/rag_application.py

+43-33
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from pai_rag.data.rag_dataloader import RagDataLoader
2-
from pai_rag.utils.oss_cache import OssCache
31
from pai_rag.modules.module_registry import module_registry
42
from pai_rag.evaluations.batch_evaluator import BatchEvaluator
53
from pai_rag.app.api.models import (
@@ -24,49 +22,34 @@ def uuid_generator() -> str:
2422
class RagApplication:
2523
def __init__(self):
2624
self.name = "RagApplication"
27-
logging.basicConfig(level=logging.INFO) # 将日志级别设置为INFO
2825
self.logger = logging.getLogger(__name__)
2926

3027
def initialize(self, config):
3128
self.config = config
32-
3329
module_registry.init_modules(self.config)
34-
self.index = module_registry.get_module("IndexModule")
35-
self.llm = module_registry.get_module("LlmModule")
36-
self.retriever = module_registry.get_module("RetrieverModule")
37-
self.chat_store = module_registry.get_module("ChatStoreModule")
38-
self.query_engine = module_registry.get_module("QueryEngineModule")
39-
self.chat_engine_factory = module_registry.get_module("ChatEngineFactoryModule")
40-
self.llm_chat_engine_factory = module_registry.get_module(
41-
"LlmChatEngineFactoryModule"
42-
)
43-
self.data_reader_factory = module_registry.get_module("DataReaderFactoryModule")
44-
self.agent = module_registry.get_module("AgentModule")
45-
46-
oss_cache = None
47-
if config.get("oss_cache", None):
48-
oss_cache = OssCache(config.oss_cache)
49-
node_parser = module_registry.get_module("NodeParserModule")
50-
51-
self.data_loader = RagDataLoader(
52-
self.data_reader_factory, node_parser, self.index, oss_cache
53-
)
5430
self.logger.info("RagApplication initialized successfully.")
5531

5632
def reload(self, config):
5733
self.initialize(config)
5834
self.logger.info("RagApplication reloaded successfully.")
5935

6036
# TODO: 大量文件上传实现异步添加
61-
def load_knowledge(self, file_dir, enable_qa_extraction=False):
62-
self.data_loader.load(file_dir, enable_qa_extraction)
37+
async def load_knowledge(self, file_dir, enable_qa_extraction=False):
38+
data_loader = module_registry.get_module_with_config(
39+
"DataLoaderModule", self.config
40+
)
41+
await data_loader.aload(file_dir, enable_qa_extraction)
6342

6443
async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse:
6544
if not query.question:
6645
return RetrievalResponse(docs=[])
6746

6847
query_bundle = QueryBundle(query.question)
69-
node_results = await self.query_engine.aretrieve(query_bundle)
48+
49+
query_engine = module_registry.get_module_with_config(
50+
"QueryEngineModule", self.config
51+
)
52+
node_results = await query_engine.aretrieve(query_bundle)
7053

7154
docs = [
7255
ContextDoc(
@@ -96,11 +79,24 @@ async def aquery(self, query: RagQuery) -> RagResponse:
9679
answer="Empty query. Please input your question.", session_id=session_id
9780
)
9881

99-
query_chat_engine = self.chat_engine_factory.get_chat_engine(
82+
sessioned_config = self.config
83+
if query.vector_db and query.vector_db.faiss_path:
84+
sessioned_config = self.config.copy()
85+
sessioned_config.index.update({"persist_path": query.vector_db.faiss_path})
86+
print(sessioned_config)
87+
88+
chat_engine_factory = module_registry.get_module_with_config(
89+
"ChatEngineFactoryModule", sessioned_config
90+
)
91+
query_chat_engine = chat_engine_factory.get_chat_engine(
10092
session_id, query.chat_history
10193
)
10294
response = await query_chat_engine.achat(query.question)
103-
self.chat_store.persist()
95+
96+
chat_store = module_registry.get_module_with_config(
97+
"ChatStoreModule", sessioned_config
98+
)
99+
chat_store.persist()
104100
return RagResponse(answer=response.response, session_id=session_id)
105101

106102
async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
@@ -122,11 +118,18 @@ async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
122118
answer="Empty query. Please input your question.", session_id=session_id
123119
)
124120

125-
llm_chat_engine = self.llm_chat_engine_factory.get_chat_engine(
121+
llm_chat_engine_factory = module_registry.get_module_with_config(
122+
"LlmChatEngineFactoryModule", self.config
123+
)
124+
llm_chat_engine = llm_chat_engine_factory.get_chat_engine(
126125
session_id, query.chat_history
127126
)
128127
response = await llm_chat_engine.achat(query.question)
129-
self.chat_store.persist()
128+
129+
chat_store = module_registry.get_module_with_config(
130+
"ChatStoreModule", self.config
131+
)
132+
chat_store.persist()
130133
return LlmResponse(answer=response.response, session_id=session_id)
131134

132135
async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
@@ -143,11 +146,18 @@ async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
143146
if not query.question:
144147
return LlmResponse(answer="Empty query. Please input your question.")
145148

146-
response = await self.agent.achat(query.question)
149+
agent = module_registry.get_module_with_config("AgentModule", self.config)
150+
response = await agent.achat(query.question)
147151
return LlmResponse(answer=response.response)
148152

149153
async def batch_evaluate_retrieval_and_response(self, type):
150-
batch_eval = BatchEvaluator(self.config, self.retriever, self.query_engine)
154+
retriever = module_registry.get_module_with_config(
155+
"RetrieverModule", self.config
156+
)
157+
query_engine = module_registry.get_module_with_config(
158+
"QueryEngineModule", self.config
159+
)
160+
batch_eval = BatchEvaluator(self.config, retriever, query_engine)
151161
df, eval_res_avg = await batch_eval.batch_retrieval_response_aevaluation(
152162
type=type, workers=2, save_to_file=True
153163
)

src/pai_rag/core/rag_service.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from pai_rag.app.web.view_model import view_model
1313
from openinference.instrumentation import using_attributes
1414
from typing import Any, Dict
15+
import logging
16+
17+
logger = logging.getLogger(__name__)
1518

1619

1720
def trace_correlation_id(function):
@@ -48,14 +51,15 @@ def reload(self, new_config: Any):
4851
self.rag.reload(self.rag_configuration.get_value())
4952
self.rag_configuration.persist()
5053

51-
def add_knowledge_async(
54+
async def add_knowledge_async(
5255
self, task_id: str, file_dir: str, enable_qa_extraction: bool = False
5356
):
5457
self.tasks_status[task_id] = "processing"
5558
try:
56-
self.rag.load_knowledge(file_dir, enable_qa_extraction)
59+
await self.rag.load_knowledge(file_dir, enable_qa_extraction)
5760
self.tasks_status[task_id] = "completed"
58-
except Exception:
61+
except Exception as ex:
62+
logger.error(f"Upload failed: {ex}")
5963
self.tasks_status[task_id] = "failed"
6064

6165
def get_task_status(self, task_id: str) -> str:

src/pai_rag/data/rag_dataloader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def __init__(
3737
):
3838
self.datareader_factory = datareader_factory
3939
self.node_parser = node_parser
40-
self.index = index
4140
self.oss_cache = oss_cache
41+
self.index = index
4242

4343
if use_local_qa_model:
4444
# API暂不支持此选项
@@ -111,7 +111,7 @@ async def aload(self, file_directory: str, enable_qa_extraction: bool):
111111

112112
logger.info("[DataReader] Start inserting to index.")
113113

114-
self.index.insert_nodes(nodes)
114+
await self.index.insert_nodes_async(nodes)
115115
self.index.storage_context.persist(persist_dir=store_path.persist_path)
116116
logger.info(f"Inserted {len(nodes)} nodes successfully.")
117117
return

src/pai_rag/data/rag_datapipeline.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22
import click
33
import os
44
from pathlib import Path
5-
from pai_rag.data.rag_dataloader import RagDataLoader
65
from pai_rag.core.rag_configuration import RagConfiguration
7-
from pai_rag.utils.oss_cache import OssCache
86
from pai_rag.modules.module_registry import module_registry
97

108

119
class RagDataPipeline:
12-
def __init__(self, data_loader: RagDataLoader):
10+
def __init__(self, data_loader):
1311
self.data_loader = data_loader
1412

1513
async def ingest_from_folder(self, folder_path: str, enable_qa_extraction: bool):
@@ -23,16 +21,7 @@ def __init_data_pipeline(use_local_qa_model):
2321
config = RagConfiguration.from_file(config_file).get_value()
2422
module_registry.init_modules(config)
2523

26-
oss_cache = None
27-
if config.get("oss_cache", None):
28-
oss_cache = OssCache(config.oss_cache)
29-
node_parser = module_registry.get_module("NodeParserModule")
30-
index = module_registry.get_module("IndexModule")
31-
data_reader_factory = module_registry.get_module("DataReaderFactoryModule")
32-
33-
data_loader = RagDataLoader(
34-
data_reader_factory, node_parser, index, oss_cache, use_local_qa_model
35-
)
24+
data_loader = module_registry.get_module_with_config("DataLoaderModule", config)
3625
return RagDataPipeline(data_loader)
3726

3827

src/pai_rag/evaluations/batch_evaluator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ def __init_evaluator_pipeline():
189189
config = RagConfiguration.from_file(config_file).get_value()
190190
module_registry.init_modules(config)
191191

192-
retriever = module_registry.get_module("RetrieverModule")
193-
query_engine = module_registry.get_module("QueryEngineModule")
192+
retriever = module_registry.get_module_with_config("RetrieverModule", config)
193+
query_engine = module_registry.get_module_with_config("QueryEngineModule", config)
194194

195195
return BatchEvaluator(config, retriever, query_engine)
196196

src/pai_rag/evaluations/dataset_generation/generate_dataset.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
from pathlib import Path
34
from pai_rag.core.rag_configuration import RagConfiguration
45
from pai_rag.modules.module_registry import module_registry
56
from llama_index.core.prompts.prompt_type import PromptType
@@ -16,8 +17,13 @@
1617
DEFAULT_TEXT_QA_PROMPT_TMPL,
1718
DEFAULT_QUESTION_GENERATION_QUERY,
1819
)
20+
1921
import json
2022

23+
_BASE_DIR = Path(__file__).parent.parent.parent
24+
DEFAULT_EVAL_CONFIG_FILE = os.path.join(_BASE_DIR, "config/settings.toml")
25+
DEFAULT_EVAL_DATA_FOLDER = "tests/testdata/paul_graham"
26+
2127

2228
class GenerateDatasetPipeline(ModifiedRagDatasetGenerator):
2329
def __init__(
@@ -29,11 +35,22 @@ def __init__(
2935
show_progress: Optional[bool] = True,
3036
) -> None:
3137
self.name = "GenerateDatasetPipeline"
32-
self.nodes = list(
33-
module_registry.get_module("IndexModule").docstore.docs.values()
38+
self.config = RagConfiguration.from_file(DEFAULT_EVAL_CONFIG_FILE).get_value()
39+
40+
# load nodes
41+
module_registry.init_modules(self.config)
42+
datareader_factory = module_registry.get_module_with_config(
43+
"DataReaderFactoryModule", self.config
3444
)
45+
self.node_parser = module_registry.get_module_with_config(
46+
"NodeParserModule", self.config
47+
)
48+
reader = datareader_factory.get_reader(DEFAULT_EVAL_DATA_FOLDER)
49+
docs = reader.load_data()
50+
self.nodes = self.node_parser.get_nodes_from_documents(docs)
51+
3552
self.num_questions_per_chunk = num_questions_per_chunk
36-
self.llm = module_registry.get_module("LlmModule")
53+
self.llm = module_registry.get_module_with_config("LlmModule", self.config)
3754
self.text_question_template = PromptTemplate(text_question_template_str)
3855
self.text_qa_template = PromptTemplate(
3956
text_qa_template_str, prompt_type=PromptType.QUESTION_ANSWER

src/pai_rag/modules/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pai_rag.modules.embedding.embedding import EmbeddingModule
22
from pai_rag.modules.llm.llm_module import LlmModule
3+
from pai_rag.modules.datareader.data_loader import DataLoaderModule
34
from pai_rag.modules.datareader.datareader_factory import DataReaderFactoryModule
45
from pai_rag.modules.index.index import IndexModule
56
from pai_rag.modules.nodeparser.node_parser import NodeParserModule
@@ -12,10 +13,13 @@
1213
from pai_rag.modules.chat.chat_store import ChatStoreModule
1314
from pai_rag.modules.agent.agent import AgentModule
1415
from pai_rag.modules.tool.tool import ToolModule
16+
from pai_rag.modules.cache.oss_cache import OssCacheModule
17+
1518

1619
ALL_MODULES = [
1720
"EmbeddingModule",
1821
"LlmModule",
22+
"DataLoaderModule",
1923
"DataReaderFactoryModule",
2024
"IndexModule",
2125
"NodeParserModule",
@@ -28,6 +32,7 @@
2832
"LlmChatEngineFactoryModule",
2933
"AgentModule",
3034
"ToolModule",
35+
"OssCacheModule",
3136
]
3237

3338
__all__ = ALL_MODULES + ["ALL_MODULES"]
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
from abc import ABC, abstractmethod
22
from typing import Dict, List, Any
3+
import logging
34

45
DEFAULT_INSTANCE_KEY = "__DEFAULT_INSTANCE__"
56

67

8+
logger = logging.getLogger(__name__)
9+
10+
711
class ConfigurableModule(ABC):
812
"""Configurable Module
913
1014
Helps to create instances according to configuration.
1115
"""
1216

13-
def __init__(self):
14-
self.__params_map = {}
15-
self.__instance_map = {}
16-
1717
@abstractmethod
1818
def _create_new_instance(self, new_params: Dict[str, Any]):
1919
raise NotImplementedError
@@ -24,20 +24,4 @@ def get_dependencies() -> List[str]:
2424
raise NotImplementedError
2525

2626
def get_or_create(self, new_params: Dict[str, Any]):
27-
return self.get_or_create_by_name(new_params=new_params)
28-
29-
def get_or_create_by_name(
30-
self, new_params: Dict[str, Any], name: str = DEFAULT_INSTANCE_KEY
31-
):
32-
# Create new instance when initializing or config changed.
33-
if (
34-
self.__params_map.get(name, None) is None
35-
or self.__params_map[name] != new_params
36-
):
37-
print(f"{self.__class__.__name__} param changed, updating")
38-
self.__instance_map[name] = self._create_new_instance(new_params)
39-
self.__params_map[name] = new_params
40-
else:
41-
print(f"{self.__class__.__name__} param unchanged, skipping")
42-
43-
return self.__instance_map[name]
27+
return self._create_new_instance(new_params)
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import Any, Dict, List
2+
from pai_rag.utils.oss_cache import OssCache
3+
from pai_rag.modules.base.configurable_module import ConfigurableModule
4+
from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG
5+
import logging
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
class OssCacheModule(ConfigurableModule):
11+
@staticmethod
12+
def get_dependencies() -> List[str]:
13+
return []
14+
15+
def _create_new_instance(self, new_params: Dict[str, Any]):
16+
cache_config = new_params[MODULE_PARAM_CONFIG]
17+
if cache_config:
18+
return OssCache(cache_config)
19+
else:
20+
return None

0 commit comments

Comments
 (0)