Skip to content

Commit

Permalink
- refactored LLM communication to facilitate testing and upcoming cha…
Browse files Browse the repository at this point in the history
…nges;
  • Loading branch information
jaltmayerpizzorno committed Jul 31, 2024
1 parent 9959da7 commit 81ed704
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 95 deletions.
108 changes: 14 additions & 94 deletions src/coverup/coverup.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
import asyncio
import json
import warnings

with warnings.catch_warnings():
# ignore pydantic warnings https://github.com/BerriAI/litellm/issues/2832
warnings.simplefilter('ignore')
import litellm # type: ignore

import logging
import openai
import subprocess
import re
import sys
Expand All @@ -17,19 +9,14 @@
from pathlib import Path
from datetime import datetime

from .llm import *
from . import llm
from .segment import *
from .testrunner import *
from .version import __version__
from .utils import summary_coverage
from . import prompt


# Turn off most logging
litellm.set_verbose = False
logging.getLogger().setLevel(logging.ERROR)
# Ignore unavailable parameters
litellm.drop_params=True

def parse_args(args=None):
import argparse
Expand Down Expand Up @@ -344,7 +331,7 @@ def __init__(self, total, initial):

def update_usage(self, usage: dict):
"""Updates the usage display."""
if (cost := compute_cost(usage, args.model)) is not None:
if (cost := llm.compute_cost(usage, args.model)) is not None:
self.postfix['usage'] = f'~${cost:.02f}'
else:
self.postfix['usage'] = f'{usage["prompt_tokens"]}+{usage["completion_tokens"]}'
Expand Down Expand Up @@ -465,74 +452,18 @@ def save_checkpoint(self, ckpt_file: Path):
json.dump(ckpt, f)


state = None
token_rate_limit = None
async def do_chat(seg: CodeSegment, completion: dict) -> str:
"""Sends a GPT chat request, handling common failures and returning the response."""

global token_rate_limit

sleep = 1
while True:
try:
if token_rate_limit:
try:
await token_rate_limit.acquire(count_tokens(args.model, completion))
except ValueError as e:
log_write(seg, f"Error: too many tokens for rate limit ({e})")
return None # gives up this segment

return await litellm.acreate(**completion)

except (litellm.exceptions.ServiceUnavailableError,
openai.RateLimitError,
openai.APITimeoutError) as e:

# This message usually indicates out of money in account
if 'You exceeded your current quota' in str(e):
log_write(seg, f"Failed: {type(e)} {e}")
raise

log_write(seg, f"Error: {type(e)} {e}")

import random
sleep = min(sleep*2, args.max_backoff)
sleep_time = random.uniform(sleep/2, sleep)
state.inc_counter('R')
await asyncio.sleep(sleep_time)

except openai.BadRequestError as e:
# usually "maximum context length" XXX check for this?
log_write(seg, f"Error: {type(e)} {e}")
return None # gives up this segment

except openai.AuthenticationError as e:
log_write(seg, f"Failed: {type(e)} {e}")
raise

except openai.APIConnectionError as e:
log_write(seg, f"Error: {type(e)} {e}")
# usually a server-side error... just retry right away
state.inc_counter('R')

except openai.APIError as e:
# APIError is the base class for all API errors;
# we may be missing a more specific handler.
print(f"Error: {type(e)} {e}; missing handler?")
log_write(seg, f"Error: {type(e)} {e}")
return None # gives up this segment


def extract_python(response: str) -> str:
# This regex accepts a truncated code block... this seems fine since we'll try it anyway
m = re.search(r'```python\n(.*?)(?:```|\Z)', response, re.DOTALL)
if not m: raise RuntimeError(f"Unable to extract Python code from response {response}")
return m.group(1)


async def improve_coverage(seg: CodeSegment) -> bool:
state = None

async def improve_coverage(chatter: llm.Chatter, seg: CodeSegment) -> bool:
"""Works to improve coverage for a code segment."""
global args, progress
global args

def log_prompts(prompts: T.List[dict]):
for p in prompts:
Expand All @@ -553,14 +484,7 @@ def log_prompts(prompts: T.List[dict]):
log_write(seg, "Too many attempts, giving up")
break

completion = {
'model': args.model,
'temperature': args.model_temperature,
'messages': messages,
**({'api_base': "http://localhost:11434"} if "ollama" in args.model else {})
}

if not (response := await do_chat(seg, completion)):
if not (response := await chatter.chat(seg, messages)):
log_write(seg, "giving up")
break

Expand Down Expand Up @@ -658,7 +582,7 @@ def main():
from collections import defaultdict
import os

global args, token_rate_limit, state
global args, state
args = parse_args()

if not args.tests_dir.exists():
Expand All @@ -670,12 +594,12 @@ def main():
add_to_pythonpath(args.module_dir)

