Skip to content

Commit

Permalink
temporary: Separate coder works.
Browse files Browse the repository at this point in the history
  • Loading branch information
anirbanbasu committed Aug 15, 2024
1 parent 71dedd0 commit d33b82a
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 42 deletions.
62 changes: 37 additions & 25 deletions src/coder_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
from langgraph.graph.message import AnyMessage, add_messages
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langchain.schema.output_parser import StrOutputParser

from langchain_core.tools import BaseTool

from code_executor import CodeExecutor
from langgraph.checkpoint.sqlite import SqliteSaver
Expand Down Expand Up @@ -64,44 +62,48 @@ class CoderInput(BaseModel):
examples: str = Field(
...,
description="Examples of similar challenges and their solutions.",
default=constants.EMPTY_STRING,
)


class CoderOutput(BaseModel):
reasoning: str = Field(..., description="Reasoning for the conceptual solution.")
pseudocode: str = Field(..., description="Pseudocode for the solution.")
code: str = Field(..., description="Python code implementation for the solution.")
summary: str = Field(
..., description="A short one sentence summary of the solution."
)


class CoderTool(BaseTool):
name = "coder_tool"
description = "Generate Python code to solve the given problem."
args_schema = CoderOutput

class Coder:
def __init__(self, llm: BaseChatModel):
self._llm = llm.with_structured_output(CoderOutput)
self._llm = llm

def _run(self, challenge: str, examples: str) -> CoderOutput:
def solve(
self, challenge: str, examples: str = constants.EMPTY_STRING
) -> CoderOutput:
"""Run the tool"""
messages = [
("system", constants.ENV_VAR_VALUE__LLM_CODER_SYSTEM_PROMPT),
("human", "{input}"),
]
chain = (
ChatPromptTemplate.from_messages(messages=messages)
| self._llm
| StrOutputParser()
)
return chain.invoke({"input": challenge, "examples": examples})
chain = ChatPromptTemplate.from_messages(
messages=messages
) | self._llm.with_structured_output(CoderOutput)
raw_result = chain.invoke({"input": challenge, "examples": examples})
ic(raw_result)
return raw_result


class MultiAgentOrchestrator:
def __init__(self, llm: BaseChatModel, prompt: ChatPromptTemplate):
self._llm = llm
self._prompt = prompt
self._runnable_solver = self._prompt | self._llm.bind_tools([CoderOutput])
self._runnable_draft_solver = self._prompt | self._llm.bind_tools([CoderOutput])
self._runnable_solver = self._prompt | self._llm.with_structured_output(
CoderOutput
)
self._runnable_draft_solver = self._prompt | self._llm.with_structured_output(
CoderOutput
)
self._evaluator = CodeExecutor()
# self._retriever = BM25Retriever.from_texts(
# [self.format_example(row) for row in train_ds]
Expand Down Expand Up @@ -200,12 +202,18 @@ def solve(self, state: AgentState) -> dict:
inputs[constants.AGENT_STATE__KEY_EXAMPLES] = state[
constants.AGENT_STATE__KEY_EXAMPLES
]
coder = Coder(self._llm)
ic(inputs)
response = (
# Use the draft solver only if the `draft` flag is set in the state
self._runnable_draft_solver.invoke(inputs)
if state[constants.AGENT_STATE__KEY_DRAFT] is True
else self._runnable_solver.invoke(inputs)
# self._runnable_draft_solver.invoke(inputs)
# if state[constants.AGENT_STATE__KEY_DRAFT] is True
# else self._runnable_solver.invoke(inputs)
self.pydantic_to_ai_message(
coder.solve(inputs[constants.AGENT_STATE__KEY_MESSAGES][-1].content)
)
)
ic(response)
# FIXME: Why do we need this? `OllamaFunctions`, for example, does not output `content`.
# if not response.content:
# return {
Expand All @@ -226,6 +234,11 @@ def draft_solve(self, state: AgentState) -> dict:
state[constants.AGENT_STATE__KEY_DRAFT] = True
return self.solve(state)

def pydantic_to_ai_message(
self, structured_message: BaseModel, id: str = None
) -> AIMessage:
return AIMessage(content=[structured_message.dict()], id=id)

