Skip to content

Commit

Permalink
Merge branch 'feature' into moria97-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
moria97 authored Jun 3, 2024
2 parents b128023 + 8c13568 commit d27ea28
Show file tree
Hide file tree
Showing 12 changed files with 504 additions and 263 deletions.
24 changes: 24 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
FROM python:3.10-slim AS builder

RUN pip3 install poetry

ENV POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_IN_PROJECT=1 \
POETRY_VIRTUALENVS_CREATE=1 \
POETRY_CACHE_DIR=/tmp/poetry_cache

WORKDIR /app
COPY . .

RUN poetry install && rm -rf $POETRY_CACHE_DIR

FROM python:3.10-slim AS prod
ENV VIRTUAL_ENV=/app/.venv \
PATH="/app/.venv/bin:$PATH"

RUN apt-get update && apt-get install -y libgl1 libglib2.0-0

WORKDIR /app
COPY . .
COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENTRYPOINT ["pai_rag", "run"]
35 changes: 0 additions & 35 deletions docker/Dockerfile

This file was deleted.

48 changes: 0 additions & 48 deletions docker/README.md

This file was deleted.

Empty file removed docker/data_loader/Dockerfile
Empty file.
584 changes: 429 additions & 155 deletions poetry.lock

Large diffs are not rendered by default.

22 changes: 18 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ readme = "README.md"
python = ">=3.10.0,<3.12"
fastapi = "^0.110.1"
uvicorn = "^0.29.0"
llama-index-core = "^0.10.29"
llama-index-core = ">=0.10.29,<=0.10.39"
llama-index-embeddings-openai = "^0.1.7"
llama-index-embeddings-azure-openai = "^0.1.7"
llama-index-embeddings-dashscope = "^0.1.3"
Expand All @@ -37,13 +37,22 @@ pytest = "^8.1.1"
llama-index-retrievers-bm25 = "^0.1.3"
jieba = "^0.42.1"
llama-index-embeddings-huggingface = "^0.2.0"
llama-index-postprocessor-flag-embedding-reranker = "^0.1.2"
flagembedding = {git = "https://github.com/FlagOpen/FlagEmbedding.git"}
llama-index-postprocessor-flag-embedding-reranker = "^0.1.3"
flagembedding = "^1.2.10"
sentencepiece = "^0.2.0"
oss2 = "^2.18.5"
asgi-correlation-id = "^4.3.1"
openinference-instrumentation-llama-index = "1.3.0"
torch = "2.2.2"
torch = [
{version = "2.3.0+cpu", source = "pytorch_cpu", platform = "linux"},
{version = "2.3.0+cpu", source = "pytorch_cpu", platform = "win32"},
{version = "2.2.2", platform = "darwin"}
]
torchvision = [
{version = "0.18.0+cpu", source = "pytorch_cpu", platform = "linux"},
{version = "0.18.0+cpu", source = "pytorch_cpu", platform = "win32"},
{version = "0.17.2", platform = "darwin"}
]
openpyxl = "^3.1.2"
pdf2image = "^1.17.0"
llama-index-storage-chat-store-redis = "^0.1.3"
Expand All @@ -67,5 +76,10 @@ load_data = "pai_rag.data.rag_datapipeline:run"
load_easyocr_model = "pai_rag.utils.download_easyocr_models:download_easyocr_models"
evaluation = "pai_rag.evaluations.batch_evaluator:run"

[[tool.poetry.source]]
name = "pytorch_cpu"
url = "https://download.pytorch.org/whl/cpu"
priority = "explicit"

