Skip to content

Commit

Permalink
Remove LiteLLM usage, fall back to OpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
nicovank committed Oct 2, 2024
1 parent d1e48f2 commit 25a9490
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 46 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ dependencies = [
"openai>=1.30.1",
"PyYAML>=6.0.1",
"rich>=13.7.1",
"litellm>=1.37.19",
"llm-utils>=0.2.8",
]
description = "Explains and proposes fixes for compile-time errors for many programming languages."
Expand Down
17 changes: 7 additions & 10 deletions src/cwhy/conversation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import json
import textwrap
import warnings

with warnings.catch_warnings():
warnings.simplefilter("ignore")
import litellm # type: ignore

import llm_utils
import openai

from . import utils
from .diff_functions import DiffFunctions
from .explain_functions import ExplainFunctions
from ..print_debug import dprint


def converse(args, diagnostic):
def converse(client: openai.OpenAI, args, diagnostic):
fns = ExplainFunctions(args)
available_functions_names = [fn["function"]["name"] for fn in fns.as_tools()]
system_message = textwrap.dedent(
Expand All @@ -32,7 +28,7 @@ def converse(args, diagnostic):
]

while True:
completion = litellm.completion(
completion = client.chat.completions.create(
model=args.llm,
messages=conversation,
tools=fns.as_tools(),
Expand Down Expand Up @@ -63,7 +59,7 @@ def converse(args, diagnostic):
dprint(f"Not found: {choice.finish_reason}.")


def diff_converse(args, diagnostic):
def diff_converse(client: openai.OpenAI, args, diagnostic):
fns = DiffFunctions(args)
tools = fns.as_tools()
tool_names = [fn["function"]["name"] for fn in tools]
Expand Down Expand Up @@ -97,7 +93,7 @@ def diff_converse(args, diagnostic):

while True:
# 1. Pick an action.
completion = litellm.completion(
completion = client.chat.completions.create(
model=args.llm,
messages=conversation,
tools=[{"type": "function", "function": pick_action_schema}],
Expand All @@ -108,12 +104,13 @@ def diff_converse(args, diagnostic):
timeout=args.timeout,
)

print(completion)
fn = completion.choices[0].message.tool_calls[0].function
arguments = json.loads(fn.arguments)
action = arguments["action"]

tool = [t for t in tools if t["function"]["name"] == action][0]
completion = litellm.completion(
completion = client.chat.completions.create(
model=args.llm,
messages=conversation,
tools=[tool],
Expand Down
63 changes: 28 additions & 35 deletions src/cwhy/cwhy.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,45 @@
import argparse
import os
import subprocess
import sys
from typing import Any
import warnings

with warnings.catch_warnings():
warnings.simplefilter("ignore")
import litellm # type: ignore

litellm.suppress_debug_info = True

import llm_utils
from openai import (
NotFoundError,
RateLimitError,
APITimeoutError,
OpenAIError,
BadRequestError,
)
import openai

from . import conversation, prompts
from .print_debug import dprint, enable_debug_printing


def complete(args: argparse.Namespace, user_prompt: str, **kwargs: Any):
def complete(
client: openai.OpenAI, args: argparse.Namespace, user_prompt: str, **kwargs: Any
):
try:
completion = litellm.completion(
completion = client.chat.completions.create(
model=args.llm,
messages=[{"role": "user", "content": user_prompt}],
timeout=args.timeout,
**kwargs,
)
return completion
except NotFoundError as e:
except openai.NotFoundError as e:
dprint(f"'{args.llm}' either does not exist or you do not have access to it.")
raise e
except BadRequestError as e:
except openai.BadRequestError as e:
dprint("Something is wrong with your prompt.")
raise e
except RateLimitError as e:
except openai.RateLimitError as e:
dprint("You have exceeded a rate limit or have no remaining funds.")
raise e
except APITimeoutError as e:
except openai.APITimeoutError as e:
dprint("The API timed out.")
dprint("You can increase the timeout with the --timeout option.")
raise e


def evaluate_diff(args, stdin):
def evaluate_diff(client: openai.OpenAI, args, stdin):
prompt = prompts.diff_prompt(args, stdin)
completion = complete(
client,
args,
prompt,
tools=[
Expand Down Expand Up @@ -95,20 +84,18 @@ def evaluate_diff(args, stdin):
return completion


def evaluate(args, stdin):
def evaluate(client: openai.OpenAI, args, stdin):
if args.subcommand == "explain":
return evaluate_text_prompt(args, prompts.explain_prompt(args, stdin))
return evaluate_text_prompt(client, args, prompts.explain_prompt(args, stdin))
elif args.subcommand == "diff":
completion = evaluate_diff(args, stdin)
completion = evaluate_diff(client, args, stdin)
tool_calls = completion.choices[0].message.tool_calls
assert len(tool_calls) == 1
return tool_calls[0].function.arguments
elif args.subcommand == "converse":
assert litellm.supports_function_calling(model=args.llm)
return conversation.converse(args, stdin)
return conversation.converse(client, args, stdin)
elif args.subcommand == "diff-converse":
assert litellm.supports_function_calling(model=args.llm)
return conversation.diff_converse(args, stdin)
return conversation.diff_converse(client, args, stdin)
else:
raise Exception(f"unknown subcommand: {args.subcommand}")

Expand Down Expand Up @@ -142,19 +129,26 @@ def main(args: argparse.Namespace) -> None:
dprint("CWhy")
dprint("==================================================")
try:
result = evaluate(args, process.stderr if process.stderr else process.stdout)
client = openai.OpenAI()
result = evaluate(
client, args, process.stderr if process.stderr else process.stdout
)
dprint(result)
except OpenAIError as e:
except openai.OpenAIError as e:
dprint(str(e).strip())
dprint("==================================================")

sys.exit(process.returncode)


def evaluate_text_prompt(
args: argparse.Namespace, prompt: str, wrap: bool = True, **kwargs: Any
client: openai.OpenAI,
args: argparse.Namespace,
prompt: str,
wrap: bool = True,
**kwargs: Any,
) -> str:
completion = complete(args, prompt, **kwargs)
completion = complete(client, args, prompt, **kwargs)

msg = f"Analysis from {args.llm}:"
dprint(msg)
Expand All @@ -164,8 +158,7 @@ def evaluate_text_prompt(
if wrap:
text = llm_utils.word_wrap_except_code_blocks(text)

cost = litellm.completion_cost(completion_response=completion)
text += "\n\n"
text += f"(Total cost: approximately ${cost:.2f} USD.)"
text += f"(TODO seconds, $TODO USD.)"

return text

0 comments on commit 25a9490

Please sign in to comment.