From d33b82a8a855e95c602aa8acdfb4df2b911fedb1 Mon Sep 17 00:00:00 2001 From: Anirban Basu Date: Thu, 15 Aug 2024 22:54:31 +0900 Subject: [PATCH] temporary: Separate coder works. --- src/coder_agent.py | 62 +++++++++++++++++++++++++++------------------- src/constants.py | 9 ++++--- src/webapp.py | 22 +++++++--------- 3 files changed, 51 insertions(+), 42 deletions(-) diff --git a/src/coder_agent.py b/src/coder_agent.py index be3b374..f77aaac 100644 --- a/src/coder_agent.py +++ b/src/coder_agent.py @@ -23,9 +23,7 @@ from langgraph.graph.message import AnyMessage, add_messages from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.runnables import RunnableConfig -from langchain.schema.output_parser import StrOutputParser -from langchain_core.tools import BaseTool from code_executor import CodeExecutor from langgraph.checkpoint.sqlite import SqliteSaver @@ -64,7 +62,6 @@ class CoderInput(BaseModel): examples: str = Field( ..., description="Examples of similar challenges and their solutions.", - default=constants.EMPTY_STRING, ) @@ -72,36 +69,41 @@ class CoderOutput(BaseModel): reasoning: str = Field(..., description="Reasoning for the conceptual solution.") pseudocode: str = Field(..., description="Pseudocode for the solution.") code: str = Field(..., description="Python code implementation for the solution.") + summary: str = Field( + ..., description="A short one sentence summary of the solution." + ) -class CoderTool(BaseTool): - name = "coder_tool" - description = "Generate Python code to solve the given problem." - args_schema = CoderOutput - +class Coder: def __init__(self, llm: BaseChatModel): - self._llm = llm.with_structured_output(CoderOutput) + self._llm = llm - def _run(self, challenge: str, examples: str) -> CoderOutput: + def solve( + self, challenge: str, examples: str = constants.EMPTY_STRING + ) -> CoderOutput: """Run the tool""" messages = [ ("system", constants.ENV_VAR_VALUE__LLM_CODER_SYSTEM_PROMPT), ("human", "{input}"), ] - chain = ( - ChatPromptTemplate.from_messages(messages=messages) - | self._llm - | StrOutputParser() - ) - return chain.invoke({"input": challenge, "examples": examples}) + chain = ChatPromptTemplate.from_messages( + messages=messages + ) | self._llm.with_structured_output(CoderOutput) + raw_result = chain.invoke({"input": challenge, "examples": examples}) + ic(raw_result) + return raw_result class MultiAgentOrchestrator: def __init__(self, llm: BaseChatModel, prompt: ChatPromptTemplate): self._llm = llm self._prompt = prompt - self._runnable_solver = self._prompt | self._llm.bind_tools([CoderOutput]) - self._runnable_draft_solver = self._prompt | self._llm.bind_tools([CoderOutput]) + self._runnable_solver = self._prompt | self._llm.with_structured_output( + CoderOutput + ) + self._runnable_draft_solver = self._prompt | self._llm.with_structured_output( + CoderOutput + ) self._evaluator = CodeExecutor() # self._retriever = BM25Retriever.from_texts( # [self.format_example(row) for row in train_ds] @@ -200,12 +202,18 @@ def solve(self, state: AgentState) -> dict: inputs[constants.AGENT_STATE__KEY_EXAMPLES] = state[ constants.AGENT_STATE__KEY_EXAMPLES ] + coder = Coder(self._llm) + ic(inputs) response = ( # Use the draft solver only if the `draft` flag is set in the state - self._runnable_draft_solver.invoke(inputs) - if state[constants.AGENT_STATE__KEY_DRAFT] is True - else self._runnable_solver.invoke(inputs) + # self._runnable_draft_solver.invoke(inputs) + # if state[constants.AGENT_STATE__KEY_DRAFT] is True + # else self._runnable_solver.invoke(inputs) + self.pydantic_to_ai_message( + coder.solve(inputs[constants.AGENT_STATE__KEY_MESSAGES][-1].content) + ) ) + ic(response) # FIXME: Why do we need this? `OllamaFunctions`, for example, does not output `content`. # if not response.content: # return { @@ -226,6 +234,11 @@ def draft_solve(self, state: AgentState) -> dict: state[constants.AGENT_STATE__KEY_DRAFT] = True return self.solve(state) + def pydantic_to_ai_message( + self, structured_message: BaseModel, id: str = None + ) -> AIMessage: + return AIMessage(content=[structured_message.dict()], id=id) + def format_tool_message(self, response: str, ai_message: AIMessage) -> ToolMessage: """ Format the response as a tool message specifying the tool call ID. @@ -258,7 +271,8 @@ def evaluate(self, state: AgentState) -> dict: test_cases = state[constants.AGENT_STATE__KEY_TEST_CASES] # Extract the `AIMessage` that is expected to contain the code from the last tool call. ai_message: AIMessage = state[constants.AGENT_STATE__KEY_MESSAGES][-1] - if not ai_message.tool_calls: + json_dict = ai_message.content[0] + if not json_dict[constants.PYDANTIC_MODEL__CODE_OUTPUT__CODE]: # If there was no tool call, add a `HumanMessage` to prompt the agent to generate code. return { constants.AGENT_STATE__KEY_MESSAGES: [ @@ -269,9 +283,7 @@ def evaluate(self, state: AgentState) -> dict: } try: # Extract the code from the tool call. - code: str = ai_message.tool_calls[0][constants.AGENT_TOOL_CALL__ARGS][ - constants.PYDANTIC_MODEL__CODE_OUTPUT__CODE - ] + code: str = json_dict[constants.PYDANTIC_MODEL__CODE_OUTPUT__CODE] # Use PythonREPL to sanitise the code # See: https://api.python.langchain.com/en/latest/utilities/langchain_experimental.utilities.python.PythonREPL.html code = CodeExecutor.sanitise_code(code) diff --git a/src/constants.py b/src/constants.py index 05413d2..be997be 100644 --- a/src/constants.py +++ b/src/constants.py @@ -121,12 +121,13 @@ """ ENV_VAR_VALUE__LLM_CODER_SYSTEM_PROMPT = """ -You are a world-class Python programmer. You write concise and well-documented Python-only code following the PEP8 style guide. +You are a world-class Python programmer. You write concise and well-documented code following the PEP8 style guide. Please respond with a Python 3 solution to the problem below. -First, output a reasoning through the problem and conceptualise a solution. -Then, output a pseudocode to implement your concept solution. If relevant, add a time and space complexity analysis for your pseudocode. -Finally, output the working Python code for your solution. Do not use external libraries. +First, output a reasoning through the problem and conceptualise a solution. Whenever possible, add a time and a space complexity analysis for your solution. +Then, output a pseudocode in Pascal to implement your concept solution. +Then, output the working Python 3 code for your solution. Do not use external libraries. Your code must be able to accept inputs from `sys.stdin` and write the final output to `sys.stdout` (or, to `sys.stderr` in case of errors). +Finally, output a one sentence summary describing what your solution does, as if you are explaining your solution to the human user. Optional examples of similar problems and solutions (may not be in Python): {examples} diff --git a/src/webapp.py b/src/webapp.py index d54fdff..02f0d4d 100644 --- a/src/webapp.py +++ b/src/webapp.py @@ -32,6 +32,7 @@ from langchain_core.messages import HumanMessage, AIMessage from langgraph.graph.message import AnyMessage + # from langchain_groq.chat_models import ChatGroq # from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ( @@ -239,17 +240,16 @@ def find_solution( # stream_mode="values", ) for result in result_iterator: - ic(result) if "solve" in result: - ai_message: AIMessage = result["solve"][ + coder_output: AIMessage = result["solve"][ constants.AGENT_STATE__KEY_MESSAGES ][-1] - if ai_message.tool_calls: - # raise ValueError("Coding agent did not produce a valid code block") + if coder_output: + json_dict = coder_output.content[0] yield [ - ai_message.tool_calls[0]["args"]["reasoning"], - ai_message.tool_calls[0]["args"]["pseudocode"], - ai_message.tool_calls[0]["args"]["code"], + json_dict[constants.PYDANTIC_MODEL__CODE_OUTPUT__REASONING], + json_dict[constants.PYDANTIC_MODEL__CODE_OUTPUT__PSEUDOCODE], + json_dict[constants.PYDANTIC_MODEL__CODE_OUTPUT__CODE], ] def add_test_case( @@ -367,13 +367,9 @@ def construct_interface(self): show_label=True, line_breaks=True, ) - output_pseudocode = gr.Markdown( - label="Pseudocode", - show_label=True, - line_breaks=True, - ) + output_pseudocode = gr.Code(label="Pseudocode", show_label=True) output_code = gr.Code( - label="Code", + label="Python code", show_label=True, language="python", )