diff --git a/src/cwhy/__main__.py b/src/cwhy/__main__.py index 570fec3..bb63a0f 100755 --- a/src/cwhy/__main__.py +++ b/src/cwhy/__main__.py @@ -113,13 +113,11 @@ def main() -> None: "subcommand", nargs="?", default="explain", - choices=["explain", "diff", "converse", "diff-converse"], + choices=["explain", "diff-converse"], metavar="subcommand", help=textwrap.dedent( r""" explain: explain the diagnostic (default) - diff: \[experimental] generate a diff to fix the diagnostic - converse: \[experimental] interactively converse with CWhy diff-converse: \[experimental] interactively fix errors with CWhy """ ).strip(), @@ -129,7 +127,7 @@ def main() -> None: "--llm", type=str, default="openai/gpt-4o-mini", - help="the language model to use" + help="the language model to use", ) parser.add_argument( "--timeout", diff --git a/src/cwhy/conversation/__init__.py b/src/cwhy/conversation/__init__.py index 66de123..2ee23c4 100644 --- a/src/cwhy/conversation/__init__.py +++ b/src/cwhy/conversation/__init__.py @@ -1,64 +1,13 @@ import json import textwrap -import llm_utils -import openai +import openai # type: ignore from . import utils from .diff_functions import DiffFunctions -from .explain_functions import ExplainFunctions from ..print_debug import dprint -def converse(client: openai.OpenAI, args, diagnostic): - fns = ExplainFunctions(args) - available_functions_names = [fn["function"]["name"] for fn in fns.as_tools()] - system_message = textwrap.dedent( - f""" - You are an assistant debugger. The user is having an issue with their code, and you are trying to help them. - A few functions exist to help with this process, namely: {", ".join(available_functions_names)}. - Don't hesitate to call as many functions as needed to give the best possible answer. - Once you have identified the problem, explain the diagnostic and provide a way to fix the issue if you can. - """ - ).strip() - user_message = f"Here is my error message:\n\n```\n{utils.get_truncated_error_message(args, diagnostic)}\n```\n\nWhat's the problem?" - conversation = [ - {"role": "system", "content": system_message}, - {"role": "user", "content": user_message}, - ] - - while True: - completion = client.chat.completions.create( - model=args.llm, - messages=conversation, - tools=fns.as_tools(), - timeout=args.timeout, - ) - - choice = completion.choices[0] - if choice.finish_reason == "tool_calls": - responses = [] - for tool_call in choice.message.tool_calls: - function_response = fns.dispatch(tool_call.function) - if function_response: - responses.append( - { - "tool_call_id": tool_call.id, - "role": "tool", - "name": tool_call.function.name, - "content": function_response, - } - ) - conversation.append(choice.message) - conversation.extend(responses) - dprint() - elif choice.finish_reason == "stop": - text = completion.choices[0].message.content - return llm_utils.word_wrap_except_code_blocks(text) - else: - dprint(f"Not found: {choice.finish_reason}.") - - def diff_converse(client: openai.OpenAI, args, diagnostic): fns = DiffFunctions(args) tools = fns.as_tools() @@ -76,41 +25,43 @@ def diff_converse(client: openai.OpenAI, args, diagnostic): {"role": "user", "content": user_message}, ] - pick_action_schema = { - "name": "pick_action", - "description": "Picks an action to take to get more information about or fix the code.", - "parameters": { - "type": "object", - "properties": { - "action": { - "type": "string", - "enum": tool_names, - }, - }, - "required": ["action"], - }, - } - while True: # 1. Pick an action. - completion = client.chat.completions.create( + completion = client.chat.completions.create( # type: ignore model=args.llm, messages=conversation, - tools=[{"type": "function", "function": pick_action_schema}], - tool_choice={ - "type": "function", - "function": {"name": "pick_action"}, - }, + tools=[ + { + "type": "function", + "function": { + "name": "pick_action", + "description": "Picks an action to get more information about the code or fix it.", + "parameters": { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": tool_names, + }, + }, + "required": ["action"], + }, + }, + } + ], + tool_choice={"type": "function", "function": {"name": "pick_action"}}, timeout=args.timeout, ) - print(completion) - fn = completion.choices[0].message.tool_calls[0].function + assert completion.choices and len(completion.choices) == 1 + choice = completion.choices[0] + assert choice.message.tool_calls and len(choice.message.tool_calls) == 1 + fn = choice.message.tool_calls[0].function arguments = json.loads(fn.arguments) action = arguments["action"] tool = [t for t in tools if t["function"]["name"] == action][0] - completion = client.chat.completions.create( + completion = client.chat.completions.create( # type: ignore model=args.llm, messages=conversation, tools=[tool], @@ -121,7 +72,9 @@ def diff_converse(client: openai.OpenAI, args, diagnostic): timeout=args.timeout, ) + assert completion.choices and len(completion.choices) == 1 choice = completion.choices[0] + assert choice.message.tool_calls and len(choice.message.tool_calls) == 1 tool_call = choice.message.tool_calls[0] function_response = fns.dispatch(tool_call.function) if function_response: @@ -130,7 +83,6 @@ def diff_converse(client: openai.OpenAI, args, diagnostic): { "tool_call_id": tool_call.id, "role": "tool", - "name": tool_call.function.name, "content": function_response, } ) diff --git a/src/cwhy/cwhy.py b/src/cwhy/cwhy.py index 1c10b1b..a13ffcf 100755 --- a/src/cwhy/cwhy.py +++ b/src/cwhy/cwhy.py @@ -92,8 +92,6 @@ def evaluate(client: openai.OpenAI, args, stdin): tool_calls = completion.choices[0].message.tool_calls assert len(tool_calls) == 1 return tool_calls[0].function.arguments - elif args.subcommand == "converse": - return conversation.converse(client, args, stdin) elif args.subcommand == "diff-converse": return conversation.diff_converse(client, args, stdin) else: diff --git a/src/cwhy/print_debug.py b/src/cwhy/print_debug.py index 7056437..3443c50 100644 --- a/src/cwhy/print_debug.py +++ b/src/cwhy/print_debug.py @@ -6,7 +6,7 @@ def dprint( - *objects, sep: str = " ", end: str = "\n", file=None, flush: bool = False + *objects: object, sep: str = " ", end: str = "\n", file=None, flush: bool = False ) -> None: global _debug if not _debug: