diff --git a/src/chatdbg/conversation/__init__.py b/src/chatdbg/conversation/__init__.py index 9584e3b..ba313a3 100644 --- a/src/chatdbg/conversation/__init__.py +++ b/src/chatdbg/conversation/__init__.py @@ -1,8 +1,9 @@ +import json import textwrap import llm_utils -from . import functions +from .functions_lldb import LldbFunctions def get_truncated_error_message(args, diagnostic) -> str: @@ -36,7 +37,7 @@ def build_diagnostic_string(): def converse(client, args, diagnostic): - fns = functions.Functions(args) + fns = LldbFunctions(args) available_functions_names = [fn["function"]["name"] for fn in fns.as_tools()] system_message = textwrap.dedent( f""" @@ -65,7 +66,9 @@ def converse(client, args, diagnostic): choice = completion.choices[0] if choice.finish_reason == "tool_calls": for tool_call in choice.message.tool_calls: - function_response = fns.dispatch(tool_call.function) + name = tool_call.function.name + arguments = json.loads(tool_call.function.arguments) + function_response = fns.dispatch(name, arguments) if function_response: conversation.append(choice.message) conversation.append( diff --git a/src/chatdbg/conversation/functions.py b/src/chatdbg/conversation/functions_interface.py similarity index 68% rename from src/chatdbg/conversation/functions.py rename to src/chatdbg/conversation/functions_interface.py index db2100c..c74e74c 100644 --- a/src/chatdbg/conversation/functions.py +++ b/src/chatdbg/conversation/functions_interface.py @@ -1,11 +1,9 @@ -import json -import os from typing import Optional import llm_utils -class Functions: +class BaseFunctions: def __init__(self, args): self.args = args @@ -15,20 +13,17 @@ def as_tools(self): for schema in [self.get_code_surrounding_schema()] ] - def dispatch(self, function_call) -> Optional[str]: - arguments = json.loads(function_call.arguments) - print( - f"Calling: {function_call.name}({', '.join([f'{k}={v}' for k, v in arguments.items()])})" - ) - try: - if function_call.name == "get_code_surrounding": - return self.get_code_surrounding( - arguments["filename"], arguments["lineno"] - ) - else: - raise ValueError("No such function.") - except Exception as e: - print(e) + def dispatch(self, name, arguments) -> Optional[str]: + if name == "get_code_surrounding": + filename = arguments["filename"] + lineno = arguments["lineno"] + result = self.get_code_surrounding(filename, lineno) + + print(f"Calling get_code_surrounding({filename}, {lineno})...") + print(result) + print() + + return result return None def get_code_surrounding_schema(self): diff --git a/src/chatdbg/conversation/functions_lldb.py b/src/chatdbg/conversation/functions_lldb.py new file mode 100644 index 0000000..956732e --- /dev/null +++ b/src/chatdbg/conversation/functions_lldb.py @@ -0,0 +1,19 @@ +import json +from typing import Optional + +import llm_utils + +from .functions_interface import BaseFunctions + + +class LldbFunctions(BaseFunctions): + def __init__(self, args): + super().__init__(args) + + def as_tools(self): + return super().as_tools() + [ + {"type": "function", "function": schema} for schema in [] + ] + + def dispatch(self, name, arguments) -> Optional[str]: + return super().dispatch(name, arguments)