Skip to content

Commit

Permalink
Add support for new OpenAI API. (#144)
Browse files Browse the repository at this point in the history
* Update openai version in pyproject.toml

* add OpenAiChatCompletion

* Update OpenAI client and add AzureOpenAI support

* Add OpenAiChatResult and OpenAiChat classes

* Refactor OpenAiChat.create_model_response method signature and add docstring

* Add docstring to OpenAiChat class

* improve doc

* Add OpenAiChat tests for error handling

* Update OpenAiAzureChat class constructor

* improve typing

* Fix method signature in OpenAiChat class
  • Loading branch information
PhilipMay authored Jan 7, 2024
1 parent c562d9b commit defa6f4
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 146 deletions.
275 changes: 131 additions & 144 deletions mltb2/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
"""


from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union

import tiktoken
import yaml
from openai import ChatCompletion, Completion
from openai.openai_object import OpenAIObject
from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion
from tiktoken.core import Encoding
from tqdm import tqdm

Expand Down Expand Up @@ -67,206 +66,194 @@ def __call__(self, text: Union[str, Iterable]) -> Union[int, List[int]]:


@dataclass
class OpenAiCompletionAnswer:
"""Answer of an OpenAI completion.
class OpenAiChatResult:
"""Result of an OpenAI chat completion.
If you want to convert this to a ``dict`` use ``asdict(open_ai_chat_result)``
from the ``dataclasses`` module.
See Also:
OpenAI API reference: `The chat completion object <https://platform.openai.com/docs/api-reference/chat/object>`_
Args:
text: the result of the OpenAI completion
content: the result of the OpenAI completion
model: model name which has been used
prompt_tokens: number of tokens of the prompt
completion_tokens: number of tokens of the completion (``text``)
total_tokens: number of total tokens (``prompt_tokens + completion_tokens``)
completion_tokens: number of tokens of the completion (``content``)
total_tokens: number of total tokens (``prompt_tokens + content_tokens``)
finish_reason: The reason why the completion stopped.
* ``stop``: Means the API returned the full completion without running into any token limit.
* ``length``: Means the API stopped the completion because of running into a token limit.
* ``function_call``: When the model called a function.
* ``content_filter``: When content was omitted due to a flag from the OpenAI content filters.
* ``tool_calls``: When the model called a tool.
* ``function_call`` (deprecated): When the model called a function.
See Also:
* `The chat completion object <https://platform.openai.com/docs/api-reference/chat/object>`_
* `The completion object <https://platform.openai.com/docs/api-reference/completions/object>`_
completion_args: The arguments which have been used for the completion. Examples:
* ``model``: always set
* ``max_tokens``: only set if ``completion_kwargs`` contained ``max_tokens``
* ``temperature``: only set if ``completion_kwargs`` contained ``temperature``
* ``top_p``: only set if ``completion_kwargs`` contained ``top_p``
"""

text: Optional[str] = None
content: Optional[str] = None
model: Optional[str] = None
prompt_tokens: Optional[int] = None
completion_tokens: Optional[int] = None
total_tokens: Optional[int] = None
finish_reason: Optional[str] = None
completion_args: Optional[Dict[str, Any]] = None

@classmethod
def from_open_ai_object(cls, open_ai_object: OpenAIObject):
"""Construct this class from ``OpenAIObject``."""
def from_chat_completion(
cls,
chat_completion: ChatCompletion,
completion_kwargs: Optional[Dict[str, Any]] = None,
):
"""Construct this class from an OpenAI ``ChatCompletion`` object.
Args:
chat_completion: The OpenAI ``ChatCompletion`` object.
completion_kwargs: The arguments which have been used for the completion.
Returns:
The constructed class.
"""
result = {}
result["model"] = open_ai_object.get("model")
usage = open_ai_object.get("usage")
result["completion_args"] = completion_kwargs
chat_completion_dict = chat_completion.model_dump()
result["model"] = chat_completion_dict.get("model")
usage = chat_completion_dict.get("usage")
if usage is not None:
result["prompt_tokens"] = usage.get("prompt_tokens")
result["completion_tokens"] = usage.get("completion_tokens")
result["total_tokens"] = usage.get("total_tokens")
choices = open_ai_object.get("choices")
choices = chat_completion_dict.get("choices")
if choices is not None and len(choices) > 0:
choice = choices[0]
result["finish_reason"] = choice.get("finish_reason")
if "text" in choice: # non chat models
result["text"] = choice.get("text")
elif "message" in choice: # chat models
message = choice.get("message")
if message is not None:
result["text"] = message.get("content")
return cls(**result)
message = choice.get("message")
if message is not None:
result["content"] = message.get("content")
return cls(**result) # type: ignore[arg-type]


@dataclass
class OpenAiBaseCompletion(ABC):
"""Abstract base class for OpenAI completion.
class OpenAiChat:
"""Tool to interact with OpenAI chat models.
Args:
completion_kwargs: kwargs for the OpenAI completion function
This also be constructed with :meth:`~OpenAiChat.from_yaml`.
See Also:
* `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_
* `Create completion <https://platform.openai.com/docs/api-reference/completions/create>`_
OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_
Args:
api_key: The OpenAI API key.
model: The OpenAI model name.
"""

completion_kwargs: Dict[str, Any]
api_key: str
model: str
client: Union[OpenAI, AzureOpenAI] = field(init=False, repr=False)

def __post_init__(self) -> None:
"""Do post init."""
self.client = OpenAI(api_key=self.api_key)

@classmethod
def from_yaml(cls, yaml_file):
"""Construct this class from a yaml file."""
"""Construct this class from a yaml file.
Args:
yaml_file: The yaml file.
Returns:
The constructed class.
"""
with open(yaml_file, "r") as file:
completion_kwargs = yaml.safe_load(file)
return cls(completion_kwargs)

@abstractmethod
def _completion(
self, prompt: Union[str, List[Dict[str, str]]], completion_kwargs_for_this_call: Mapping[str, Any]
) -> OpenAIObject:
"""Abstract method to call the OpenAI completion."""
return cls(**completion_kwargs)

def __call__(
self, prompt: Union[str, List[Dict[str, str]]], completion_kwargs: Optional[Mapping[str, Any]] = None
) -> OpenAiCompletionAnswer:
"""Call the OpenAI prompt completion.
self,
prompt: Union[str, List[Dict[str, str]]],
completion_kwargs: Optional[Dict[str, Any]] = None,
) -> OpenAiChatResult:
"""Create a model response for the given prompt (chat conversation).
Args:
prompt: The prompt to be completed by the LLM.
In case of chat models this can be a string or a list.
In case of "non chat" models only a string is allowed.
completion_kwargs: Overwrite the ``completion_kwargs`` for this call.
This allows you, for example, to change the temperature for this call only.
"""
completion_kwargs_for_this_call = self.completion_kwargs.copy()
if completion_kwargs is not None:
completion_kwargs_for_this_call.update(completion_kwargs)
open_ai_object: OpenAIObject = self._completion(prompt, completion_kwargs_for_this_call)
open_ai_completion_answer = OpenAiCompletionAnswer.from_open_ai_object(open_ai_object)
return open_ai_completion_answer
prompt: The prompt for the model.
completion_kwargs: Keyword arguments for the OpenAI completion.
- ``model`` can not be set via ``completion_kwargs``! Please set the ``model`` in the initializer.
- ``messages`` can not be set via ``completion_kwargs``! Please set the ``prompt`` argument.
@dataclass
class OpenAiChatCompletion(OpenAiBaseCompletion):
"""OpenAI chat completion.
Also see:
This also be constructed with :meth:`OpenAiBaseCompletion.from_yaml`.
- ``openai.resources.chat.completions.Completions.create()``
- OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_
Args:
completion_kwargs: The kwargs for the OpenAI completion function.
See Also:
`Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_
"""
Returns:
The result of the OpenAI completion.
"""
if isinstance(prompt, list):
for message in prompt:
if "role" not in message or "content" not in message:
raise ValueError(
"If prompt is a list of messages, each message must have a 'role' and 'content' key!"
)
if message["role"] not in ["system", "user", "assistant", "tool"]:
raise ValueError(
"If prompt is a list of messages, each message must have a 'role' key with one of the values "
"'system', 'user', 'assistant' or 'tool'!"
)

def _completion(
self, prompt: Union[str, List[Dict[str, str]]], completion_kwargs_for_this_call: Mapping[str, Any]
) -> OpenAIObject:
"""Call to the OpenAI chat completion."""
if completion_kwargs is not None:
# check keys of completion_kwargs
if "model" in completion_kwargs:
raise ValueError(
"'model' can not be set via 'completion_kwargs'! Please set the 'model' in the initializer."
)
if "messages" in completion_kwargs:
raise ValueError(
"'messages' can not be set via 'completion_kwargs'! Please set the 'prompt' argument."
)
else:
completion_kwargs = {} # set default value
completion_kwargs["model"] = self.model
messages = [{"role": "user", "content": prompt}] if isinstance(prompt, str) else prompt
open_ai_object: OpenAIObject = ChatCompletion.create(
messages=messages,
**completion_kwargs_for_this_call,
chat_completion = self.client.chat.completions.create(
messages=messages, # type: ignore[arg-type]
**completion_kwargs,
)
return open_ai_object


def _check_mandatory_azure_completion_kwargs(completion_kwargs: Mapping[str, Any]) -> None:
"""Check mandatory Azure ``completion_kwargs``."""
for mandatory_azure_completion_kwarg in ("api_base", "engine", "api_type", "api_version"):
if mandatory_azure_completion_kwarg not in completion_kwargs:
raise ValueError(f"You must set '{mandatory_azure_completion_kwarg}' for Azure completion!")
if completion_kwargs["api_type"] != "azure":
raise ValueError("You must set 'api_type' to 'azure' for Azure completion!")
result = OpenAiChatResult.from_chat_completion(chat_completion, completion_kwargs=completion_kwargs)
return result


@dataclass
class OpenAiAzureChatCompletion(OpenAiChatCompletion):
"""OpenAI Azure chat completion.
This also be constructed with :meth:`OpenAiBaseCompletion.from_yaml`.
class OpenAiAzureChat(OpenAiChat):
"""Tool to interact with Azure OpenAI chat models.
Args:
completion_kwargs: The kwargs for the OpenAI completion function.
The following Azure specific properties must be specified:
* ``api_type``
* ``api_version``
* ``api_base``
* ``engine``
This can also be constructed with :meth:`~OpenAiChat.from_yaml`.
See Also:
* `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_
* `Quickstart: Get started using GPT-35-Turbo and GPT-4 with Azure OpenAI Service <https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart?tabs=command-line&pivots=programming-language-python>`_
"""