[tool.pytest.ini_options]
asyncio_mode = "auto"
8 changes: 6 additions & 2 deletions src/pai_rag/app/web/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def respond(input_elements: List[Any]):
response = rag_client.query_vector(msg)
else:
response = rag_client.query(msg, session_id=current_session_id)
print("history======:", update_dict["include_history"])
if update_dict["include_history"]:
current_session_id = response.session_id
else:
Expand Down Expand Up @@ -384,7 +383,12 @@ def change_llm_src(value):
)
# similarity_cutoff = gr.Slider(minimum=0, maximum=1, step=0.01,elem_id="similarity_cutoff",value=view_model.similarity_cutoff, label="Similarity Distance Threshold (The more similar the vectors, the smaller the value.)")
rerank_model = gr.Radio(
["No Rerank", "bge-reranker-base", "LLMRerank"],
[
"no-reranker",
"bge-reranker-base",
"bge-reranker-large",
"llm-reranker",
],
label="Re-Rank Model (Note: It will take a long time to load the model when using it for the first time.)",
elem_id="rerank_model",
value=view_model.rerank_model,
Expand Down
24 changes: 14 additions & 10 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class ViewModel(BaseModel):

similarity_top_k: int = 5
# similarity_cutoff: float = 0.3
rerank_model: str = "No Rerank"
rerank_model: str = "no-reranker"
retrieval_mode: str = "hybrid" # hybrid / embedding / keyword
query_engine_type: str = "RetrieverQueryEngine"
BM25_weight: float = 0.5
Expand Down Expand Up @@ -181,13 +181,15 @@ def sync_app_config(self, config):
# if "Similarity" in config["postprocessor"]:
# self.similarity_cutoff = config["postprocessor"].get("similarity_cutoff", 0.1)

rerank_model = config["postprocessor"].get("rerank_model", "No Rerank")
if rerank_model == "llm":
self.rerank_model = "LLMRerank"
elif rerank_model == "bge_reranker_base":
rerank_model = config["postprocessor"].get("rerank_model", "no-reranker")
if rerank_model == "llm-reranker":
self.rerank_model = "llm-reranker"
elif rerank_model == "bge-reranker-base":
self.rerank_model = "bge-reranker-base"
elif rerank_model == "bge-reranker-large":
self.rerank_model = "bge-reranker-large"
else:
self.rerank_model = "No Rerank"
self.rerank_model = "no-reranker"

self.synthesizer_type = config["synthesizer"].get("type", "SimpleSummarize")
self.text_qa_template = config["synthesizer"].get("text_qa_template", None)
Expand Down Expand Up @@ -278,13 +280,15 @@ def to_app_config(self):
config["retriever"]["retrieval_mode"] = "keyword"

# config["postprocessor"]["similarity_cutoff"] = self.similarity_cutoff
if self.rerank_model == "LLMRerank":
config["postprocessor"]["rerank_model"] = "llm"
config["postprocessor"]["top_n"] = 3
if self.rerank_model == "llm-reranker":
config["postprocessor"]["rerank_model"] = "llm-reranker"
elif self.rerank_model == "bge-reranker-base":
config["postprocessor"]["rerank_model"] = "bge-reranker-base"
elif self.rerank_model == "bge-reranker-large":
config["postprocessor"]["rerank_model"] = "bge-reranker-large"
else:
config["postprocessor"]["rerank_model"] = "no rerank"
config["postprocessor"]["rerank_model"] = "no-reranker"
config["postprocessor"]["top_n"] = 3

config["synthesizer"]["type"] = self.synthesizer_type
config["synthesizer"]["text_qa_template"] = self.text_qa_template
Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/config/settings.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ chunk_size = 500
chunk_overlap = 10

[rag.postprocessor]
rerank_model = "no rerank" # [no rerank, llm, bge_reranker_base]
rerank_model = "no-reranker" # [no-reranker, bge-reranker-base, bge-reranker-large, llm-reranker]
top_n = 2

[rag.query_engine]
Expand Down
4 changes: 3 additions & 1 deletion src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ContextDoc,
RetrievalResponse,
)
from llama_index.core.schema import QueryBundle

import logging

Expand Down Expand Up @@ -64,7 +65,8 @@ async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse:

session_id = correlation_id.get() or DEFAULT_SESSION_ID
self.logger.info(f"Get session ID: {session_id}.")
node_results = await self.retriever.aretrieve(query.question)
query_bundle = QueryBundle(query.question)
node_results = await self.query_engine.aretrieve(query_bundle)

docs = [
ContextDoc(
Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/core/rag_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ def update(self, new_value: Dynaconf):
def persist(self):
"""Save configuration to file."""
data = self.config.as_dict()
os.makedirs("output", exist_ok=True)
os.makedirs("localdata", exist_ok=True)
loaders.write(GENERATED_CONFIG_FILE_NAME, DynaBox(data).to_dict(), merge=True)
14 changes: 8 additions & 6 deletions src/pai_rag/modules/postprocessor/postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,24 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
)

rerank_model = config.get("rerank_model", "")
if rerank_model == "llm":
if rerank_model == "llm-reranker":
top_n = config.get("top_n", DEFAULT_RANK_TOP_N)
logger.info(f"[PostProcessor]: LLMRerank used with top_n {top_n}.")
logger.info(f"[PostProcessor]: Llm reranker used with top_n {top_n}.")
post_processors.append(LLMRerank(top_n=top_n, llm=llm))

elif rerank_model == "bge_reranker_base":
elif (
rerank_model == "bge-reranker-base" or rerank_model == "bge-reranker-large"
):
model_dir = config.get("rerank_model_dir", DEFAULT_MODEL_DIR)
model_name = config.get("rerank_model_name", DEFAULT_RANK_MODEL)
model_name = config.get("rerank_model_name", rerank_model)
model = os.path.join(model_dir, model_name)
top_n = config.get("top_n", DEFAULT_RANK_TOP_N)
logger.info(
f"[PostProcessor]: Rerank used with top_n {top_n}, model {model_name}."
f"[PostProcessor]: Reranker model used with top_n {top_n}, model {model_name}."
)
post_processors.append(FlagEmbeddingReranker(model=model, top_n=top_n))

else:
logger.info("[PostProcessor]: No Rerank used.")
logger.info("[PostProcessor]: No Reranker used.")

return post_processors

0 comments on commit d27ea28

Please sign in to comment.