Skip to content

Commit

Permalink
Fix update index (#252)
Browse files Browse the repository at this point in the history
* Fix update index

* Fix threshold bug
  • Loading branch information
moria97 authored Oct 22, 2024
1 parent 14cf611 commit c74428a
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 6 deletions.
9 changes: 9 additions & 0 deletions src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ async def aconfig():
return rag_service.get_config()


@router.get("/indexes/{index_name}")
async def get_index(index_name: str):
try:
return index_manager.get_index_by_name(index_name=index_name)
except Exception as ex:
logger.error(f"Get index '{index_name}' failed: {ex} {traceback.format_exc()}")
raise UserInputError(f"Get index '{index_name}' failed: {ex}")


@router.post("/indexes/{index_name}")
async def add_index(index_name: str, index_entry: RagIndexEntry):
try:
Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def to_app_config(self):

config["postprocessor"]["reranker_type"] = self.reranker_type
config["postprocessor"]["reranker_model"] = self.reranker_model
if self.reranker_type == "no-rerank":
if self.reranker_type == "no-reranker":
config["postprocessor"]["similarity_threshold"] = self.similarity_threshold
else:
config["postprocessor"][
Expand Down
6 changes: 4 additions & 2 deletions src/pai_rag/core/rag_index_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self._lock = threading.Lock()

def add_default_index(self, rag_config: RagConfig):
if DEFAULT_INDEX_NAME not in self._index_map:
if DEFAULT_INDEX_NAME not in self._index_map.indexes:
self._index_map.indexes[DEFAULT_INDEX_NAME] = RagIndexEntry(
index_name=DEFAULT_INDEX_NAME,
vector_store_config=rag_config.index.vector_store,
Expand Down Expand Up @@ -116,7 +116,9 @@ def update_index(self, index_entry: RagIndexEntry):
), f"Index name '{index_entry.index_name}' not exists."
self._index_map.indexes[index_entry.index_name] = index_entry
self.save_index_map()
logger.info(f"Index '{index_entry.index_name}' created successfully.")
logger.info(
f"Index '{index_entry.index_name}' updated successfully {self._index_map}."
)

def delete_index(self, index_name: str):
with self._lock:
Expand Down
3 changes: 0 additions & 3 deletions src/pai_rag/core/rag_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def resolve_data_loader(config: RagConfig) -> RagDataLoader:
node_parser = resolve(cls=PaiNodeParser, parser_config=config.node_parser)

embed_model = resolve(cls=PaiEmbedding, embed_config=config.embedding)
Settings.embed_model = embed_model
multimodal_embed_model = None
if config.index.enable_multimodal:
multimodal_embed_model = resolve(
Expand Down Expand Up @@ -128,7 +127,6 @@ def resolve_llm(config: RagConfig) -> PaiLlm:
def resolve_data_analysis_tool(config: RagConfig) -> DataAnalysisTool:
llm = resolve_llm(config)
embed_model = resolve(cls=PaiEmbedding, embed_config=config.embedding)
Settings.embed_model = embed_model

return resolve(
cls=DataAnalysisTool,
Expand Down Expand Up @@ -167,7 +165,6 @@ def resolve_synthesizer(config: RagConfig) -> PaiSynthesizer:

def resolve_vector_index(config: RagConfig) -> PaiVectorStoreIndex:
embed_model = resolve(cls=PaiEmbedding, embed_config=config.embedding)
Settings.embed_model = embed_model
multimodal_embed_model = None
if config.index.enable_multimodal:
multimodal_embed_model = resolve(
Expand Down

0 comments on commit c74428a

Please sign in to comment.