def __post_init__(self) -> None:
"""Do post init."""
_check_mandatory_azure_completion_kwargs(self.completion_kwargs)


@dataclass
class OpenAiCompletion(OpenAiBaseCompletion):
"""OpenAI (non chat) completion.
This also be constructed with :meth:`OpenAiBaseCompletion.from_yaml`.
* OpenAI API reference: `Create chat completion <https://platform.openai.com/docs/api-reference/chat/create>`_
* `Quickstart: Get started generating text using Azure OpenAI Service <https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart?tabs=command-line&pivots=programming-language-python>`_
Args:
completion_kwargs: The kwargs for the OpenAI completion function.
See Also:
`Create completion <https://platform.openai.com/docs/api-reference/completions/create>`_
api_key: The OpenAI API key.
model: The OpenAI model name.
api_version: The OpenAI API version.
A common value for this is ``2023-05-15``.
azure_endpoint: The Azure endpoint.
"""

def _completion(
self, prompt: Union[str, List[Dict[str, str]]], completion_kwargs_for_this_call: Mapping[str, Any]
) -> OpenAIObject:
"""Call to the OpenAI (not chat) completion."""
open_ai_object: OpenAIObject = Completion.create(
prompt=prompt,
**completion_kwargs_for_this_call,
)
return open_ai_object