def format_tool_message(self, response: str, ai_message: AIMessage) -> ToolMessage:
"""
Format the response as a tool message specifying the tool call ID.
Expand Down Expand Up @@ -258,7 +271,8 @@ def evaluate(self, state: AgentState) -> dict:
test_cases = state[constants.AGENT_STATE__KEY_TEST_CASES]
# Extract the `AIMessage` that is expected to contain the code from the last tool call.
ai_message: AIMessage = state[constants.AGENT_STATE__KEY_MESSAGES][-1]
if not ai_message.tool_calls:
json_dict = ai_message.content[0]
if not json_dict[constants.PYDANTIC_MODEL__CODE_OUTPUT__CODE]:
# If there was no tool call, add a `HumanMessage` to prompt the agent to generate code.
return {
constants.AGENT_STATE__KEY_MESSAGES: [
Expand All @@ -269,9 +283,7 @@ def evaluate(self, state: AgentState) -> dict:
}
try:
# Extract the code from the tool call.
code: str = ai_message.tool_calls[0][constants.AGENT_TOOL_CALL__ARGS][
constants.PYDANTIC_MODEL__CODE_OUTPUT__CODE
]
code: str = json_dict[constants.PYDANTIC_MODEL__CODE_OUTPUT__CODE]
# Use PythonREPL to sanitise the code
# See: https://api.python.langchain.com/en/latest/utilities/langchain_experimental.utilities.python.PythonREPL.html
code = CodeExecutor.sanitise_code(code)
Expand Down
9 changes: 5 additions & 4 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,13 @@
"""

ENV_VAR_VALUE__LLM_CODER_SYSTEM_PROMPT = """
You are a world-class Python programmer. You write concise and well-documented Python-only code following the PEP8 style guide.
You are a world-class Python programmer. You write concise and well-documented code following the PEP8 style guide.
Please respond with a Python 3 solution to the problem below.
First, output a reasoning through the problem and conceptualise a solution.
Then, output a pseudocode to implement your concept solution. If relevant, add a time and space complexity analysis for your pseudocode.
Finally, output the working Python code for your solution. Do not use external libraries.
First, output a reasoning through the problem and conceptualise a solution. Whenever possible, add a time and a space complexity analysis for your solution.
Then, output a pseudocode in Pascal to implement your concept solution.
Then, output the working Python 3 code for your solution. Do not use external libraries. Your code must be able to accept inputs from `sys.stdin` and write the final output to `sys.stdout` (or, to `sys.stderr` in case of errors).
Finally, output a one sentence summary describing what your solution does, as if you are explaining your solution to the human user.
Optional examples of similar problems and solutions (may not be in Python):
{examples}
Expand Down
22 changes: 9 additions & 13 deletions src/webapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.graph.message import AnyMessage


# from langchain_groq.chat_models import ChatGroq
# from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
Expand Down Expand Up @@ -239,17 +240,16 @@ def find_solution(
# stream_mode="values",
)
for result in result_iterator:
ic(result)
if "solve" in result:
ai_message: AIMessage = result["solve"][
coder_output: AIMessage = result["solve"][
constants.AGENT_STATE__KEY_MESSAGES
][-1]
if ai_message.tool_calls:
# raise ValueError("Coding agent did not produce a valid code block")
if coder_output:
json_dict = coder_output.content[0]
yield [
ai_message.tool_calls[0]["args"]["reasoning"],
ai_message.tool_calls[0]["args"]["pseudocode"],
ai_message.tool_calls[0]["args"]["code"],
json_dict[constants.PYDANTIC_MODEL__CODE_OUTPUT__REASONING],
json_dict[constants.PYDANTIC_MODEL__CODE_OUTPUT__PSEUDOCODE],
json_dict[constants.PYDANTIC_MODEL__CODE_OUTPUT__CODE],
]

def add_test_case(
Expand Down Expand Up @@ -367,13 +367,9 @@ def construct_interface(self):
show_label=True,
line_breaks=True,
)
output_pseudocode = gr.Markdown(
label="Pseudocode",
show_label=True,
line_breaks=True,
)
output_pseudocode = gr.Code(label="Pseudocode", show_label=True)
output_code = gr.Code(
label="Code",
label="Python code",
show_label=True,
language="python",
)
Expand Down

0 comments on commit d33b82a

Please sign in to comment.