Skip to content

Commit

Permalink
Consolidate explain and diff-converse functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nicovank committed Oct 3, 2024
1 parent 25a9490 commit 23acaea
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 84 deletions.
6 changes: 2 additions & 4 deletions src/cwhy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,11 @@ def main() -> None:
"subcommand",
nargs="?",
default="explain",
choices=["explain", "diff", "converse", "diff-converse"],
choices=["explain", "diff-converse"],
metavar="subcommand",
help=textwrap.dedent(
r"""
explain: explain the diagnostic (default)
diff: \[experimental] generate a diff to fix the diagnostic
converse: \[experimental] interactively converse with CWhy
diff-converse: \[experimental] interactively fix errors with CWhy
"""
).strip(),
Expand All @@ -129,7 +127,7 @@ def main() -> None:
"--llm",
type=str,
default="openai/gpt-4o-mini",
help="the language model to use"
help="the language model to use",
)
parser.add_argument(
"--timeout",
Expand Down
106 changes: 29 additions & 77 deletions src/cwhy/conversation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,13 @@
import json
import textwrap

import llm_utils
import openai
import openai # type: ignore

from . import utils
from .diff_functions import DiffFunctions
from .explain_functions import ExplainFunctions
from ..print_debug import dprint


def converse(client: openai.OpenAI, args, diagnostic):
fns = ExplainFunctions(args)
available_functions_names = [fn["function"]["name"] for fn in fns.as_tools()]
system_message = textwrap.dedent(
f"""
You are an assistant debugger. The user is having an issue with their code, and you are trying to help them.
A few functions exist to help with this process, namely: {", ".join(available_functions_names)}.
Don't hesitate to call as many functions as needed to give the best possible answer.
Once you have identified the problem, explain the diagnostic and provide a way to fix the issue if you can.
"""
).strip()
user_message = f"Here is my error message:\n\n```\n{utils.get_truncated_error_message(args, diagnostic)}\n```\n\nWhat's the problem?"
conversation = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]

while True:
completion = client.chat.completions.create(
model=args.llm,
messages=conversation,
tools=fns.as_tools(),
timeout=args.timeout,
)

choice = completion.choices[0]
if choice.finish_reason == "tool_calls":
responses = []
for tool_call in choice.message.tool_calls:
function_response = fns.dispatch(tool_call.function)
if function_response:
responses.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": tool_call.function.name,
"content": function_response,
}
)
conversation.append(choice.message)
conversation.extend(responses)
dprint()
elif choice.finish_reason == "stop":
text = completion.choices[0].message.content
return llm_utils.word_wrap_except_code_blocks(text)
else:
dprint(f"Not found: {choice.finish_reason}.")


def diff_converse(client: openai.OpenAI, args, diagnostic):
fns = DiffFunctions(args)
tools = fns.as_tools()
Expand All @@ -76,41 +25,43 @@ def diff_converse(client: openai.OpenAI, args, diagnostic):
{"role": "user", "content": user_message},
]

pick_action_schema = {
"name": "pick_action",
"description": "Picks an action to take to get more information about or fix the code.",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": tool_names,
},
},
"required": ["action"],
},
}

while True:
# 1. Pick an action.
completion = client.chat.completions.create(
completion = client.chat.completions.create( # type: ignore
model=args.llm,
messages=conversation,
tools=[{"type": "function", "function": pick_action_schema}],
tool_choice={
"type": "function",
"function": {"name": "pick_action"},
},
tools=[
{
"type": "function",
"function": {
"name": "pick_action",
"description": "Picks an action to get more information about the code or fix it.",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": tool_names,
},
},
"required": ["action"],
},
},
}
],
tool_choice={"type": "function", "function": {"name": "pick_action"}},
timeout=args.timeout,
)

print(completion)
fn = completion.choices[0].message.tool_calls[0].function
assert completion.choices and len(completion.choices) == 1
choice = completion.choices[0]
assert choice.message.tool_calls and len(choice.message.tool_calls) == 1
fn = choice.message.tool_calls[0].function
arguments = json.loads(fn.arguments)
action = arguments["action"]

tool = [t for t in tools if t["function"]["name"] == action][0]
completion = client.chat.completions.create(
completion = client.chat.completions.create( # type: ignore
model=args.llm,
messages=conversation,
tools=[tool],
Expand All @@ -121,7 +72,9 @@ def diff_converse(client: openai.OpenAI, args, diagnostic):
timeout=args.timeout,
)

assert completion.choices and len(completion.choices) == 1
choice = completion.choices[0]
assert choice.message.tool_calls and len(choice.message.tool_calls) == 1
tool_call = choice.message.tool_calls[0]
function_response = fns.dispatch(tool_call.function)
if function_response:
Expand All @@ -130,7 +83,6 @@ def diff_converse(client: openai.OpenAI, args, diagnostic):
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": tool_call.function.name,
"content": function_response,
}
)
Expand Down
2 changes: 0 additions & 2 deletions src/cwhy/cwhy.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ def evaluate(client: openai.OpenAI, args, stdin):
tool_calls = completion.choices[0].message.tool_calls
assert len(tool_calls) == 1
return tool_calls[0].function.arguments
elif args.subcommand == "converse":
return conversation.converse(client, args, stdin)
elif args.subcommand == "diff-converse":
return conversation.diff_converse(client, args, stdin)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/cwhy/print_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def dprint(
*objects, sep: str = " ", end: str = "\n", file=None, flush: bool = False
*objects: object, sep: str = " ", end: str = "\n", file=None, flush: bool = False
) -> None:
global _debug
if not _debug:
Expand Down

0 comments on commit 23acaea

Please sign in to comment.