Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 41 additions & 36 deletions backend/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@

from ks_search_tool import general_search, general_search_async, global_fuzzy_keyword_search
from retrieval import Retriever

def clean_and_parse_json(text: str) -> dict:
try:
# Regex se sirf { ... } ke beech ka content nikalna
match = re.search(r'\{.*\}', text, re.DOTALL)
if match:
return json.loads(match.group())
return json.loads(text)
except (json.JSONDecodeError, TypeError, AttributeError):
return {}
# LLM (Gemini) client setup
try:
from google import genai
Expand Down Expand Up @@ -142,7 +150,7 @@ async def call_gemini_for_keywords(query: str) -> List[str]:
response_mime_type="application/json",
)
resp = client.models.generate_content(model=FLASH_LITE_MODEL, contents=[prompt], config=cfg)
out = json.loads(resp.text or "{}")
out = clean_and_parse_json(resp.text or "{}")
kws = out.get("keywords", []) or []
normalized: List[str] = []
for k in kws:
Expand Down Expand Up @@ -222,7 +230,7 @@ async def call_gemini_detect_intents(query: str, history: List[str]) -> List[str
response_mime_type="application/json",
)
resp = client.models.generate_content(model=FLASH_LITE_MODEL, contents=[prompt], config=cfg)
out = json.loads(resp.text or "{}")
out = clean_and_parse_json(resp.text or "{}")
intents = [i for i in out.get("intents", []) if i in allowed]
return list(dict.fromkeys(intents or [QueryIntent.DATA_DISCOVERY.value]))[:6]

Expand Down Expand Up @@ -352,24 +360,12 @@ async def run(self, query: str, keywords: List[str], want: int = 45) -> dict:

class VectorSearchAgent:
def __init__(self):
self.retriever = Retriever()
self.is_enabled = self.retriever.is_enabled

# self.retriever = Retriever()
self.is_enabled = False
print(" -> Vector Search disabled (Torch bypass active)")
async def run(self, query: str, want: int, context: Optional[Dict] = None) -> List[dict]:
if not self.is_enabled:
return []
try:
# Run the synchronous search in a thread to make it async
results = await asyncio.to_thread(
self.retriever.search,
query=query,
top_k=min(want, 50),
context={"raw": True}
)
return [item.__dict__ if hasattr(item, "__dict__") else item for item in results]
except Exception as e:
print(f"Vector search error: {e}")
return []

return []


async def extract_keywords_and_rewrite(state: AgentState) -> AgentState:
Expand Down Expand Up @@ -426,27 +422,36 @@ async def execute_search(state: AgentState) -> Dict[str, Any]:


def fuse_results(state: AgentState) -> AgentState:
print("--- Node: Result Fusion ---")
print("--- Node: Result Fusion (RRF) ---")
ks_results = state.get("ks_results", [])
vector_results = state.get("vector_results", [])
combined: Dict[str, dict] = {}
for res in vector_results:
if isinstance(res, dict):
doc_id = res.get("id") or res.get("_id") or f"vec_{len(combined)}"
combined[doc_id] = {**res, "final_score": res.get("similarity", 0) * 0.6}
for res in ks_results:
if isinstance(res, dict):
doc_id = res.get("_id") or res.get("id") or f"ks_{len(combined)}"
if doc_id in combined:
combined[doc_id]["final_score"] += res.get("_score", 0) * 0.4
else:
combined[doc_id] = {**res, "final_score": res.get("_score", 0) * 0.4}
all_sorted = sorted(combined.values(), key=lambda x: x.get("final_score", 0), reverse=True)
print(f"Results summary: KS={len(ks_results)}, Vector={len(vector_results)}, Combined={len(all_sorted)}")

rrf_scores: Dict[str, float] = {}
k = 60
def get_id(res):
return res.get("_id") or res.get("id")
for rank, res in enumerate(vector_results, start=1):
doc_id = get_id(res)
if doc_id:
# RRF Formula: 1 / (k + rank)
rrf_scores[doc_id] = rrf_scores.get(doc_id, 0) + (1.0 / (k + rank))
for rank, res in enumerate(ks_results, start=1):
doc_id = get_id(res)
if doc_id:
rrf_scores[doc_id] = rrf_scores.get(doc_id, 0) + (1.0 / (k + rank))
all_docs = {get_id(r): r for r in vector_results + ks_results if get_id(r)}

combined = []
for doc_id, score in rrf_scores.items():
doc = all_docs[doc_id].copy()
doc["final_score"] = score
combined.append(doc)
all_sorted = sorted(combined, key=lambda x: x.get("final_score", 0), reverse=True)

print(f"RRF Summary: Unique results={len(all_sorted)}")
page_size = 15
return {**state, "all_results": all_sorted, "final_results": all_sorted[:page_size]}


async def generate_final_response(state: AgentState) -> AgentState:
print("--- Node: Response Generation ---")
intents = state.get("intents", [QueryIntent.DATA_DISCOVERY.value])
Expand Down