1
- from pai_rag .data .rag_dataloader import RagDataLoader
2
- from pai_rag .utils .oss_cache import OssCache
3
1
from pai_rag .modules .module_registry import module_registry
4
2
from pai_rag .evaluations .batch_evaluator import BatchEvaluator
5
3
from pai_rag .app .api .models import (
@@ -24,49 +22,34 @@ def uuid_generator() -> str:
24
22
class RagApplication :
25
23
def __init__ (self ):
26
24
self .name = "RagApplication"
27
- logging .basicConfig (level = logging .INFO ) # 将日志级别设置为INFO
28
25
self .logger = logging .getLogger (__name__ )
29
26
30
27
def initialize (self , config ):
31
28
self .config = config
32
-
33
29
module_registry .init_modules (self .config )
34
- self .index = module_registry .get_module ("IndexModule" )
35
- self .llm = module_registry .get_module ("LlmModule" )
36
- self .retriever = module_registry .get_module ("RetrieverModule" )
37
- self .chat_store = module_registry .get_module ("ChatStoreModule" )
38
- self .query_engine = module_registry .get_module ("QueryEngineModule" )
39
- self .chat_engine_factory = module_registry .get_module ("ChatEngineFactoryModule" )
40
- self .llm_chat_engine_factory = module_registry .get_module (
41
- "LlmChatEngineFactoryModule"
42
- )
43
- self .data_reader_factory = module_registry .get_module ("DataReaderFactoryModule" )
44
- self .agent = module_registry .get_module ("AgentModule" )
45
-
46
- oss_cache = None
47
- if config .get ("oss_cache" , None ):
48
- oss_cache = OssCache (config .oss_cache )
49
- node_parser = module_registry .get_module ("NodeParserModule" )
50
-
51
- self .data_loader = RagDataLoader (
52
- self .data_reader_factory , node_parser , self .index , oss_cache
53
- )
54
30
self .logger .info ("RagApplication initialized successfully." )
55
31
56
32
def reload (self , config ):
57
33
self .initialize (config )
58
34
self .logger .info ("RagApplication reloaded successfully." )
59
35
60
36
# TODO: 大量文件上传实现异步添加
61
- def load_knowledge (self , file_dir , enable_qa_extraction = False ):
62
- self .data_loader .load (file_dir , enable_qa_extraction )
37
+ async def load_knowledge (self , file_dir , enable_qa_extraction = False ):
38
+ data_loader = module_registry .get_module_with_config (
39
+ "DataLoaderModule" , self .config
40
+ )
41
+ await data_loader .aload (file_dir , enable_qa_extraction )
63
42
64
43
async def aquery_retrieval (self , query : RetrievalQuery ) -> RetrievalResponse :
65
44
if not query .question :
66
45
return RetrievalResponse (docs = [])
67
46
68
47
query_bundle = QueryBundle (query .question )
69
- node_results = await self .query_engine .aretrieve (query_bundle )
48
+
49
+ query_engine = module_registry .get_module_with_config (
50
+ "QueryEngineModule" , self .config
51
+ )
52
+ node_results = await query_engine .aretrieve (query_bundle )
70
53
71
54
docs = [
72
55
ContextDoc (
@@ -96,11 +79,24 @@ async def aquery(self, query: RagQuery) -> RagResponse:
96
79
answer = "Empty query. Please input your question." , session_id = session_id
97
80
)
98
81
99
- query_chat_engine = self .chat_engine_factory .get_chat_engine (
82
+ sessioned_config = self .config
83
+ if query .vector_db and query .vector_db .faiss_path :
84
+ sessioned_config = self .config .copy ()
85
+ sessioned_config .index .update ({"persist_path" : query .vector_db .faiss_path })
86
+ print (sessioned_config )
87
+
88
+ chat_engine_factory = module_registry .get_module_with_config (
89
+ "ChatEngineFactoryModule" , sessioned_config
90
+ )
91
+ query_chat_engine = chat_engine_factory .get_chat_engine (
100
92
session_id , query .chat_history
101
93
)
102
94
response = await query_chat_engine .achat (query .question )
103
- self .chat_store .persist ()
95
+
96
+ chat_store = module_registry .get_module_with_config (
97
+ "ChatStoreModule" , sessioned_config
98
+ )
99
+ chat_store .persist ()
104
100
return RagResponse (answer = response .response , session_id = session_id )
105
101
106
102
async def aquery_llm (self , query : LlmQuery ) -> LlmResponse :
@@ -122,11 +118,18 @@ async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
122
118
answer = "Empty query. Please input your question." , session_id = session_id
123
119
)
124
120
125
- llm_chat_engine = self .llm_chat_engine_factory .get_chat_engine (
121
+ llm_chat_engine_factory = module_registry .get_module_with_config (
122
+ "LlmChatEngineFactoryModule" , self .config
123
+ )
124
+ llm_chat_engine = llm_chat_engine_factory .get_chat_engine (
126
125
session_id , query .chat_history
127
126
)
128
127
response = await llm_chat_engine .achat (query .question )
129
- self .chat_store .persist ()
128
+
129
+ chat_store = module_registry .get_module_with_config (
130
+ "ChatStoreModule" , self .config
131
+ )
132
+ chat_store .persist ()
130
133
return LlmResponse (answer = response .response , session_id = session_id )
131
134
132
135
async def aquery_agent (self , query : LlmQuery ) -> LlmResponse :
@@ -143,11 +146,18 @@ async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
143
146
if not query .question :
144
147
return LlmResponse (answer = "Empty query. Please input your question." )
145
148
146
- response = await self .agent .achat (query .question )
149
+ agent = module_registry .get_module_with_config ("AgentModule" , self .config )
150
+ response = await agent .achat (query .question )
147
151
return LlmResponse (answer = response .response )
148
152
149
153
async def batch_evaluate_retrieval_and_response (self , type ):
150
- batch_eval = BatchEvaluator (self .config , self .retriever , self .query_engine )
154
+ retriever = module_registry .get_module_with_config (
155
+ "RetrieverModule" , self .config
156
+ )
157
+ query_engine = module_registry .get_module_with_config (
158
+ "QueryEngineModule" , self .config
159
+ )
160
+ batch_eval = BatchEvaluator (self .config , retriever , query_engine )
151
161
df , eval_res_avg = await batch_eval .batch_retrieval_response_aevaluation (
152
162
type = type , workers = 2 , save_to_file = True
153
163
)
0 commit comments