Skip to content

Commit

Permalink
refactor(llm): async completion -> async acompletion
Browse files Browse the repository at this point in the history
  • Loading branch information
idiotWu committed Jan 25, 2025
1 parent 737d40d commit e2f2a85
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion examples/invoice_organizer/invoice_organizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def _identify_invoice_message(self, ctx: Context, message: Message):
),
]

response = await ctx.llm.completion(
response = await ctx.llm.acompletion(
messages=prompts,
tools=self._processor.tools,
tool_choice="auto",
Expand Down
2 changes: 1 addition & 1 deletion examples/invoice_organizer/search_query_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def save_configs(self, ctx: Context, **composed_configs: Dict[str, Any]):
),
]

response = await ctx.llm.completion(
response = await ctx.llm.acompletion(
messages=messages,
tools=[fn_reg.get_tool_param()],
tool_choice="required",
Expand Down
4 changes: 2 additions & 2 deletions npiai/context/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def _parse_instruction(self, ctx: Context, instruction: str):
),
]

response = await ctx.llm.completion(
response = await ctx.llm.acompletion(
messages=messages,
tools=[fn_reg.get_tool_param()],
tool_choice="required",
Expand Down Expand Up @@ -174,7 +174,7 @@ async def _finalize_configs(self, ctx: Context, configs: Dict[str, Any]):
),
]

response = await ctx.llm.completion(
response = await ctx.llm.acompletion(
messages=messages,
tools=[fn_reg.get_tool_param()],
tool_choice="required",
Expand Down
2 changes: 1 addition & 1 deletion npiai/core/browser/_navigator.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ async def chat(self, ctx: Context, instruction: str) -> str:
return f"Maximum number of steps reached. Last response was: {response_str}"

async def _call_llm(self, ctx: Context, task: Task) -> str:
response = await ctx.llm.completion(
response = await ctx.llm.acompletion(
messages=task.conversations(),
max_tokens=4096,
)
Expand Down
2 changes: 1 addition & 1 deletion npiai/core/tool/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def chat(

async def _call_llm(self, ctx: Context, task: Task) -> str:
while True:
response = await ctx.llm.completion(
response = await ctx.llm.acompletion(
messages=task.conversations(),
tools=self._tool.tools,
tool_choice="auto",
Expand Down
11 changes: 6 additions & 5 deletions npiai/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from enum import Enum
from typing import List, Union
import os
import asyncio
from litellm import acompletion, ModelResponse, CustomStreamWrapper, drop_params
from litellm import completion, acompletion, ModelResponse, CustomStreamWrapper


class Provider(Enum):
Expand All @@ -27,13 +26,15 @@ def get_provider(self) -> Provider:
return self.provider

# TODO: kwargs typings
async def completion(self, **kwargs) -> Union[ModelResponse, CustomStreamWrapper]:
async def acompletion(self, **kwargs) -> ModelResponse | CustomStreamWrapper:
return await acompletion(
model=self.model, api_key=self.api_key, drop_params=True, **kwargs
)

def completion_sync(self, **kwargs) -> Union[ModelResponse, CustomStreamWrapper]:
return asyncio.run(self.completion(**kwargs))
def completion(self, **kwargs) -> ModelResponse | CustomStreamWrapper:
return completion(
model=self.model, api_key=self.api_key, drop_params=True, **kwargs
)


class OpenAI(LLM):
Expand Down
2 changes: 1 addition & 1 deletion npiai/utils/llm_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def llm_summarize(
final_response_content = ""

while True:
response = await llm.completion(
response = await llm.acompletion(
messages=messages_copy,
max_tokens=4096,
# use fixed temperature and seed to ensure deterministic results
Expand Down
2 changes: 1 addition & 1 deletion npiai/utils/llm_tool_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def llm_tool_call(
if fn_reg.model is None:
raise RuntimeError("Unable to modeling tool function")

response = await llm.completion(
response = await llm.acompletion(
messages=messages,
tools=[fn_reg.get_tool_param()],
max_tokens=4096,
Expand Down

0 comments on commit e2f2a85

Please sign in to comment.