diff --git a/backend/agents.py b/backend/agents.py index 80743e4..b656204 100644 --- a/backend/agents.py +++ b/backend/agents.py @@ -10,7 +10,15 @@ from ks_search_tool import general_search, general_search_async, global_fuzzy_keyword_search from retrieval import Retriever - +def clean_and_parse_json(text: str) -> dict: + try: + # Regex se sirf { ... } ke beech ka content nikalna + match = re.search(r'\{.*\}', text, re.DOTALL) + if match: + return json.loads(match.group()) + return json.loads(text) + except (json.JSONDecodeError, TypeError, AttributeError): + return {} # LLM (Gemini) client setup try: from google import genai @@ -142,7 +150,7 @@ async def call_gemini_for_keywords(query: str) -> List[str]: response_mime_type="application/json", ) resp = client.models.generate_content(model=FLASH_LITE_MODEL, contents=[prompt], config=cfg) - out = json.loads(resp.text or "{}") + out = clean_and_parse_json(resp.text or "{}") kws = out.get("keywords", []) or [] normalized: List[str] = [] for k in kws: @@ -222,7 +230,7 @@ async def call_gemini_detect_intents(query: str, history: List[str]) -> List[str response_mime_type="application/json", ) resp = client.models.generate_content(model=FLASH_LITE_MODEL, contents=[prompt], config=cfg) - out = json.loads(resp.text or "{}") + out = clean_and_parse_json(resp.text or "{}") intents = [i for i in out.get("intents", []) if i in allowed] return list(dict.fromkeys(intents or [QueryIntent.DATA_DISCOVERY.value]))[:6] @@ -352,24 +360,12 @@ async def run(self, query: str, keywords: List[str], want: int = 45) -> dict: class VectorSearchAgent: def __init__(self): - self.retriever = Retriever() - self.is_enabled = self.retriever.is_enabled - + # self.retriever = Retriever() + self.is_enabled = False + print(" -> Vector Search disabled (Torch bypass active)") async def run(self, query: str, want: int, context: Optional[Dict] = None) -> List[dict]: - if not self.is_enabled: - return [] - try: - # Run the synchronous search in a thread to make it async - results = await asyncio.to_thread( - self.retriever.search, - query=query, - top_k=min(want, 50), - context={"raw": True} - ) - return [item.__dict__ if hasattr(item, "__dict__") else item for item in results] - except Exception as e: - print(f"Vector search error: {e}") - return [] + + return [] async def extract_keywords_and_rewrite(state: AgentState) -> AgentState: @@ -426,27 +422,36 @@ async def execute_search(state: AgentState) -> Dict[str, Any]: def fuse_results(state: AgentState) -> AgentState: - print("--- Node: Result Fusion ---") + print("--- Node: Result Fusion (RRF) ---") ks_results = state.get("ks_results", []) vector_results = state.get("vector_results", []) - combined: Dict[str, dict] = {} - for res in vector_results: - if isinstance(res, dict): - doc_id = res.get("id") or res.get("_id") or f"vec_{len(combined)}" - combined[doc_id] = {**res, "final_score": res.get("similarity", 0) * 0.6} - for res in ks_results: - if isinstance(res, dict): - doc_id = res.get("_id") or res.get("id") or f"ks_{len(combined)}" - if doc_id in combined: - combined[doc_id]["final_score"] += res.get("_score", 0) * 0.4 - else: - combined[doc_id] = {**res, "final_score": res.get("_score", 0) * 0.4} - all_sorted = sorted(combined.values(), key=lambda x: x.get("final_score", 0), reverse=True) - print(f"Results summary: KS={len(ks_results)}, Vector={len(vector_results)}, Combined={len(all_sorted)}") + + rrf_scores: Dict[str, float] = {} + k = 60 + def get_id(res): + return res.get("_id") or res.get("id") + for rank, res in enumerate(vector_results, start=1): + doc_id = get_id(res) + if doc_id: + # RRF Formula: 1 / (k + rank) + rrf_scores[doc_id] = rrf_scores.get(doc_id, 0) + (1.0 / (k + rank)) + for rank, res in enumerate(ks_results, start=1): + doc_id = get_id(res) + if doc_id: + rrf_scores[doc_id] = rrf_scores.get(doc_id, 0) + (1.0 / (k + rank)) + all_docs = {get_id(r): r for r in vector_results + ks_results if get_id(r)} + + combined = [] + for doc_id, score in rrf_scores.items(): + doc = all_docs[doc_id].copy() + doc["final_score"] = score + combined.append(doc) + all_sorted = sorted(combined, key=lambda x: x.get("final_score", 0), reverse=True) + + print(f"RRF Summary: Unique results={len(all_sorted)}") page_size = 15 return {**state, "all_results": all_sorted, "final_results": all_sorted[:page_size]} - async def generate_final_response(state: AgentState) -> AgentState: print("--- Node: Response Generation ---") intents = state.get("intents", [QueryIntent.DATA_DISCOVERY.value])