From defa6f44ab18e430e659481e7bc39cc2fc8e6da8 Mon Sep 17 00:00:00 2001 From: Philip May Date: Sun, 7 Jan 2024 14:26:04 +0100 Subject: [PATCH] Add support for new OpenAI API. (#144) * Update openai version in pyproject.toml * add OpenAiChatCompletion * Update OpenAI client and add AzureOpenAI support * Add OpenAiChatResult and OpenAiChat classes * Refactor OpenAiChat.create_model_response method signature and add docstring * Add docstring to OpenAiChat class * improve doc * Add OpenAiChat tests for error handling * Update OpenAiAzureChat class constructor * improve typing * Fix method signature in OpenAiChat class --- mltb2/openai.py | 275 +++++++++++++++++++++---------------------- pyproject.toml | 2 +- tests/test_openai.py | 35 +++++- 3 files changed, 166 insertions(+), 146 deletions(-) diff --git a/mltb2/openai.py b/mltb2/openai.py index 7ff38f2..264bc79 100644 --- a/mltb2/openai.py +++ b/mltb2/openai.py @@ -10,14 +10,13 @@ """ -from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, Iterable, List, Mapping, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union import tiktoken import yaml -from openai import ChatCompletion, Completion -from openai.openai_object import OpenAIObject +from openai import AzureOpenAI, OpenAI +from openai.types.chat import ChatCompletion from tiktoken.core import Encoding from tqdm import tqdm @@ -67,206 +66,194 @@ def __call__(self, text: Union[str, Iterable]) -> Union[int, List[int]]: @dataclass -class OpenAiCompletionAnswer: - """Answer of an OpenAI completion. +class OpenAiChatResult: + """Result of an OpenAI chat completion. + + If you want to convert this to a ``dict`` use ``asdict(open_ai_chat_result)`` + from the ``dataclasses`` module. + + See Also: + OpenAI API reference: `The chat completion object `_ Args: - text: the result of the OpenAI completion + content: the result of the OpenAI completion model: model name which has been used prompt_tokens: number of tokens of the prompt - completion_tokens: number of tokens of the completion (``text``) - total_tokens: number of total tokens (``prompt_tokens + completion_tokens``) + completion_tokens: number of tokens of the completion (``content``) + total_tokens: number of total tokens (``prompt_tokens + content_tokens``) finish_reason: The reason why the completion stopped. * ``stop``: Means the API returned the full completion without running into any token limit. * ``length``: Means the API stopped the completion because of running into a token limit. - * ``function_call``: When the model called a function. + * ``content_filter``: When content was omitted due to a flag from the OpenAI content filters. + * ``tool_calls``: When the model called a tool. + * ``function_call`` (deprecated): When the model called a function. - See Also: - * `The chat completion object `_ - * `The completion object `_ + completion_args: The arguments which have been used for the completion. Examples: + + * ``model``: always set + * ``max_tokens``: only set if ``completion_kwargs`` contained ``max_tokens`` + * ``temperature``: only set if ``completion_kwargs`` contained ``temperature`` + * ``top_p``: only set if ``completion_kwargs`` contained ``top_p`` """ - text: Optional[str] = None + content: Optional[str] = None model: Optional[str] = None prompt_tokens: Optional[int] = None completion_tokens: Optional[int] = None total_tokens: Optional[int] = None finish_reason: Optional[str] = None + completion_args: Optional[Dict[str, Any]] = None @classmethod - def from_open_ai_object(cls, open_ai_object: OpenAIObject): - """Construct this class from ``OpenAIObject``.""" + def from_chat_completion( + cls, + chat_completion: ChatCompletion, + completion_kwargs: Optional[Dict[str, Any]] = None, + ): + """Construct this class from an OpenAI ``ChatCompletion`` object. + + Args: + chat_completion: The OpenAI ``ChatCompletion`` object. + completion_kwargs: The arguments which have been used for the completion. + Returns: + The constructed class. + """ result = {} - result["model"] = open_ai_object.get("model") - usage = open_ai_object.get("usage") + result["completion_args"] = completion_kwargs + chat_completion_dict = chat_completion.model_dump() + result["model"] = chat_completion_dict.get("model") + usage = chat_completion_dict.get("usage") if usage is not None: result["prompt_tokens"] = usage.get("prompt_tokens") result["completion_tokens"] = usage.get("completion_tokens") result["total_tokens"] = usage.get("total_tokens") - choices = open_ai_object.get("choices") + choices = chat_completion_dict.get("choices") if choices is not None and len(choices) > 0: choice = choices[0] result["finish_reason"] = choice.get("finish_reason") - if "text" in choice: # non chat models - result["text"] = choice.get("text") - elif "message" in choice: # chat models - message = choice.get("message") - if message is not None: - result["text"] = message.get("content") - return cls(**result) + message = choice.get("message") + if message is not None: + result["content"] = message.get("content") + return cls(**result) # type: ignore[arg-type] @dataclass -class OpenAiBaseCompletion(ABC): - """Abstract base class for OpenAI completion. +class OpenAiChat: + """Tool to interact with OpenAI chat models. - Args: - completion_kwargs: kwargs for the OpenAI completion function + This also be constructed with :meth:`~OpenAiChat.from_yaml`. See Also: - * `Create chat completion `_ - * `Create completion `_ + OpenAI API reference: `Create chat completion `_ + + Args: + api_key: The OpenAI API key. + model: The OpenAI model name. """ - completion_kwargs: Dict[str, Any] + api_key: str + model: str + client: Union[OpenAI, AzureOpenAI] = field(init=False, repr=False) + + def __post_init__(self) -> None: + """Do post init.""" + self.client = OpenAI(api_key=self.api_key) @classmethod def from_yaml(cls, yaml_file): - """Construct this class from a yaml file.""" + """Construct this class from a yaml file. + + Args: + yaml_file: The yaml file. + Returns: + The constructed class. + """ with open(yaml_file, "r") as file: completion_kwargs = yaml.safe_load(file) - return cls(completion_kwargs) - - @abstractmethod - def _completion( - self, prompt: Union[str, List[Dict[str, str]]], completion_kwargs_for_this_call: Mapping[str, Any] - ) -> OpenAIObject: - """Abstract method to call the OpenAI completion.""" + return cls(**completion_kwargs) def __call__( - self, prompt: Union[str, List[Dict[str, str]]], completion_kwargs: Optional[Mapping[str, Any]] = None - ) -> OpenAiCompletionAnswer: - """Call the OpenAI prompt completion. + self, + prompt: Union[str, List[Dict[str, str]]], + completion_kwargs: Optional[Dict[str, Any]] = None, + ) -> OpenAiChatResult: + """Create a model response for the given prompt (chat conversation). Args: - prompt: The prompt to be completed by the LLM. - In case of chat models this can be a string or a list. - In case of "non chat" models only a string is allowed. - completion_kwargs: Overwrite the ``completion_kwargs`` for this call. - This allows you, for example, to change the temperature for this call only. - """ - completion_kwargs_for_this_call = self.completion_kwargs.copy() - if completion_kwargs is not None: - completion_kwargs_for_this_call.update(completion_kwargs) - open_ai_object: OpenAIObject = self._completion(prompt, completion_kwargs_for_this_call) - open_ai_completion_answer = OpenAiCompletionAnswer.from_open_ai_object(open_ai_object) - return open_ai_completion_answer + prompt: The prompt for the model. + completion_kwargs: Keyword arguments for the OpenAI completion. + - ``model`` can not be set via ``completion_kwargs``! Please set the ``model`` in the initializer. + - ``messages`` can not be set via ``completion_kwargs``! Please set the ``prompt`` argument. -@dataclass -class OpenAiChatCompletion(OpenAiBaseCompletion): - """OpenAI chat completion. + Also see: - This also be constructed with :meth:`OpenAiBaseCompletion.from_yaml`. + - ``openai.resources.chat.completions.Completions.create()`` + - OpenAI API reference: `Create chat completion `_ - Args: - completion_kwargs: The kwargs for the OpenAI completion function. - - See Also: - `Create chat completion `_ - """ + Returns: + The result of the OpenAI completion. + """ + if isinstance(prompt, list): + for message in prompt: + if "role" not in message or "content" not in message: + raise ValueError( + "If prompt is a list of messages, each message must have a 'role' and 'content' key!" + ) + if message["role"] not in ["system", "user", "assistant", "tool"]: + raise ValueError( + "If prompt is a list of messages, each message must have a 'role' key with one of the values " + "'system', 'user', 'assistant' or 'tool'!" + ) - def _completion( - self, prompt: Union[str, List[Dict[str, str]]], completion_kwargs_for_this_call: Mapping[str, Any] - ) -> OpenAIObject: - """Call to the OpenAI chat completion.""" + if completion_kwargs is not None: + # check keys of completion_kwargs + if "model" in completion_kwargs: + raise ValueError( + "'model' can not be set via 'completion_kwargs'! Please set the 'model' in the initializer." + ) + if "messages" in completion_kwargs: + raise ValueError( + "'messages' can not be set via 'completion_kwargs'! Please set the 'prompt' argument." + ) + else: + completion_kwargs = {} # set default value + completion_kwargs["model"] = self.model messages = [{"role": "user", "content": prompt}] if isinstance(prompt, str) else prompt - open_ai_object: OpenAIObject = ChatCompletion.create( - messages=messages, - **completion_kwargs_for_this_call, + chat_completion = self.client.chat.completions.create( + messages=messages, # type: ignore[arg-type] + **completion_kwargs, ) - return open_ai_object - - -def _check_mandatory_azure_completion_kwargs(completion_kwargs: Mapping[str, Any]) -> None: - """Check mandatory Azure ``completion_kwargs``.""" - for mandatory_azure_completion_kwarg in ("api_base", "engine", "api_type", "api_version"): - if mandatory_azure_completion_kwarg not in completion_kwargs: - raise ValueError(f"You must set '{mandatory_azure_completion_kwarg}' for Azure completion!") - if completion_kwargs["api_type"] != "azure": - raise ValueError("You must set 'api_type' to 'azure' for Azure completion!") + result = OpenAiChatResult.from_chat_completion(chat_completion, completion_kwargs=completion_kwargs) + return result @dataclass -class OpenAiAzureChatCompletion(OpenAiChatCompletion): - """OpenAI Azure chat completion. - - This also be constructed with :meth:`OpenAiBaseCompletion.from_yaml`. +class OpenAiAzureChat(OpenAiChat): + """Tool to interact with Azure OpenAI chat models. - Args: - completion_kwargs: The kwargs for the OpenAI completion function. - The following Azure specific properties must be specified: - - * ``api_type`` - * ``api_version`` - * ``api_base`` - * ``engine`` + This can also be constructed with :meth:`~OpenAiChat.from_yaml`. See Also: - * `Create chat completion `_ - * `Quickstart: Get started using GPT-35-Turbo and GPT-4 with Azure OpenAI Service `_ - """ - - def __post_init__(self) -> None: - """Do post init.""" - _check_mandatory_azure_completion_kwargs(self.completion_kwargs) - - -@dataclass -class OpenAiCompletion(OpenAiBaseCompletion): - """OpenAI (non chat) completion. - - This also be constructed with :meth:`OpenAiBaseCompletion.from_yaml`. + * OpenAI API reference: `Create chat completion `_ + * `Quickstart: Get started generating text using Azure OpenAI Service `_ Args: - completion_kwargs: The kwargs for the OpenAI completion function. - - See Also: - `Create completion `_ + api_key: The OpenAI API key. + model: The OpenAI model name. + api_version: The OpenAI API version. + A common value for this is ``2023-05-15``. + azure_endpoint: The Azure endpoint. """ - def _completion( - self, prompt: Union[str, List[Dict[str, str]]], completion_kwargs_for_this_call: Mapping[str, Any] - ) -> OpenAIObject: - """Call to the OpenAI (not chat) completion.""" - open_ai_object: OpenAIObject = Completion.create( - prompt=prompt, - **completion_kwargs_for_this_call, - ) - return open_ai_object - - -@dataclass -class OpenAiAzureCompletion(OpenAiCompletion): - """OpenAI Azure (non chat) completion. - - This also be constructed with :meth:`OpenAiBaseCompletion.from_yaml`. - - Args: - completion_kwargs: The kwargs for the OpenAI completion function. - The following Azure specific properties must be specified: - - * ``api_type`` - * ``api_version`` - * ``api_base`` - * ``engine`` - - See Also: - * `Create completion `_ - * `Quickstart: Get started generating text using Azure OpenAI Service `_ - """ + api_version: str + azure_endpoint: str def __post_init__(self) -> None: """Do post init.""" - _check_mandatory_azure_completion_kwargs(self.completion_kwargs) + self.client = AzureOpenAI( + api_key=self.api_key, + api_version=self.api_version, + azure_endpoint=self.azure_endpoint, + ) diff --git a/pyproject.toml b/pyproject.toml index 08230b5..3df5f43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ torch = {version = "!=2.0.1,!=2.1.0", optional = true} # some versions have poe transformers = {version = "*", optional = true} tiktoken = {version = "*", optional = true} safetensors = {version = "!=0.3.2", optional = true} # version 0.3.2 has poetry issues -openai = {version = "^0", optional = true} +openai = {version = "^1", optional = true} pyyaml = {version = "*", optional = true} pandas = {version = "*", optional = true} beautifulsoup4 = {version = "*", optional = true} diff --git a/tests/test_openai.py b/tests/test_openai.py index 66a3a38..2e59d8c 100644 --- a/tests/test_openai.py +++ b/tests/test_openai.py @@ -8,7 +8,7 @@ from hypothesis import given, settings from hypothesis.strategies import lists, text -from mltb2.openai import OpenAiTokenCounter +from mltb2.openai import OpenAiChat, OpenAiTokenCounter @pytest.fixture(scope="module") @@ -48,3 +48,36 @@ def test_OpenAiTokenCounter_call_list(): # noqa: N802 assert len(token_count) == 2 assert token_count[0] == 5 assert token_count[1] == 7 + + +def test_OpenAiChat__missing_role_message_key(): # noqa: N802 + open_ai_chat = OpenAiChat(api_key="secret", model="apt-4") + invalid_prompt_as_list = [{"x": "user", "content": "prompt"}] + with pytest.raises(ValueError): + open_ai_chat(invalid_prompt_as_list) + + +def test_OpenAiChat__missing_content_message_key(): # noqa: N802 + open_ai_chat = OpenAiChat(api_key="secret", model="apt-4") + invalid_prompt_as_list = [{"role": "user", "x": "prompt"}] + with pytest.raises(ValueError): + open_ai_chat(invalid_prompt_as_list) + + +def test_OpenAiChat__invalid_role_in_message_key(): # noqa: N802 + open_ai_chat = OpenAiChat(api_key="secret", model="apt-4") + invalid_prompt_as_list = [{"role": "x", "content": "prompt"}] + with pytest.raises(ValueError): + open_ai_chat(invalid_prompt_as_list) + + +def test_OpenAiChat__model_in_completion_kwargs(): # noqa: N802 + open_ai_chat = OpenAiChat(api_key="secret", model="apt-4") + with pytest.raises(ValueError): + open_ai_chat("Hello!", completion_kwargs={"model": "gpt-4"}) + + +def test_OpenAiChat__messages_in_completion_kwargs(): # noqa: N802 + open_ai_chat = OpenAiChat(api_key="secret", model="apt-4") + with pytest.raises(ValueError): + open_ai_chat("Hello!", completion_kwargs={"messages": "World!"})