Skip to content

Commit 7cb484e

Browse files
authored
Merge pull request #162 from athina-ai/vivek/7th-jan-bug-fixes
Vivek/7th jan bug fixes
2 parents b3c7b2f + 4e1525b commit 7cb484e

6 files changed

+146
-85
lines changed

athina/steps/api.py

+40-22
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,6 @@
88
from jinja2 import Environment
99

1010

11-
def prepare_input_data(data: Dict[str, Any]) -> Dict[str, Any]:
12-
"""Prepare input data by converting complex types to JSON strings."""
13-
return {
14-
key: json.dumps(value) if isinstance(value, (list, dict)) else value
15-
for key, value in data.items()
16-
}
17-
18-
1911
def prepare_template_data(
2012
env: Environment,
2113
template_dict: Optional[Dict[str, str]],
@@ -31,6 +23,19 @@ def prepare_template_data(
3123
return prepared_dict
3224

3325

26+
def debug_json_structure(body_str: str, error: json.JSONDecodeError) -> dict:
27+
"""Analyze JSON structure and identify problematic keys."""
28+
lines = body_str.split("\n")
29+
error_line_num = error.lineno - 1
30+
31+
return {
32+
"original_body": body_str,
33+
"problematic_line": (
34+
lines[error_line_num] if error_line_num < len(lines) else None
35+
),
36+
}
37+
38+
3439
def prepare_body(
3540
env: Environment, body_template: Optional[str], input_data: Dict[str, Any]
3641
) -> Optional[str]:
@@ -112,31 +117,44 @@ async def execute_async(self, input_data: Any) -> Union[Dict[str, Any], None]:
112117
)
113118
# Prepare the environment and input data
114119
self.env = self._create_jinja_env()
115-
prepared_input_data = prepare_input_data(input_data)
116120

117121
# Prepare request components
118-
prepared_body = prepare_body(self.env, self.body, prepared_input_data)
119-
prepared_headers = prepare_template_data(
120-
self.env, self.headers, prepared_input_data
121-
)
122-
prepared_params = prepare_template_data(
123-
self.env, self.params, prepared_input_data
124-
)
122+
prepared_body = prepare_body(self.env, self.body, input_data)
123+
prepared_headers = prepare_template_data(self.env, self.headers, input_data)
124+
prepared_params = prepare_template_data(self.env, self.params, input_data)
125+
# Prepare the URL by rendering the template
126+
prepared_url = self.env.from_string(self.url).render(**input_data)
125127

126128
timeout = aiohttp.ClientTimeout(total=self.timeout)
127129

128130
for attempt in range(self.retries):
129131
try:
130132
async with aiohttp.ClientSession(timeout=timeout) as session:
131-
json_body = (
132-
json.loads(prepared_body, strict=False)
133-
if prepared_body
134-
else None
135-
)
133+
try:
134+
json_body = (
135+
json.loads(prepared_body, strict=False)
136+
if prepared_body
137+
else None
138+
)
139+
except json.JSONDecodeError as e:
140+
debug_info = debug_json_structure(prepared_body, e)
141+
return self._create_step_result(
142+
status="error",
143+
data=json.dumps(
144+
{
145+
"message": f"Failed to parse request body as JSON",
146+
"error_type": "JSONDecodeError",
147+
"error_details": str(e),
148+
"debug_info": debug_info,
149+
},
150+
indent=2,
151+
),
152+
start_time=start_time,
153+
)
136154

137155
async with session.request(
138156
method=self.method,
139-
url=self.url,
157+
url=prepared_url,
140158
headers=prepared_headers,
141159
params=prepared_params,
142160
json=json_body,

athina/steps/chroma_retrieval.py

+19-27
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class ChromaRetrieval(Step):
2626
port (int): The port of the Chroma server.
2727
collection_name (str): The name of the Chroma collection.
2828
limit (int): The maximum number of results to fetch.
29-
input_column (str): The column name in the input data.
29+
user_query (str): the query which will be sent to chroma.
3030
openai_api_key (str): The OpenAI API key.
3131
auth_type (str): The authentication type for the Chroma server (e.g., "token" or "basic").
3232
auth_credentials (str): The authentication credentials for the Chroma server.
@@ -35,9 +35,8 @@ class ChromaRetrieval(Step):
3535
host: str
3636
port: int
3737
collection_name: str
38-
key: str
3938
limit: int
40-
input_column: str
39+
user_query: str
4140
openai_api_key: str
4241
auth_type: Optional[AuthType] = None
4342
auth_credentials: Optional[str] = None
@@ -76,12 +75,6 @@ def __init__(self, *args, **kwargs):
7675
self._collection = self._client.get_collection(
7776
name=self.collection_name, embedding_function=self._embedding_function
7877
)
79-
# Create a custom Jinja2 environment with double curly brace delimiters and PreserveUndefined
80-
self.env = Environment(
81-
variable_start_string="{{",
82-
variable_end_string="}}",
83-
undefined=PreserveUndefined,
84-
)
8578

8679
"""Makes a call to chromadb collection to fetch relevant chunks"""
8780

@@ -95,31 +88,30 @@ def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
9588
start_time=start_time,
9689
)
9790

98-
query = input_data.get(self.input_column)
99-
if query is None:
91+
self.env = self._create_jinja_env()
92+
93+
query_text = self.env.from_string(self.user_query).render(**input_data)
94+
95+
if query_text is None:
10096
return self._create_step_result(
101-
status="error",
102-
data="Input column not found.",
103-
start_time=start_time,
97+
status="error", data="Query text is Empty.", start_time=start_time
10498
)
10599

106100
try:
107-
if isinstance(query, list) and isinstance(query[0], float):
108-
response = self._collection.query(
109-
query_embeddings=[query],
110-
n_results=self.limit,
111-
include=["documents", "metadatas", "distances"],
112-
)
113-
else:
114-
response = self._collection.query(
115-
query_texts=[query],
116-
n_results=self.limit,
117-
include=["documents", "metadatas", "distances"],
101+
response = self._collection.query(
102+
query_texts=[query_text],
103+
n_results=self.limit,
104+
include=["documents", "metadatas", "distances"],
105+
)
106+
result = [
107+
{"text": text, "score": distance}
108+
for text, distance in zip(
109+
response["documents"][0], response["distances"][0]
118110
)
119-
111+
]
120112
return self._create_step_result(
121113
status="success",
122-
data=response["documents"][0],
114+
data=result,
123115
start_time=start_time,
124116
)
125117
except Exception as e:

athina/steps/pinecone_retrieval.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# Step to make a call to pinecone index to fetch relevent chunks
2-
import pinecone
31
from typing import Optional, Union, Dict, Any
42

53
from pydantic import Field, PrivateAttr
@@ -9,6 +7,7 @@
97
from llama_index.core import VectorStoreIndex
108
from llama_index.core.retrievers import VectorIndexRetriever
119
import time
10+
import traceback
1211

1312

1413
class PineconeRetrieval(Step):
@@ -22,36 +21,48 @@ class PineconeRetrieval(Step):
2221
metadata_filters: filters to apply to metadata.
2322
environment: pinecone environment.
2423
api_key: api key for the pinecone server
25-
input_column: column name in the input data
24+
user_query: the query which will be sent to pinecone
2625
env: jinja environment
2726
"""
2827

2928
index_name: str
3029
top_k: int
3130
api_key: str
32-
input_column: str
31+
user_query: str
3332
env: Environment = None
3433
metadata_filters: Optional[Dict[str, Any]] = None
3534
namespace: Optional[str] = None
3635
environment: Optional[str] = None
36+
text_key: Optional[str] = None # Optional parameter for text key
3737
_vector_store: PineconeVectorStore = PrivateAttr()
3838
_vector_index: VectorStoreIndex = PrivateAttr()
3939
_retriever: VectorIndexRetriever = PrivateAttr()
4040

4141
def __init__(self, *args, **kwargs):
4242
super().__init__(*args, **kwargs)
43+
# Initialize base vector store arguments
4344
vector_store_args = {"api_key": self.api_key, "index_name": self.index_name}
45+
# Add text_key only if specified by user
46+
if self.text_key:
47+
vector_store_args["text_key"] = self.text_key
4448

49+
# Only add environment if it's provided
4550
if self.environment is not None:
4651
vector_store_args["environment"] = self.environment
4752

48-
if self.namespace is not None:
53+
# Only add namespace if it's provided and not None
54+
if self.namespace:
4955
vector_store_args["namespace"] = self.namespace
5056

57+
# Initialize vector store with filtered arguments
5158
self._vector_store = PineconeVectorStore(**vector_store_args)
59+
60+
# Create vector index from store
5261
self._vector_index = VectorStoreIndex.from_vector_store(
5362
vector_store=self._vector_store
5463
)
64+
65+
# Initialize retriever with specified top_k
5566
self._retriever = VectorIndexRetriever(
5667
index=self._vector_index, similarity_top_k=self.top_k
5768
)
@@ -60,9 +71,10 @@ class Config:
6071
arbitrary_types_allowed = True
6172

6273
def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
63-
"""makes a call to pinecone index to fetch relevent chunks"""
74+
"""Makes a call to pinecone index to fetch relevant chunks"""
6475
start_time = time.perf_counter()
6576

77+
# Validate input data
6678
if input_data is None:
6779
input_data = {}
6880

@@ -73,26 +85,36 @@ def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
7385
start_time=start_time,
7486
)
7587

76-
input_text = input_data.get(self.input_column, None)
88+
# Create Jinja environment and render query
89+
self.env = self._create_jinja_env()
90+
query_text = self.env.from_string(self.user_query).render(**input_data)
7791

78-
if input_text is None:
92+
if not query_text:
7993
return self._create_step_result(
8094
status="error",
81-
data="Input column not found.",
95+
data="Query text is Empty.",
8296
start_time=start_time,
8397
)
8498

8599
try:
86-
response = self._retriever.retrieve(input_text)
87-
result = [node.get_content() for node in response]
100+
# Perform retrieval
101+
response = self._retriever.retrieve(query_text)
102+
result = [
103+
{
104+
"text": node.get_content(),
105+
"score": node.get_score(),
106+
}
107+
for node in response
108+
]
109+
return self._create_step_result(
110+
status="success", data=result, start_time=start_time
111+
)
88112
return self._create_step_result(
89113
status="success",
90114
data=result,
91115
start_time=start_time,
92116
)
93117
except Exception as e:
94-
import traceback
95-
96118
traceback.print_exc()
97119
print(f"Error during retrieval: {str(e)}")
98120
return self._create_step_result(

athina/steps/qdrant_retrieval.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ class QdrantRetrieval(Step):
2020
url: url of the qdrant server
2121
top_k: How many chunks to fetch.
2222
api_key: api key for the qdrant server
23-
input_column: the query which will be sent to qdrant
23+
user_query: the query which will be sent to qdrant
2424
env: jinja environment
2525
"""
2626

2727
collection_name: str
2828
url: str
2929
top_k: int
3030
api_key: str
31-
input_column: str
31+
user_query: str
3232
env: Environment = None
3333
_qdrant_client: qdrant_client.QdrantClient = PrivateAttr()
3434
_vector_store: QdrantVectorStore = PrivateAttr()
@@ -70,11 +70,11 @@ def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
7070

7171
self.env = self._create_jinja_env()
7272

73-
query_text = self.env.from_string(self.input_column).render(**input_data)
73+
query_text = self.env.from_string(self.user_query).render(**input_data)
7474

7575
if query_text is None:
7676
return self._create_step_result(
77-
status="error", data="Query text not found.", start_time=start_time
77+
status="error", data="Query text is Empty.", start_time=start_time
7878
)
7979

8080
try:
@@ -84,7 +84,13 @@ def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
8484
return self._create_step_result(
8585
status="success", data=[], start_time=start_time
8686
)
87-
result = [node.get_content() for node in response]
87+
result = [
88+
{
89+
"text": node.get_content(),
90+
"score": node.get_score(),
91+
}
92+
for node in response
93+
]
8894
return self._create_step_result(
8995
status="success", data=result, start_time=start_time
9096
)

0 commit comments

Comments
 (0)