Skip to content

Commit c100697

Browse files
author
nik
committed
Add json model for custom llm endpoints
1 parent 6450e1a commit c100697

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

adala/runtimes/_litellm.py

+12
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def init_runtime(self) -> "Runtime":
174174
raise ValueError(
175175
f'Failed to check availability of requested model "{self.model}": {e}'
176176
)
177+
177178
return self
178179

179180
def get_llm_response(self, messages: List[Dict[str, str]]) -> str:
@@ -343,6 +344,7 @@ def init_runtime(self) -> "Runtime":
343344
raise ValueError(
344345
f'Failed to check availability of requested model "{self.model}": {e}'
345346
)
347+
346348
return self
347349

348350
@field_validator("concurrency", mode="before")
@@ -355,6 +357,10 @@ def check_concurrency(cls, value) -> int:
355357
)
356358
return value
357359

360+
@property
361+
def is_custom_openai_endpoint(self) -> bool:
362+
return self.model.startswith("openai/") and self.model_extra.get("base_url")
363+
358364
async def batch_to_batch(
359365
self,
360366
batch: InternalDataFrame,
@@ -383,6 +389,12 @@ async def batch_to_batch(
383389
).tolist()
384390

385391
retries = AsyncRetrying(**RETRY_POLICY)
392+
if self.is_custom_openai_endpoint:
393+
# TODO: most of the custom openai endpoints do not support tools mode but json mode
394+
# we should make it more performant by not creating instructor client on every request
395+
async_instructor_client = instructor.from_litellm(
396+
litellm.acompletion, mode=instructor.Mode.JSON
397+
)
386398

387399
tasks = [
388400
asyncio.ensure_future(

adala/skills/collection/entity_extraction.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def extract_indices(self, df):
267267
input_field_name = self._get_input_field_name()
268268
output_field_name = self._get_output_field_name()
269269
for i, row in df.iterrows():
270-
if row.get('_adala_error'):
270+
if row.get("_adala_error"):
271271
logger.warning(f"Error in row {i}: {row['_adala_message']}")
272272
continue
273273
text = row[input_field_name]

0 commit comments

Comments
 (0)