@dataclass
class OpenAiAzureCompletion(OpenAiCompletion):
"""OpenAI Azure (non chat) completion.
This also be constructed with :meth:`OpenAiBaseCompletion.from_yaml`.
Args:
completion_kwargs: The kwargs for the OpenAI completion function.
The following Azure specific properties must be specified:
* ``api_type``
* ``api_version``
* ``api_base``
* ``engine``
See Also:
* `Create completion <https://platform.openai.com/docs/api-reference/completions/create>`_
* `Quickstart: Get started generating text using Azure OpenAI Service <https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart?tabs=command-line&pivots=programming-language-python>`_
"""
api_version: str
azure_endpoint: str

def __post_init__(self) -> None:
"""Do post init."""
_check_mandatory_azure_completion_kwargs(self.completion_kwargs)
self.client = AzureOpenAI(
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ torch = {version = "!=2.0.1,!=2.1.0", optional = true} # some versions have poe
transformers = {version = "*", optional = true}
tiktoken = {version = "*", optional = true}
safetensors = {version = "!=0.3.2", optional = true} # version 0.3.2 has poetry issues
openai = {version = "^0", optional = true}
openai = {version = "^1", optional = true}
pyyaml = {version = "*", optional = true}
pandas = {version = "*", optional = true}
beautifulsoup4 = {version = "*", optional = true}
Expand Down
Loading

0 comments on commit defa6f4

Please sign in to comment.