diff --git a/src/coder_agent.py b/src/coder_agent.py index 7c5df94..823e1b7 100644 --- a/src/coder_agent.py +++ b/src/coder_agent.py @@ -23,8 +23,10 @@ from llama_index.core.tools import ToolSelection, ToolOutput from llama_index.core.workflow import Event +from pydantic import BaseModel from typing_extensions import TypedDict +from llama_index.core.program import FunctionCallingProgram from code_executor import CodeExecutor @@ -41,6 +43,12 @@ class TestCase(TypedDict): outputs: str +class CoderAgentOutput(BaseModel): + reasoning: str + pseudocode: str + code: str + + class InputEvent(Event): input: list[ChatMessage] test_cases: list[TestCase] = None @@ -68,6 +76,15 @@ def __init__( self.tools = [FunctionTool.from_defaults(self.evaluate)] + self.pydantic_code_generator = FunctionCallingProgram.from_defaults( + output_cls=CoderAgentOutput, + llm=self.llm, + verbose=True, + prompt_template_str="{input}", + tool_choice=self.tools[0], + allow_parallel_tool_calls=True, + ) + self.memory = ChatMemoryBuffer.from_defaults(llm=llm) self.sources = [] @@ -90,6 +107,9 @@ async def handle_llm_input(self, ev: InputEvent) -> ToolCallEvent | StopEvent: ic(ev) chat_history = ev.input + ic(chat_history[-1].content) + pydantic_call = self.pydantic_code_generator(input=chat_history[-1].content) + ic(pydantic_call) response = await self.llm.achat_with_tools( self.tools, chat_history=chat_history )