From e2f2a85432f4edd48e8b4fb2b62adfd40ed0e7b1 Mon Sep 17 00:00:00 2001 From: Daofeng Wu Date: Sun, 26 Jan 2025 00:20:10 +0900 Subject: [PATCH] refactor(llm): async completion -> async acompletion --- examples/invoice_organizer/invoice_organizer.py | 2 +- examples/invoice_organizer/search_query_configs.py | 2 +- npiai/context/configurator.py | 4 ++-- npiai/core/browser/_navigator.py | 2 +- npiai/core/tool/_agent.py | 2 +- npiai/llm/llm.py | 11 ++++++----- npiai/utils/llm_summarize.py | 2 +- npiai/utils/llm_tool_call.py | 2 +- 8 files changed, 14 insertions(+), 13 deletions(-) diff --git a/examples/invoice_organizer/invoice_organizer.py b/examples/invoice_organizer/invoice_organizer.py index 90a39cd4..0afb439b 100644 --- a/examples/invoice_organizer/invoice_organizer.py +++ b/examples/invoice_organizer/invoice_organizer.py @@ -127,7 +127,7 @@ async def _identify_invoice_message(self, ctx: Context, message: Message): ), ] - response = await ctx.llm.completion( + response = await ctx.llm.acompletion( messages=prompts, tools=self._processor.tools, tool_choice="auto", diff --git a/examples/invoice_organizer/search_query_configs.py b/examples/invoice_organizer/search_query_configs.py index a402d285..bc18f1ca 100644 --- a/examples/invoice_organizer/search_query_configs.py +++ b/examples/invoice_organizer/search_query_configs.py @@ -146,7 +146,7 @@ async def save_configs(self, ctx: Context, **composed_configs: Dict[str, Any]): ), ] - response = await ctx.llm.completion( + response = await ctx.llm.acompletion( messages=messages, tools=[fn_reg.get_tool_param()], tool_choice="required", diff --git a/npiai/context/configurator.py b/npiai/context/configurator.py index a05417bd..b4ace7d5 100644 --- a/npiai/context/configurator.py +++ b/npiai/context/configurator.py @@ -123,7 +123,7 @@ async def _parse_instruction(self, ctx: Context, instruction: str): ), ] - response = await ctx.llm.completion( + response = await ctx.llm.acompletion( messages=messages, tools=[fn_reg.get_tool_param()], tool_choice="required", @@ -174,7 +174,7 @@ async def _finalize_configs(self, ctx: Context, configs: Dict[str, Any]): ), ] - response = await ctx.llm.completion( + response = await ctx.llm.acompletion( messages=messages, tools=[fn_reg.get_tool_param()], tool_choice="required", diff --git a/npiai/core/browser/_navigator.py b/npiai/core/browser/_navigator.py index e5fea4c9..5000653a 100644 --- a/npiai/core/browser/_navigator.py +++ b/npiai/core/browser/_navigator.py @@ -263,7 +263,7 @@ async def chat(self, ctx: Context, instruction: str) -> str: return f"Maximum number of steps reached. Last response was: {response_str}" async def _call_llm(self, ctx: Context, task: Task) -> str: - response = await ctx.llm.completion( + response = await ctx.llm.acompletion( messages=task.conversations(), max_tokens=4096, ) diff --git a/npiai/core/tool/_agent.py b/npiai/core/tool/_agent.py index c7255425..18970a58 100644 --- a/npiai/core/tool/_agent.py +++ b/npiai/core/tool/_agent.py @@ -82,7 +82,7 @@ async def chat( async def _call_llm(self, ctx: Context, task: Task) -> str: while True: - response = await ctx.llm.completion( + response = await ctx.llm.acompletion( messages=task.conversations(), tools=self._tool.tools, tool_choice="auto", diff --git a/npiai/llm/llm.py b/npiai/llm/llm.py index dd99fd8b..3269174d 100644 --- a/npiai/llm/llm.py +++ b/npiai/llm/llm.py @@ -1,8 +1,7 @@ from enum import Enum -from typing import List, Union import os import asyncio -from litellm import acompletion, ModelResponse, CustomStreamWrapper, drop_params +from litellm import completion, acompletion, ModelResponse, CustomStreamWrapper class Provider(Enum): @@ -27,13 +26,15 @@ def get_provider(self) -> Provider: return self.provider # TODO: kwargs typings - async def completion(self, **kwargs) -> Union[ModelResponse, CustomStreamWrapper]: + async def acompletion(self, **kwargs) -> ModelResponse | CustomStreamWrapper: return await acompletion( model=self.model, api_key=self.api_key, drop_params=True, **kwargs ) - def completion_sync(self, **kwargs) -> Union[ModelResponse, CustomStreamWrapper]: - return asyncio.run(self.completion(**kwargs)) + def completion(self, **kwargs) -> ModelResponse | CustomStreamWrapper: + return completion( + model=self.model, api_key=self.api_key, drop_params=True, **kwargs + ) class OpenAI(LLM): diff --git a/npiai/utils/llm_summarize.py b/npiai/utils/llm_summarize.py index f8014c42..467d9c18 100644 --- a/npiai/utils/llm_summarize.py +++ b/npiai/utils/llm_summarize.py @@ -18,7 +18,7 @@ async def llm_summarize( final_response_content = "" while True: - response = await llm.completion( + response = await llm.acompletion( messages=messages_copy, max_tokens=4096, # use fixed temperature and seed to ensure deterministic results diff --git a/npiai/utils/llm_tool_call.py b/npiai/utils/llm_tool_call.py index c9261160..ce290344 100644 --- a/npiai/utils/llm_tool_call.py +++ b/npiai/utils/llm_tool_call.py @@ -18,7 +18,7 @@ async def llm_tool_call( if fn_reg.model is None: raise RuntimeError("Unable to modeling tool function") - response = await llm.completion( + response = await llm.acompletion( messages=messages, tools=[fn_reg.get_tool_param()], max_tokens=4096,