Skip to content

Commit ba771f7

Browse files
authored
Langgraph type fixes (#361)
1 parent 37b2d91 commit ba771f7

File tree

2 files changed

+44
-37
lines changed

2 files changed

+44
-37
lines changed

src/art/langgraph/llm_wrapper.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -139,31 +139,29 @@ def _log(self, completion_id, input, output):
139139
entry = {"input": input, "output": output, "tools": self.tools}
140140
self.logger.log(f"{completion_id}", entry)
141141

142-
def invoke(self, input, config=None):
142+
def invoke(self, input, config=None, **kwargs):
143143
completion_id = str(uuid.uuid4())
144144

145-
async def execute():
145+
def execute():
146146
result = self.llm.invoke(input, config=config)
147147
self._log(completion_id, input, result)
148148
return result
149149

150150
result = execute()
151151

152-
if hasattr(result, "get") and result.get("parsed"):
153-
return result.get("parsed")
154-
155-
if hasattr(result, "tool_calls"):
156-
for tool_call in result.tool_calls:
152+
tool_calls = getattr(result, "tool_calls", None)
153+
if tool_calls:
154+
for tool_call in tool_calls:
157155
if isinstance(tool_call["args"], str):
158156
tool_call["args"] = json.loads(tool_call["args"])
159157

160158
if self.structured_output:
161159
return self.structured_output.model_validate(
162-
result.tool_calls[0]["args"] if result.tool_calls else None
160+
tool_calls[0]["args"] if tool_calls else None
163161
)
164162
return result
165163

166-
async def ainvoke(self, input, config=None):
164+
async def ainvoke(self, input, config=None, **kwargs):
167165
completion_id = str(uuid.uuid4())
168166

169167
async def execute():
@@ -178,17 +176,15 @@ async def execute():
178176

179177
result = await execute()
180178

181-
if hasattr(result, "get") and result.get("parsed"):
182-
return result.get("parsed")
183-
184-
if hasattr(result, "tool_calls"):
185-
for tool_call in result.tool_calls:
179+
tool_calls = getattr(result, "tool_calls", None)
180+
if tool_calls:
181+
for tool_call in tool_calls:
186182
if isinstance(tool_call["args"], str):
187183
tool_call["args"] = json.loads(tool_call["args"])
188184

189185
if self.structured_output:
190186
return self.structured_output.model_validate(
191-
result.tool_calls[0]["args"] if result.tool_calls else None
187+
tool_calls[0]["args"] if tool_calls else None
192188
)
193189
return result
194190

@@ -220,23 +216,24 @@ def with_config(
220216
):
221217
art_config = CURRENT_CONFIG.get()
222218
self.logger = art_config["logger"]
223-
max_tokens = config.get("max_tokens")
224219

225220
if hasattr(self.llm, "bound"):
226-
self.llm.bound = ChatOpenAI(
227-
base_url=art_config["base_url"],
228-
api_key=art_config["api_key"],
229-
model=art_config["model"],
230-
temperature=1.0,
231-
max_tokens=max_tokens,
221+
setattr(
222+
self.llm,
223+
"bound",
224+
ChatOpenAI(
225+
base_url=art_config["base_url"],
226+
api_key=art_config["api_key"],
227+
model=art_config["model"],
228+
temperature=1.0,
229+
),
232230
)
233231
else:
234232
self.llm = ChatOpenAI(
235233
base_url=art_config["base_url"],
236234
api_key=art_config["api_key"],
237235
model=art_config["model"],
238236
temperature=1.0,
239-
max_tokens=max_tokens,
240237
)
241238

242239
return self

src/art/langgraph/message_utils.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from openai.types.chat.chat_completion_function_message_param import (
2020
ChatCompletionFunctionMessageParam,
2121
)
22+
from openai.types.chat.chat_completion_message import ChatCompletionMessage
2223
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
2324
from openai.types.chat.chat_completion_system_message_param import (
2425
ChatCompletionSystemMessageParam,
@@ -50,7 +51,7 @@ def make_message_param(role: str, **kwargs) -> ChatCompletionMessageParam:
5051
return cls(**kwargs)
5152

5253

53-
def langchain_msg_to_openai(msg: BaseMessage) -> Message:
54+
def langchain_msg_to_openai(msg: BaseMessage):
5455
if isinstance(msg, HumanMessage):
5556
role = "user"
5657
elif isinstance(msg, AIMessage):
@@ -71,23 +72,30 @@ def langchain_msg_to_openai(msg: BaseMessage) -> Message:
7172
result = {"role": role, "content": content}
7273

7374
# Handle tool calls or function call if present
74-
if hasattr(msg, "tool_calls") and msg.tool_calls:
75-
result["tool_calls"] = msg.tool_calls
76-
if hasattr(msg, "tool_call_id"):
77-
result["tool_call_id"] = msg.tool_call_id
78-
if hasattr(msg, "function_call") and msg.function_call:
79-
result["function_call"] = msg.function_call
75+
tool_calls = getattr(msg, "tool_calls", None)
76+
if tool_calls:
77+
result["tool_calls"] = tool_calls
78+
79+
tool_call_id = getattr(msg, "tool_call_id", None)
80+
if tool_call_id:
81+
result["tool_call_id"] = tool_call_id
82+
83+
function_call = getattr(msg, "function_call", None)
84+
if function_call:
85+
result["function_call"] = function_call
8086

8187
return result
8288

8389

8490
def convert_langgraph_messages(messages: List[object]) -> MessagesAndChoices:
85-
converted: MessagesAndChoices = []
91+
converted = []
8692

8793
for msg in messages:
88-
if hasattr(msg, "response_metadata") and "logprobs" in msg.response_metadata:
89-
if msg.tool_calls:
90-
for tool_call in msg.tool_calls:
94+
response_metadata = getattr(msg, "response_metadata")
95+
if response_metadata and "logprobs" in response_metadata:
96+
tool_calls = getattr(msg, "tool_calls", None)
97+
if tool_calls:
98+
for tool_call in tool_calls:
9199
tool_call["function"] = {
92100
"arguments": json.dumps(tool_call["args"]),
93101
"name": tool_call["name"],
@@ -96,11 +104,13 @@ def convert_langgraph_messages(messages: List[object]) -> MessagesAndChoices:
96104

97105
converted.append(
98106
Choice(
99-
message=ChatCompletionAssistantMessageParam(
100-
role="assistant", content=msg.content, tool_calls=msg.tool_calls
107+
message=ChatCompletionMessage(
108+
role="assistant",
109+
content=getattr(msg, "content"),
110+
tool_calls=tool_calls,
101111
),
102112
index=0,
103-
**msg.response_metadata,
113+
**response_metadata,
104114
)
105115
)
106116
elif isinstance(msg, BaseMessage):

0 commit comments

Comments
 (0)