From 81ed704490ee904fb517ce840da9da69463367a2 Mon Sep 17 00:00:00 2001 From: Juan Altmayer Pizzorno Date: Wed, 31 Jul 2024 16:39:37 -0400 Subject: [PATCH] - refactored LLM communication to facilitate testing and upcoming changes; --- src/coverup/coverup.py | 108 ++++++----------------------------------- src/coverup/llm.py | 108 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 121 insertions(+), 95 deletions(-) diff --git a/src/coverup/coverup.py b/src/coverup/coverup.py index 5d4cfd9..ccf702f 100644 --- a/src/coverup/coverup.py +++ b/src/coverup/coverup.py @@ -1,14 +1,6 @@ import asyncio import json -import warnings -with warnings.catch_warnings(): - # ignore pydantic warnings https://github.com/BerriAI/litellm/issues/2832 - warnings.simplefilter('ignore') - import litellm # type: ignore - -import logging -import openai import subprocess import re import sys @@ -17,7 +9,7 @@ from pathlib import Path from datetime import datetime -from .llm import * +from . import llm from .segment import * from .testrunner import * from .version import __version__ @@ -25,11 +17,6 @@ from . import prompt -# Turn off most logging -litellm.set_verbose = False -logging.getLogger().setLevel(logging.ERROR) -# Ignore unavailable parameters -litellm.drop_params=True def parse_args(args=None): import argparse @@ -344,7 +331,7 @@ def __init__(self, total, initial): def update_usage(self, usage: dict): """Updates the usage display.""" - if (cost := compute_cost(usage, args.model)) is not None: + if (cost := llm.compute_cost(usage, args.model)) is not None: self.postfix['usage'] = f'~${cost:.02f}' else: self.postfix['usage'] = f'{usage["prompt_tokens"]}+{usage["completion_tokens"]}' @@ -465,64 +452,6 @@ def save_checkpoint(self, ckpt_file: Path): json.dump(ckpt, f) -state = None -token_rate_limit = None -async def do_chat(seg: CodeSegment, completion: dict) -> str: - """Sends a GPT chat request, handling common failures and returning the response.""" - - global token_rate_limit - - sleep = 1 - while True: - try: - if token_rate_limit: - try: - await token_rate_limit.acquire(count_tokens(args.model, completion)) - except ValueError as e: - log_write(seg, f"Error: too many tokens for rate limit ({e})") - return None # gives up this segment - - return await litellm.acreate(**completion) - - except (litellm.exceptions.ServiceUnavailableError, - openai.RateLimitError, - openai.APITimeoutError) as e: - - # This message usually indicates out of money in account - if 'You exceeded your current quota' in str(e): - log_write(seg, f"Failed: {type(e)} {e}") - raise - - log_write(seg, f"Error: {type(e)} {e}") - - import random - sleep = min(sleep*2, args.max_backoff) - sleep_time = random.uniform(sleep/2, sleep) - state.inc_counter('R') - await asyncio.sleep(sleep_time) - - except openai.BadRequestError as e: - # usually "maximum context length" XXX check for this? - log_write(seg, f"Error: {type(e)} {e}") - return None # gives up this segment - - except openai.AuthenticationError as e: - log_write(seg, f"Failed: {type(e)} {e}") - raise - - except openai.APIConnectionError as e: - log_write(seg, f"Error: {type(e)} {e}") - # usually a server-side error... just retry right away - state.inc_counter('R') - - except openai.APIError as e: - # APIError is the base class for all API errors; - # we may be missing a more specific handler. - print(f"Error: {type(e)} {e}; missing handler?") - log_write(seg, f"Error: {type(e)} {e}") - return None # gives up this segment - - def extract_python(response: str) -> str: # This regex accepts a truncated code block... this seems fine since we'll try it anyway m = re.search(r'```python\n(.*?)(?:```|\Z)', response, re.DOTALL) @@ -530,9 +459,11 @@ def extract_python(response: str) -> str: return m.group(1) -async def improve_coverage(seg: CodeSegment) -> bool: +state = None + +async def improve_coverage(chatter: llm.Chatter, seg: CodeSegment) -> bool: """Works to improve coverage for a code segment.""" - global args, progress + global args def log_prompts(prompts: T.List[dict]): for p in prompts: @@ -553,14 +484,7 @@ def log_prompts(prompts: T.List[dict]): log_write(seg, "Too many attempts, giving up") break - completion = { - 'model': args.model, - 'temperature': args.model_temperature, - 'messages': messages, - **({'api_base': "http://localhost:11434"} if "ollama" in args.model else {}) - } - - if not (response := await do_chat(seg, completion)): + if not (response := await chatter.chat(seg, messages)): log_write(seg, "giving up") break @@ -658,7 +582,7 @@ def main(): from collections import defaultdict import os - global args, token_rate_limit, state + global args, state args = parse_args() if not args.tests_dir.exists(): @@ -670,12 +594,12 @@ def main(): add_to_pythonpath(args.module_dir) if args.prompt_for_tests: - if args.rate_limit or token_rate_limit_for_model(args.model): - limit = (args.rate_limit, 60) if args.rate_limit else token_rate_limit_for_model(args.model) - from aiolimiter import AsyncLimiter - token_rate_limit = AsyncLimiter(*limit) - # TODO also add request limit, and use 'await asyncio.gather(t.acquire(tokens), r.acquire())' to acquire both + chatter = llm.Chatter(model=args.model, model_temperature=args.model_temperature, + log_write=log_write, signal_retry=lambda: state.inc_counter('R')) + chatter.set_max_backoff(args.max_backoff) + if args.rate_limit: + chatter.set_token_rate_limit((args.rate_limit, 60)) # Check for an API key for OpenAI or Amazon Bedrock. if 'OPENAI_API_KEY' not in os.environ and 'ANTHROPIC_API_KEY' not in os.environ: @@ -706,10 +630,6 @@ def main(): print() return 1 - if not args.model: - print("Please specify model to use with --model") - return 1 - log_write('startup', f"Command: {' '.join(sys.argv)}") # --- (1) load or measure initial coverage, figure out segmentation --- @@ -751,7 +671,7 @@ def main(): print("(in the following, G=good, F=failed, U=useless and R=retry)") async def work_segment(seg: CodeSegment) -> None: - if await improve_coverage(seg): + if await improve_coverage(chatter, seg): # Only mark done if was able to complete (True return), # so that it can be retried after installing any missing modules state.mark_done(seg) diff --git a/src/coverup/llm.py b/src/coverup/llm.py index 749316f..827d78f 100644 --- a/src/coverup/llm.py +++ b/src/coverup/llm.py @@ -1,5 +1,22 @@ import typing as T -import litellm +import openai +import logging +import asyncio +import warnings +from .segment import CodeSegment + +with warnings.catch_warnings(): + # ignore pydantic warnings https://github.com/BerriAI/litellm/issues/2832 + warnings.simplefilter('ignore') + import litellm # type: ignore + + +# Turn off most logging +litellm.set_verbose = False +logging.getLogger().setLevel(logging.ERROR) + +# Ignore unavailable parameters +litellm.drop_params=True # Tier 5 rate limits for models; tuples indicate limit and interval in seconds @@ -85,3 +102,92 @@ def count_tokens(model_name: str, completion: dict): count += len(encoding.encode(m['content'])) return count + + +class Chatter: + def __init__(self, model, model_temperature, log_write, signal_retry): + self.model = model + self.model_temperature = model_temperature + self.max_backoff = 64 # seconds + self.set_token_rate_limit(token_rate_limit_for_model(model)) + + self.log_write = log_write + self.signal_retry = signal_retry + + + def set_token_rate_limit(self, limit): + if limit: + from aiolimiter import AsyncLimiter + self.token_rate_limit = AsyncLimiter(*limit) + else: + self.token_rate_limit = None + + + def set_max_backoff(self, max_backoff): + self.max_backoff = max_backoff + + + def _completion(self, messages: list) -> dict: + return { + 'model': self.model, + 'temperature': self.model_temperature, + 'messages': messages, + **({'api_base': "http://localhost:11434"} if "ollama" in self.model else {}) + } + + + async def chat(self, seg: CodeSegment, messages: list) -> dict: + """Sends a GPT chat request, handling common failures and returning the response.""" + + sleep = 1 + while True: + try: + completion = self._completion(messages) + # TODO also add request limit; could use 'await asyncio.gather(t.acquire(tokens), r.acquire())' + # to acquire both + if self.token_rate_limit: + try: + await self.token_rate_limit.acquire(count_tokens(self.model, completion)) + except ValueError as e: + self.log_write(seg, f"Error: too many tokens for rate limit ({e})") + return None # gives up this segment + + return await litellm.acreate(**completion) + + except (litellm.exceptions.ServiceUnavailableError, + openai.RateLimitError, + openai.APITimeoutError) as e: + + # This message usually indicates out of money in account + if 'You exceeded your current quota' in str(e): + self.log_write(seg, f"Failed: {type(e)} {e}") + raise + + self.log_write(seg, f"Error: {type(e)} {e}") + + import random + sleep = min(sleep*2, self.max_backoff) + sleep_time = random.uniform(sleep/2, sleep) + self.signal_retry() + await asyncio.sleep(sleep_time) + + except openai.BadRequestError as e: + # usually "maximum context length" XXX check for this? + self.log_write(seg, f"Error: {type(e)} {e}") + return None # gives up this segment + + except openai.AuthenticationError as e: + self.log_write(seg, f"Failed: {type(e)} {e}") + raise + + except openai.APIConnectionError as e: + self.log_write(seg, f"Error: {type(e)} {e}") + # usually a server-side error... just retry right away + self.signal_retry() + + except openai.APIError as e: + # APIError is the base class for all API errors; + # we may be missing a more specific handler. + print(f"Error: {type(e)} {e}; missing handler?") + self.log_write(seg, f"Error: {type(e)} {e}") + return None # gives up this segment