diff --git a/easycompletion/__init__.py b/easycompletion/__init__.py index c7e01f8..e496aca 100644 --- a/easycompletion/__init__.py +++ b/easycompletion/__init__.py @@ -4,7 +4,8 @@ text_completion, text_completion_async, chat_completion, - chat_completion_async + chat_completion_async, + build_model_info, ) openai_function_call = function_completion @@ -21,6 +22,7 @@ from .constants import ( TEXT_MODEL, + LONG_TEXT_MODEL, DEFAULT_CHUNK_LENGTH, ) @@ -28,6 +30,9 @@ "function_completion", "text_completion", "chat_completion", + "function_completion_async", + "text_completion_async", + "chat_completion_async", "openai_function_call", "openai_text_call", "compose_prompt", @@ -36,6 +41,8 @@ "chunk_prompt", "count_tokens", "get_tokens", + "build_model_info", "TEXT_MODEL", + "LONG_TEXT_MODEL", "DEFAULT_CHUNK_LENGTH", ] diff --git a/easycompletion/constants.py b/easycompletion/constants.py index c797af3..8ac9f18 100644 --- a/easycompletion/constants.py +++ b/easycompletion/constants.py @@ -3,19 +3,22 @@ load_dotenv() # take environment variables from .env. -TEXT_MODEL = os.getenv("EASYCOMPLETION_TEXT_MODEL") -if TEXT_MODEL == None or TEXT_MODEL == "": - TEXT_MODEL = "gpt-3.5-turbo-0613" -LONG_TEXT_MODEL = os.getenv("EASYCOMPLETION_LONG_TEXT_MODEL") -if LONG_TEXT_MODEL == None or LONG_TEXT_MODEL == "": - LONG_TEXT_MODEL = "gpt-3.5-turbo-16k" +TEXT_MODEL = os.getenv("EASYCOMPLETION_TEXT_MODEL") or "gpt-3.5-turbo-0613" +TEXT_MODEL_WINDOW = os.getenv("EASYCOMPLETION_TEXT_MODEL_WINDOW") or 4096 +LONG_TEXT_MODEL = os.getenv("EASYCOMPLETION_LONG_TEXT_MODEL") or "gpt-3.5-turbo-16k" +LONG_TEXT_MODEL_WINDOW = os.getenv("EASYCOMPLETION_LONG_TEXT_MODEL_WINDOW") or 16*1024 -EASYCOMPLETION_API_KEY = os.getenv("OPENAI_API_KEY") -if EASYCOMPLETION_API_KEY is None: - EASYCOMPLETION_API_KEY = os.getenv("EASYCOMPLETION_API_KEY") +EASYCOMPLETION_API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("EASYCOMPLETION_API_KEY") EASYCOMPLETION_API_ENDPOINT = os.getenv("EASYCOMPLETION_API_ENDPOINT") or "https://api.openai.com/v1" -DEBUG = os.environ.get("EASYCOMPLETION_DEBUG") == "true" or os.environ.get("EASYCOMPLETION_DEBUG") == "True" +DEBUG = (os.environ.get("EASYCOMPLETION_DEBUG") or '').lower() == "true" +SUPPRESS_WARNINGS = (os.environ.get("SUPPRESS_WARNINGS") or '').lower() == "true" -DEFAULT_CHUNK_LENGTH = 4096 * 3 / 4 # 3/4ths of the context window size +DEFAULT_CHUNK_LENGTH = os.getenv("DEFAULT_CHUNK_LENGTH") or (TEXT_MODEL_WINDOW * 3 // 4) # 3/4ths of the context window size + +DEFAULT_MODEL_INFO = [ + (TEXT_MODEL, DEFAULT_CHUNK_LENGTH), + (LONG_TEXT_MODEL, LONG_TEXT_MODEL_WINDOW - DEFAULT_CHUNK_LENGTH) + # In the second case, DEFAULT_CHUNK_LENGTH is used as buffer +] diff --git a/easycompletion/model.py b/easycompletion/model.py index 41be67d..f143b5e 100644 --- a/easycompletion/model.py +++ b/easycompletion/model.py @@ -7,6 +7,7 @@ import asyncio from dotenv import load_dotenv +import tiktoken # Load environment variables from .env file load_dotenv() @@ -14,10 +15,14 @@ from .constants import ( EASYCOMPLETION_API_ENDPOINT, TEXT_MODEL, + TEXT_MODEL_WINDOW, LONG_TEXT_MODEL, + LONG_TEXT_MODEL_WINDOW, EASYCOMPLETION_API_KEY, DEFAULT_CHUNK_LENGTH, + DEFAULT_MODEL_INFO, DEBUG, + SUPPRESS_WARNINGS, ) from .logger import log @@ -154,47 +159,91 @@ def validate_functions(response, functions, function_call, debug=DEBUG): log("Function call is valid", type="success", log=debug) return True -def sanity_check(prompt, model=None, chunk_length=DEFAULT_CHUNK_LENGTH, api_key=EASYCOMPLETION_API_KEY, debug=DEBUG): + +def is_long_model(model_name): + return "16k" in model_name + +def build_model_info(model_names, factor=0.75): + return [ + (model_name, + int(factor * (LONG_TEXT_MODEL_WINDOW if is_long_model(model_name) else TEXT_MODEL_WINDOW))) + for model_name in model_names + ] + + +def sanity_check(prompt, model=None, model_info=None, chunk_length=DEFAULT_CHUNK_LENGTH, api_key=EASYCOMPLETION_API_KEY, debug=DEBUG): # Validate the API key if not api_key.strip(): - return model, {"error": "Invalid OpenAI API key"} + return [], {"error": "Invalid OpenAI API key"} openai.api_key = api_key - - # Count tokens in the input text - total_tokens = count_tokens(prompt, model=model) - - # If text is longer than chunk_length and model is not for long texts, switch to the long text model - if total_tokens > chunk_length and "16k" not in model: - model = LONG_TEXT_MODEL - if not os.environ.get("SUPPRESS_WARNINGS"): - print( - "Warning: Message is long. Using 16k model (to hide this message, set SUPPRESS_WARNINGS=1)" - ) + # Construct a model_info from legacy parameters + if chunk_length not in (None, DEFAULT_CHUNK_LENGTH): + log("Warning: deprecated use of chuck_length. Please use model_info.", + type="warning", log=not SUPPRESS_WARNINGS) + else: + chunk_length = chunk_length or DEFAULT_CHUNK_LENGTH + if model is not None: + if model == TEXT_MODEL and chunk_length == DEFAULT_CHUNK_LENGTH: + log("Warning: deprecated use of model, use model_info", + type="warning", log=not SUPPRESS_WARNINGS) + model_info = DEFAULT_MODEL_INFO + else: + log("Warning: deprecated use of model. Assuming long_model allowed. Use model_info otherwise.", + type="warning", log=not SUPPRESS_WARNINGS) + model_info = ((model, chunk_length), (LONG_TEXT_MODEL, LONG_TEXT_MODEL_WINDOW - DEFAULT_CHUNK_LENGTH)) + elif chunk_length != DEFAULT_CHUNK_LENGTH: + log("Warning: deprecated use of chuck_length. Please use model_info.", + type="warning", log=not SUPPRESS_WARNINGS) + model_info = ((TEXT_MODEL, chunk_length), (LONG_TEXT_MODEL, LONG_TEXT_MODEL_WINDOW - chunk_length)) + else: + model_info = model_info or DEFAULT_MODEL_INFO + + model_info = sorted(model_info, key=lambda i: i[1]) + + # Names of long enough models + models = [] + len_by_encoding = {} + len_by_model = {} + for model, chunk_length in model_info: + encoding = tiktoken.encoding_for_model(model) + if encoding not in len_by_encoding: + # Count of tokens in the input text + len_by_encoding[encoding] = count_tokens(prompt, model=model) + len_by_model[model] = (token_len := len_by_encoding[encoding]) + if token_len <= chunk_length: + models.append(model) # If text is too long even for long text model, return None - if total_tokens > (16384 - chunk_length): + if not models: print("Error: Message too long") - return model, { + return models, { "text": None, "usage": None, "finish_reason": None, "error": "Message too long", } + if models[0] != model_info[0][0]: + log("Warning: Message is long. Using larger models (to hide this message, set SUPPRESS_WARNINGS=1)", + type="warning", log=not SUPPRESS_WARNINGS) + + total_tokens = len_by_model[models[0]] # First appropriate model + if isinstance(prompt, dict): for key, value in prompt.items(): if value: - log(f"Prompt {key} ({count_tokens(value)} tokens):\n{str(value)}", type="prompt", log=debug) + log(f"Prompt {key} ({count_tokens(value, model=models[0])} tokens):\n{str(value)}", type="prompt", log=debug) else: log(f"Prompt ({total_tokens} tokens):\n{str(prompt)}", type="prompt", log=debug) - return model, None + return models, None def do_chat_completion( - messages, model=TEXT_MODEL, temperature=0.8, functions=None, function_call=None, model_failure_retries=5, debug=DEBUG): + messages, models, temperature=0.8, functions=None, function_call=None, model_failure_retries=5, debug=DEBUG): # Try to make a request for a specified number of times response = None + model = models[0] for i in range(model_failure_retries): try: if functions is not None: @@ -206,17 +255,19 @@ def do_chat_completion( response = openai.ChatCompletion.create( model=model, messages=messages, temperature=temperature ) - print('response') - print(response) + log('response', log=debug) + log(response, log=debug) break except Exception as e: log(f"OpenAI Error: {e}", type="error", log=debug) + # TODO: Are there other reasons to try fallback models? + # If response is not valid, print an error message and return None if ( - response is None - or response["choices"] is None - or response["choices"][0] is None + not response + or not response.get("choices") + or not response["choices"][0] ): return None, { "text": None, @@ -224,12 +275,28 @@ def do_chat_completion( "finish_reason": None, "error": "Error: Could not get a successful response from OpenAI API", } + + # Check if failed for length reasons. + choices = response.get("choices", []) + if choices and all(choice.get("finish_reason", None) == 'length' for choice in choices): + models.pop(0) # Side effect: Do not ever retry that model on that prompt + if models: + log("Failed because of length, trying next model", log=debug) + return do_chat_completion(messages, models, temperature, functions, function_call, model_failure_retries, debug) + return None, { + "text": None, + "usage": None, + "finish_reason": 'length', + "error": "Error: The prompt elicits too-long responses", + } + return response, None def chat_completion( messages, model_failure_retries=5, model=None, + model_info=None, chunk_length=DEFAULT_CHUNK_LENGTH, api_key=EASYCOMPLETION_API_KEY, debug=DEBUG, @@ -241,8 +308,9 @@ def chat_completion( Parameters: messages (str): Messages to send to the model. In the form {: string, : string} - roles are "user" and "assistant" model_failure_retries (int, optional): Number of retries if the request fails. Default is 5. - model (str, optional): The model to use. Default is the TEXT_MODEL defined in constants.py. - chunk_length (int, optional): Maximum length of text chunk to process. Default is defined in constants.py. + model (str, optional): The model to use. Deprecated. + chunk_length (int, optional): Maximum length of text chunk to process. Deprecated. + model_info (List[Tuple[str, int]], optional): The list of models to use, and their respective chuck length. Default is the DEFAULT_MODEL_INFO defined in constants.py. api_key (str, optional): OpenAI API key. If not provided, it uses the one defined in constants.py. Returns: @@ -254,14 +322,14 @@ def chat_completion( openai.api_key = api_key # Use the default model if no model is specified - model = model or TEXT_MODEL - model, error = sanity_check(messages, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) + + models, error = sanity_check(messages, model_info=model_info, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) if error: return error # Try to make a request for a specified number of times response, error = do_chat_completion( - model=model, messages=messages, temperature=temperature, model_failure_retries=model_failure_retries, debug=debug) + messages, models, temperature=temperature, model_failure_retries=model_failure_retries, debug=debug) if error: return error @@ -283,6 +351,7 @@ async def chat_completion_async( messages, model_failure_retries=5, model=None, + model_info=None, chunk_length=DEFAULT_CHUNK_LENGTH, api_key=EASYCOMPLETION_API_KEY, debug=DEBUG, @@ -294,8 +363,9 @@ async def chat_completion_async( Parameters: messages (str): Messages to send to the model. In the form {: string, : string} - roles are "user" and "assistant" model_failure_retries (int, optional): Number of retries if the request fails. Default is 5. - model (str, optional): The model to use. Default is the TEXT_MODEL defined in constants.py. - chunk_length (int, optional): Maximum length of text chunk to process. Default is defined in constants.py. + model (str, optional): The model to use. Deprecated. + chunk_length (int, optional): Maximum length of text chunk to process. Deprecated. + model_info (List[Tuple[str, int]], optional): The list of models to use, and their respective chuck length. Default is the DEFAULT_MODEL_INFO defined in constants.py. api_key (str, optional): OpenAI API key. If not provided, it uses the one defined in constants.py. Returns: @@ -307,13 +377,13 @@ async def chat_completion_async( # Use the default model if no model is specified model = model or TEXT_MODEL - model, error = sanity_check(messages, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) + models, error = sanity_check(messages, model_info=model_info, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) if error: return error # Try to make a request for a specified number of times response, error = await asyncio.to_thread(lambda: do_chat_completion( - model=model, messages=messages, temperature=temperature, model_failure_retries=model_failure_retries, debug=debug)) + messages, models, temperature=temperature, model_failure_retries=model_failure_retries, debug=debug)) if error: return error @@ -335,6 +405,7 @@ def text_completion( text, model_failure_retries=5, model=None, + model_info=None, chunk_length=DEFAULT_CHUNK_LENGTH, api_key=EASYCOMPLETION_API_KEY, debug=DEBUG, @@ -346,8 +417,9 @@ def text_completion( Parameters: text (str): Text to send to the model. model_failure_retries (int, optional): Number of retries if the request fails. Default is 5. - model (str, optional): The model to use. Default is the TEXT_MODEL defined in constants.py. - chunk_length (int, optional): Maximum length of text chunk to process. Default is defined in constants.py. + model (str, optional): The model to use. Deprecated. + chunk_length (int, optional): Maximum length of text chunk to process. Deprecated. + model_info (List[Tuple[str, int]], optional): The list of models to use, and their respective chuck length. Default is the DEFAULT_MODEL_INFO defined in constants.py. api_key (str, optional): OpenAI API key. If not provided, it uses the one defined in constants.py. Returns: @@ -358,8 +430,7 @@ def text_completion( """ # Use the default model if no model is specified - model = model or TEXT_MODEL - model, error = sanity_check(text, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) + models, error = sanity_check(text, model_info=model_info, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) if error: return error @@ -368,7 +439,7 @@ def text_completion( # Try to make a request for a specified number of times response, error = do_chat_completion( - model=model, messages=messages, temperature=temperature, model_failure_retries=model_failure_retries, debug=debug) + messages, models, temperature=temperature, model_failure_retries=model_failure_retries, debug=debug) if error: return error @@ -388,6 +459,7 @@ async def text_completion_async( text, model_failure_retries=5, model=None, + model_info=None, chunk_length=DEFAULT_CHUNK_LENGTH, api_key=EASYCOMPLETION_API_KEY, debug=DEBUG, @@ -399,8 +471,9 @@ async def text_completion_async( Parameters: text (str): Text to send to the model. model_failure_retries (int, optional): Number of retries if the request fails. Default is 5. - model (str, optional): The model to use. Default is the TEXT_MODEL defined in constants.py. - chunk_length (int, optional): Maximum length of text chunk to process. Default is defined in constants.py. + model (str, optional): The model to use. Deprecated. + chunk_length (int, optional): Maximum length of text chunk to process. Deprecated. + model_info (List[Tuple[str, int]], optional): The list of models to use, and their respective chuck length. Default is the DEFAULT_MODEL_INFO defined in constants.py. api_key (str, optional): OpenAI API key. If not provided, it uses the one defined in constants.py. Returns: @@ -411,8 +484,7 @@ async def text_completion_async( """ # Use the default model if no model is specified - model = model or TEXT_MODEL - model, error = sanity_check(text, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) + models, error = sanity_check(text, model_info=model_info, model=model, chunk_length=chunk_length, api_key=api_key, debug=debug) if error: return error @@ -421,7 +493,7 @@ async def text_completion_async( # Try to make a request for a specified number of times response, error = await asyncio.to_thread(lambda: do_chat_completion( - model=model, messages=messages, temperature=temperature, model_failure_retries=model_failure_retries, debug=debug)) + messages, models, temperature=temperature, model_failure_retries=model_failure_retries, debug=debug)) if error: return error @@ -449,6 +521,7 @@ def function_completion( function_failure_retries=10, chunk_length=DEFAULT_CHUNK_LENGTH, model=None, + model_info=None, api_key=EASYCOMPLETION_API_KEY, debug=DEBUG, temperature=0.0, @@ -464,8 +537,9 @@ def function_completion( model_failure_retries (int): Number of times to retry the request if it fails (default is 5). function_call (str | dict | None): 'auto' to let the model decide, or a function name or a dictionary containing the function name (default is "auto"). function_failure_retries (int): Number of times to retry the request if the function call is invalid (default is 10). - chunk_length (int): The length of each chunk to be processed. - model (str | None): The model to use (default is the TEXT_MODEL, i.e. gpt-3.5-turbo). + model (str, optional): The model to use. Deprecated. + chunk_length (int, optional): Maximum length of text chunk to process. Deprecated. + model_info (List[Tuple[str, int]], optional): The list of models to use, and their respective chuck length. Default is the DEFAULT_MODEL_INFO defined in constants.py. api_key (str | None): If you'd like to pass in a key to override the environment variable EASYCOMPLETION_API_KEY. Returns: @@ -477,9 +551,6 @@ def function_completion( >>> function_completion("Call the function.", function) """ - # Use the default model if no model is specified - model = model or TEXT_MODEL - # Ensure that functions are provided if functions is None: return {"error": "functions is required"} @@ -534,17 +605,17 @@ def function_completion( "error": "function_call had an invalid name. Should be a string of the function name or an object with a name property" } - model, error = sanity_check(dict( + models, error = sanity_check(dict( text=text, functions=functions, messages=messages, system_message=system_message - ), model=model, chunk_length=chunk_length, api_key=api_key) + ), model_info=model_info, model=model, chunk_length=chunk_length, api_key=api_key) if error: return error # Count the number of tokens in the message - message_tokens = count_tokens(text, model=model) + message_tokens = count_tokens(text, model=models[0]) total_tokens = message_tokens - function_call_tokens = count_tokens(functions, model=model) + function_call_tokens = count_tokens(functions, model=models[0]) total_tokens += function_call_tokens + 3 # Additional tokens for the user all_messages = [] @@ -564,7 +635,7 @@ def function_completion( for _ in range(function_failure_retries): # Try to make a request for a specified number of times response, error = do_chat_completion( - model=model, messages=all_messages, temperature=temperature, function_call=function_call, + all_messages, models, temperature=temperature, function_call=function_call, functions=functions, model_failure_retries=model_failure_retries, debug=debug) if error: time.sleep(1) @@ -573,6 +644,13 @@ def function_completion( print(response) if validate_functions(response, functions, function_call): break + if response.get("choices", [{}])[0].get("finish_reason", None) == 'length': + return { + "text": None, + "usage": response.usage, + "finish_reason": 'length', + "error": "Message too long", + } time.sleep(1) # Check if we have a valid response from the model @@ -618,6 +696,7 @@ async def function_completion_async( function_failure_retries=10, chunk_length=DEFAULT_CHUNK_LENGTH, model=None, + model_info=None, api_key=EASYCOMPLETION_API_KEY, debug=DEBUG, temperature=0.0, @@ -633,8 +712,9 @@ async def function_completion_async( model_failure_retries (int): Number of times to retry the request if it fails (default is 5). function_call (str | dict | None): 'auto' to let the model decide, or a function name or a dictionary containing the function name (default is "auto"). function_failure_retries (int): Number of times to retry the request if the function call is invalid (default is 10). - chunk_length (int): The length of each chunk to be processed. - model (str | None): The model to use (default is the TEXT_MODEL, i.e. gpt-3.5-turbo). + model (str, optional): The model to use. Deprecated. + chunk_length (int, optional): Maximum length of text chunk to process. Deprecated. + model_info (List[Tuple[str, int]], optional): The list of models to use, and their respective chuck length. Default is the DEFAULT_MODEL_INFO defined in constants.py. api_key (str | None): If you'd like to pass in a key to override the environment variable EASYCOMPLETION_API_KEY. Returns: @@ -646,9 +726,6 @@ async def function_completion_async( >>> function_completion("Call the function.", function) """ - # Use the default model if no model is specified - model = model or TEXT_MODEL - # Ensure that functions are provided if functions is None: return {"error": "functions is required"} @@ -703,17 +780,17 @@ async def function_completion_async( "error": "function_call had an invalid name. Should be a string of the function name or an object with a name property" } - model, error = sanity_check(dict( + models, error = sanity_check(dict( text=text, functions=functions, messages=messages, system_message=system_message - ), model=model, chunk_length=chunk_length, api_key=api_key) + ), model_info=model_info, model=model, chunk_length=chunk_length, api_key=api_key) if error: return error # Count the number of tokens in the message - message_tokens = count_tokens(text, model=model) + message_tokens = count_tokens(text, model=models[0]) total_tokens = message_tokens - function_call_tokens = count_tokens(functions, model=model) + function_call_tokens = count_tokens(functions, model=models[0]) total_tokens += function_call_tokens + 3 # Additional tokens for the user all_messages = [] @@ -733,7 +810,7 @@ async def function_completion_async( for _ in range(function_failure_retries): # Try to make a request for a specified number of times response, error = await asyncio.to_thread(lambda: do_chat_completion( - model=model, messages=all_messages, temperature=temperature, function_call=function_call, + all_messages, models, temperature=temperature, function_call=function_call, functions=functions, model_failure_retries=model_failure_retries, debug=debug)) if error: time.sleep(1) @@ -742,6 +819,13 @@ async def function_completion_async( print(response) if validate_functions(response, functions, function_call): break + if response.get("choices", [{}])[0].get("finish_reason", None) == 'length': + return { + "text": None, + "usage": response.usage, + "finish_reason": 'length', + "error": "Message too long", + } time.sleep(1) # Check if we have a valid response from the model