@@ -139,31 +139,29 @@ def _log(self, completion_id, input, output):
139139 entry = {"input" : input , "output" : output , "tools" : self .tools }
140140 self .logger .log (f"{ completion_id } " , entry )
141141
142- def invoke (self , input , config = None ):
142+ def invoke (self , input , config = None , ** kwargs ):
143143 completion_id = str (uuid .uuid4 ())
144144
145- async def execute ():
145+ def execute ():
146146 result = self .llm .invoke (input , config = config )
147147 self ._log (completion_id , input , result )
148148 return result
149149
150150 result = execute ()
151151
152- if hasattr (result , "get" ) and result .get ("parsed" ):
153- return result .get ("parsed" )
154-
155- if hasattr (result , "tool_calls" ):
156- for tool_call in result .tool_calls :
152+ tool_calls = getattr (result , "tool_calls" , None )
153+ if tool_calls :
154+ for tool_call in tool_calls :
157155 if isinstance (tool_call ["args" ], str ):
158156 tool_call ["args" ] = json .loads (tool_call ["args" ])
159157
160158 if self .structured_output :
161159 return self .structured_output .model_validate (
162- result . tool_calls [0 ]["args" ] if result . tool_calls else None
160+ tool_calls [0 ]["args" ] if tool_calls else None
163161 )
164162 return result
165163
166- async def ainvoke (self , input , config = None ):
164+ async def ainvoke (self , input , config = None , ** kwargs ):
167165 completion_id = str (uuid .uuid4 ())
168166
169167 async def execute ():
@@ -178,17 +176,15 @@ async def execute():
178176
179177 result = await execute ()
180178
181- if hasattr (result , "get" ) and result .get ("parsed" ):
182- return result .get ("parsed" )
183-
184- if hasattr (result , "tool_calls" ):
185- for tool_call in result .tool_calls :
179+ tool_calls = getattr (result , "tool_calls" , None )
180+ if tool_calls :
181+ for tool_call in tool_calls :
186182 if isinstance (tool_call ["args" ], str ):
187183 tool_call ["args" ] = json .loads (tool_call ["args" ])
188184
189185 if self .structured_output :
190186 return self .structured_output .model_validate (
191- result . tool_calls [0 ]["args" ] if result . tool_calls else None
187+ tool_calls [0 ]["args" ] if tool_calls else None
192188 )
193189 return result
194190
@@ -220,23 +216,24 @@ def with_config(
220216 ):
221217 art_config = CURRENT_CONFIG .get ()
222218 self .logger = art_config ["logger" ]
223- max_tokens = config .get ("max_tokens" )
224219
225220 if hasattr (self .llm , "bound" ):
226- self .llm .bound = ChatOpenAI (
227- base_url = art_config ["base_url" ],
228- api_key = art_config ["api_key" ],
229- model = art_config ["model" ],
230- temperature = 1.0 ,
231- max_tokens = max_tokens ,
221+ setattr (
222+ self .llm ,
223+ "bound" ,
224+ ChatOpenAI (
225+ base_url = art_config ["base_url" ],
226+ api_key = art_config ["api_key" ],
227+ model = art_config ["model" ],
228+ temperature = 1.0 ,
229+ ),
232230 )
233231 else :
234232 self .llm = ChatOpenAI (
235233 base_url = art_config ["base_url" ],
236234 api_key = art_config ["api_key" ],
237235 model = art_config ["model" ],
238236 temperature = 1.0 ,
239- max_tokens = max_tokens ,
240237 )
241238
242239 return self
0 commit comments