if args.prompt_for_tests:
if args.rate_limit or token_rate_limit_for_model(args.model):
limit = (args.rate_limit, 60) if args.rate_limit else token_rate_limit_for_model(args.model)
from aiolimiter import AsyncLimiter
token_rate_limit = AsyncLimiter(*limit)
# TODO also add request limit, and use 'await asyncio.gather(t.acquire(tokens), r.acquire())' to acquire both
chatter = llm.Chatter(model=args.model, model_temperature=args.model_temperature,
log_write=log_write, signal_retry=lambda: state.inc_counter('R'))
chatter.set_max_backoff(args.max_backoff)

if args.rate_limit:
chatter.set_token_rate_limit((args.rate_limit, 60))

# Check for an API key for OpenAI or Amazon Bedrock.
if 'OPENAI_API_KEY' not in os.environ and 'ANTHROPIC_API_KEY' not in os.environ:
Expand Down Expand Up @@ -706,10 +630,6 @@ def main():
print()
return 1

if not args.model:
print("Please specify model to use with --model")
return 1

log_write('startup', f"Command: {' '.join(sys.argv)}")

# --- (1) load or measure initial coverage, figure out segmentation ---
Expand Down Expand Up @@ -751,7 +671,7 @@ def main():
print("(in the following, G=good, F=failed, U=useless and R=retry)")

async def work_segment(seg: CodeSegment) -> None:
if await improve_coverage(seg):
if await improve_coverage(chatter, seg):
# Only mark done if was able to complete (True return),
# so that it can be retried after installing any missing modules
state.mark_done(seg)
Expand Down
108 changes: 107 additions & 1 deletion src/coverup/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
import typing as T
import litellm
import openai
import logging
import asyncio
import warnings
from .segment import CodeSegment

with warnings.catch_warnings():
# ignore pydantic warnings https://github.com/BerriAI/litellm/issues/2832
warnings.simplefilter('ignore')
import litellm # type: ignore


# Turn off most logging
litellm.set_verbose = False
logging.getLogger().setLevel(logging.ERROR)

# Ignore unavailable parameters
litellm.drop_params=True


# Tier 5 rate limits for models; tuples indicate limit and interval in seconds
Expand Down Expand Up @@ -85,3 +102,92 @@ def count_tokens(model_name: str, completion: dict):
count += len(encoding.encode(m['content']))

return count


class Chatter:
def __init__(self, model, model_temperature, log_write, signal_retry):
self.model = model
self.model_temperature = model_temperature
self.max_backoff = 64 # seconds
self.set_token_rate_limit(token_rate_limit_for_model(model))

self.log_write = log_write
self.signal_retry = signal_retry


def set_token_rate_limit(self, limit):
if limit:
from aiolimiter import AsyncLimiter
self.token_rate_limit = AsyncLimiter(*limit)
else:
self.token_rate_limit = None


def set_max_backoff(self, max_backoff):
self.max_backoff = max_backoff


def _completion(self, messages: list) -> dict:
return {
'model': self.model,
'temperature': self.model_temperature,
'messages': messages,
**({'api_base': "http://localhost:11434"} if "ollama" in self.model else {})
}


async def chat(self, seg: CodeSegment, messages: list) -> dict:
"""Sends a GPT chat request, handling common failures and returning the response."""

sleep = 1
while True:
try:
completion = self._completion(messages)
# TODO also add request limit; could use 'await asyncio.gather(t.acquire(tokens), r.acquire())'
# to acquire both
if self.token_rate_limit:
try:
await self.token_rate_limit.acquire(count_tokens(self.model, completion))
except ValueError as e:
self.log_write(seg, f"Error: too many tokens for rate limit ({e})")
return None # gives up this segment

return await litellm.acreate(**completion)

except (litellm.exceptions.ServiceUnavailableError,
openai.RateLimitError,
openai.APITimeoutError) as e:

# This message usually indicates out of money in account
if 'You exceeded your current quota' in str(e):
self.log_write(seg, f"Failed: {type(e)} {e}")
raise

self.log_write(seg, f"Error: {type(e)} {e}")

import random
sleep = min(sleep*2, self.max_backoff)
sleep_time = random.uniform(sleep/2, sleep)
self.signal_retry()
await asyncio.sleep(sleep_time)

except openai.BadRequestError as e:
# usually "maximum context length" XXX check for this?
self.log_write(seg, f"Error: {type(e)} {e}")
return None # gives up this segment

except openai.AuthenticationError as e:
self.log_write(seg, f"Failed: {type(e)} {e}")
raise

except openai.APIConnectionError as e:
self.log_write(seg, f"Error: {type(e)} {e}")
# usually a server-side error... just retry right away
self.signal_retry()

except openai.APIError as e:
# APIError is the base class for all API errors;
# we may be missing a more specific handler.
print(f"Error: {type(e)} {e}; missing handler?")
self.log_write(seg, f"Error: {type(e)} {e}")
return None # gives up this segment

0 comments on commit 81ed704

Please sign in to comment.