From 026d97c41566c3bb2e1ae6d1cd842c955c43b0c1 Mon Sep 17 00:00:00 2001 From: Tonic Date: Tue, 6 Aug 2024 14:36:17 +0200 Subject: [PATCH 1/4] add githubllm client , wrapper and test --- autogen/oai/github.py | 242 +++++++++++++++++++++++++++++++++++++ test/oai/test_githubllm.py | 121 +++++++++++++++++++ 2 files changed, 363 insertions(+) create mode 100644 autogen/oai/github.py create mode 100644 test/oai/test_githubllm.py diff --git a/autogen/oai/github.py b/autogen/oai/github.py new file mode 100644 index 00000000000..31c0d07b1e7 --- /dev/null +++ b/autogen/oai/github.py @@ -0,0 +1,242 @@ +'''Create a Github LLM Client with Azure Fallback. + +# Usage example: +if __name__ == "__main__": + config = { + "model": "gpt-4o", + "system_prompt": "You are a knowledgeable history teacher.", + "use_azure_fallback": True + } + + wrapper = GithubWrapper(config_list=[config]) + + response = wrapper.create(messages=[{"role": "user", "content": "What is the capital of France?"}]) + print(wrapper.message_retrieval(response)[0]) + + conversation = [ + {"role": "user", "content": "Tell me about the French Revolution."}, + {"role": "assistant", "content": "The French Revolution was a period of major social and political upheaval in France that began in 1789 with the Storming of the Bastille and ended in the late 1790s with the ascent of Napoleon Bonaparte."}, + {"role": "user", "content": "What were the main causes?"} + ] + + response = wrapper.create(messages=conversation) + print(wrapper.message_retrieval(response)[0]) +''' + +from __future__ import annotations + +import os +import logging +import time +import json +from typing import Any, Dict, List, Optional, Union, Tuple + +import requests +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types.completion_usage import CompletionUsage + +from autogen.oai.client_utils import should_hide_tools, validate_parameter +from autogen.cache import Cache + +logger = logging.getLogger(__name__) + +class GithubClient: + """GitHub LLM Client with Azure Fallback""" + + SUPPORTED_MODELS = [ + "AI21-Jamba-Instruct", + "cohere-command-r", + "cohere-command-r-plus", + "cohere-embed-v3-english", + "cohere-embed-v3-multilingual", + "meta-llama-3-70b-instruct", + "meta-llama-3-8b-instruct", + "meta-llama-3.1-405b-instruct", + "meta-llama-3.1-70b-instruct", + "meta-llama-3.1-8b-instruct", + "mistral-large", + "mistral-large-2407", + "mistral-nemo", + "mistral-small", + "gpt-4o", + "gpt-4o-mini", + "phi-3-medium-instruct-128k", + "phi-3-medium-instruct-4k", + "phi-3-mini-instruct-128k", + "phi-3-mini-instruct-4k", + "phi-3-small-instruct-128k", + "phi-3-small-instruct-8k" + ] + + def __init__(self, **kwargs): + self.github_endpoint_url = "https://models.inference.ai.azure.com/chat/completions" + self.model = kwargs.get("model") + self.system_prompt = kwargs.get("system_prompt", "You are a helpful assistant.") + self.use_azure_fallback = kwargs.get("use_azure_fallback", True) + self.rate_limit_reset_time = 0 + self.request_count = 0 + self.max_requests_per_minute = 15 + self.max_requests_per_day = 150 + + self.github_token = os.environ.get("GITHUB_TOKEN") + self.azure_api_key = os.environ.get("AZURE_API_KEY") + + if not self.github_token: + raise ValueError("GITHUB_TOKEN environment variable is not set.") + if self.use_azure_fallback and not self.azure_api_key: + raise ValueError("AZURE_API_KEY environment variable is not set.") + + if self.model.lower() not in [model.lower() for model in self.SUPPORTED_MODELS]: + raise ValueError(f"Model {self.model} is not supported. Please choose from {self.SUPPORTED_MODELS}") + + def message_retrieval(self, response: ChatCompletion) -> List[str]: + """Retrieve the messages from the response.""" + return [choice.message.content for choice in response.choices] + + def create(self, params: Dict[str, Any]) -> ChatCompletion: + """Create a completion for a given config.""" + messages = params.get("messages", []) + + if "system" not in [m["role"] for m in messages]: + messages.insert(0, {"role": "system", "content": self.system_prompt}) + + data = { + "messages": messages, + "model": self.model, + **params + } + + if self._check_rate_limit(): + try: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.github_token}" + } + + response = self._call_api(self.github_endpoint_url, headers, data) + self._increment_request_count() + return self._process_response(response) + except (requests.exceptions.RequestException, ValueError) as e: + logger.warning(f"GitHub API call failed: {str(e)}. Falling back to Azure.") + + if self.use_azure_fallback: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.azure_api_key}" + } + + response = self._call_api(self.github_endpoint_url, headers, data) + return self._process_response(response) + else: + raise ValueError("Rate limit reached and Azure fallback is disabled.") + + def _check_rate_limit(self) -> bool: + """Check if the rate limit has been reached.""" + current_time = time.time() + if current_time < self.rate_limit_reset_time: + return False + if self.request_count >= self.max_requests_per_minute: + self.rate_limit_reset_time = current_time + 60 + self.request_count = 0 + return False + return True + + def _increment_request_count(self): + """Increment the request count.""" + self.request_count += 1 + + def _call_api(self, endpoint_url: str, headers: Dict[str, str], data: Dict[str, Any]) -> Dict[str, Any]: + """Make an API call to either GitHub or Azure.""" + response = requests.post(endpoint_url, headers=headers, json=data) + response.raise_for_status() + return response.json() + + def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion: + """Process the API response and return a ChatCompletion object.""" + choices = [ + Choice( + index=i, + message=ChatCompletionMessage( + role="assistant", + content=choice["message"]["content"] + ), + finish_reason=choice.get("finish_reason") + ) + for i, choice in enumerate(response_data["choices"]) + ] + + usage = CompletionUsage( + prompt_tokens=response_data["usage"]["prompt_tokens"], + completion_tokens=response_data["usage"]["completion_tokens"], + total_tokens=response_data["usage"]["total_tokens"] + ) + + return ChatCompletion( + id=response_data["id"], + model=response_data["model"], + created=response_data["created"], + object="chat.completion", + choices=choices, + usage=usage + ) + + def cost(self, response: ChatCompletion) -> float: + """Calculate the cost of the response.""" + # Pass + return 0.0 # Placeholder + + @staticmethod + def get_usage(response: ChatCompletion) -> Dict: + return { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + "cost": response.cost if hasattr(response, "cost") else 0, + "model": response.model, + } + +class GithubWrapper: + """Wrapper for GitHub LLM Client""" + + def __init__(self, config_list: Optional[List[Dict[str, Any]]] = None, **kwargs): + self._clients = [] + self._config_list = [] + + if config_list: + for config in config_list: + self._register_client(config) + self._config_list.append(config) + else: + self._register_client(kwargs) + self._config_list = [kwargs] + + def _register_client(self, config: Dict[str, Any]): + client = GithubClient(**config) + self._clients.append(client) + + def create(self, **params: Any) -> ChatCompletion: + """Create a completion using available clients.""" + for i, client in enumerate(self._clients): + try: + response = client.create(params) + response.config_id = i + return response + except Exception as e: + logger.warning(f"Error with client {i}: {str(e)}") + if i == len(self._clients) - 1: + raise + + def message_retrieval(self, response: ChatCompletion) -> List[str]: + """Retrieve messages from the response.""" + return self._clients[response.config_id].message_retrieval(response) + + def cost(self, response: ChatCompletion) -> float: + """Calculate the cost of the response.""" + return self._clients[response.config_id].cost(response) + + @staticmethod + def get_usage(response: ChatCompletion) -> Dict: + """Get usage information from the response.""" + return GithubClient.get_usage(response) + diff --git a/test/oai/test_githubllm.py b/test/oai/test_githubllm.py new file mode 100644 index 00000000000..fd76c68446d --- /dev/null +++ b/test/oai/test_githubllm.py @@ -0,0 +1,121 @@ +import pytest +from unittest.mock import patch, MagicMock +from autogen.oai.github import GithubClient, GithubWrapper + +@pytest.fixture +def github_client(): + with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): + return GithubClient(model="gpt-4o", system_prompt="Test prompt") + +@pytest.fixture +def github_wrapper(): + with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): + config = { + "model": "gpt-4o", + "system_prompt": "Test prompt", + "use_azure_fallback": True + } + return GithubWrapper(config_list=[config]) + +def test_github_client_initialization(github_client): + assert github_client.model == "gpt-4o" + assert github_client.system_prompt == "Test prompt" + assert github_client.use_azure_fallback == True + +def test_github_client_unsupported_model(): + with pytest.raises(ValueError): + with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): + GithubClient(model="unsupported-model") + +@patch('requests.post') +def test_github_client_create(mock_post, github_client): + mock_response = MagicMock() + mock_response.json.return_value = { + "id": "test_id", + "model": "gpt-4o", + "created": 1234567890, + "choices": [{"message": {"content": "Test response"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + } + mock_post.return_value = mock_response + + params = {"messages": [{"role": "user", "content": "Test message"}]} + response = github_client.create(params) + + assert response.id == "test_id" + assert response.model == "gpt-4o" + assert len(response.choices) == 1 + assert response.choices[0].message.content == "Test response" + +def test_github_client_message_retrieval(github_client): + mock_response = MagicMock() + mock_response.choices = [ + MagicMock(message=MagicMock(content="Response 1")), + MagicMock(message=MagicMock(content="Response 2")) + ] + + messages = github_client.message_retrieval(mock_response) + assert messages == ["Response 1", "Response 2"] + +def test_github_client_cost(github_client): + mock_response = MagicMock() + cost = github_client.cost(mock_response) + assert cost == 0.0 # Assuming the placeholder implementation + +def test_github_client_get_usage(github_client): + mock_response = MagicMock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 20 + mock_response.usage.total_tokens = 30 + mock_response.model = "gpt-4o" + + usage = github_client.get_usage(mock_response) + assert usage["prompt_tokens"] == 10 + assert usage["completion_tokens"] == 20 + assert usage["total_tokens"] == 30 + assert usage["model"] == "gpt-4o" + +@patch('your_module.GithubClient.create') +def test_github_wrapper_create(mock_create, github_wrapper): + mock_response = MagicMock() + mock_create.return_value = mock_response + + params = {"messages": [{"role": "user", "content": "Test message"}]} + response = github_wrapper.create(**params) + + assert response == mock_response + assert hasattr(response, 'config_id') + mock_create.assert_called_once_with(params) + +def test_github_wrapper_message_retrieval(github_wrapper): + mock_response = MagicMock() + mock_response.config_id = 0 + + with patch.object(github_wrapper._clients[0], 'message_retrieval') as mock_retrieval: + mock_retrieval.return_value = ["Test message"] + messages = github_wrapper.message_retrieval(mock_response) + + assert messages == ["Test message"] + +def test_github_wrapper_cost(github_wrapper): + mock_response = MagicMock() + mock_response.config_id = 0 + + with patch.object(github_wrapper._clients[0], 'cost') as mock_cost: + mock_cost.return_value = 0.05 + cost = github_wrapper.cost(mock_response) + + assert cost == 0.05 + +def test_github_wrapper_get_usage(github_wrapper): + mock_response = MagicMock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 20 + mock_response.usage.total_tokens = 30 + mock_response.model = "gpt-4o" + + usage = github_wrapper.get_usage(mock_response) + assert usage["prompt_tokens"] == 10 + assert usage["completion_tokens"] == 20 + assert usage["total_tokens"] == 30 + assert usage["model"] == "gpt-4o" \ No newline at end of file From 1448b9c35433ec3ecfaced7b42e2a9e2a34881cc Mon Sep 17 00:00:00 2001 From: Tonic Date: Tue, 6 Aug 2024 15:48:08 +0200 Subject: [PATCH 2/4] add test and ruff lint improvements --- test/oai/test_githubllm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/oai/test_githubllm.py b/test/oai/test_githubllm.py index fd76c68446d..0a3b879f1dc 100644 --- a/test/oai/test_githubllm.py +++ b/test/oai/test_githubllm.py @@ -75,7 +75,7 @@ def test_github_client_get_usage(github_client): assert usage["total_tokens"] == 30 assert usage["model"] == "gpt-4o" -@patch('your_module.GithubClient.create') +@patch('autogen.oai.github.GithubClient.create') def test_github_wrapper_create(mock_create, github_wrapper): mock_response = MagicMock() mock_create.return_value = mock_response @@ -118,4 +118,4 @@ def test_github_wrapper_get_usage(github_wrapper): assert usage["prompt_tokens"] == 10 assert usage["completion_tokens"] == 20 assert usage["total_tokens"] == 30 - assert usage["model"] == "gpt-4o" \ No newline at end of file + assert usage["model"] == "gpt-4o" From 100298abab7b9e37a6c9022ecb20b92a180c1ca6 Mon Sep 17 00:00:00 2001 From: Tonic Date: Tue, 6 Aug 2024 17:32:11 +0200 Subject: [PATCH 3/4] fix some kind of linting --- autogen/oai/github.py | 51 ++++++++++++++------------------------ test/oai/test_githubllm.py | 49 ++++++++++++++++++++---------------- 2 files changed, 47 insertions(+), 53 deletions(-) diff --git a/autogen/oai/github.py b/autogen/oai/github.py index 31c0d07b1e7..4dfbdc7e94c 100644 --- a/autogen/oai/github.py +++ b/autogen/oai/github.py @@ -1,4 +1,4 @@ -'''Create a Github LLM Client with Azure Fallback. +"""Create a Github LLM Client with Azure Fallback. # Usage example: if __name__ == "__main__": @@ -7,9 +7,9 @@ "system_prompt": "You are a knowledgeable history teacher.", "use_azure_fallback": True } - + wrapper = GithubWrapper(config_list=[config]) - + response = wrapper.create(messages=[{"role": "user", "content": "What is the capital of France?"}]) print(wrapper.message_retrieval(response)[0]) @@ -18,29 +18,29 @@ {"role": "assistant", "content": "The French Revolution was a period of major social and political upheaval in France that began in 1789 with the Storming of the Bastille and ended in the late 1790s with the ascent of Napoleon Bonaparte."}, {"role": "user", "content": "What were the main causes?"} ] - + response = wrapper.create(messages=conversation) print(wrapper.message_retrieval(response)[0]) -''' +""" from __future__ import annotations - -import os +import json import logging +import os import time -import json -from typing import Any, Dict, List, Optional, Union, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import requests from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice from openai.types.completion_usage import CompletionUsage -from autogen.oai.client_utils import should_hide_tools, validate_parameter from autogen.cache import Cache +from autogen.oai.client_utils import should_hide_tools, validate_parameter logger = logging.getLogger(__name__) + class GithubClient: """GitHub LLM Client with Azure Fallback""" @@ -66,7 +66,7 @@ class GithubClient: "phi-3-mini-instruct-128k", "phi-3-mini-instruct-4k", "phi-3-small-instruct-128k", - "phi-3-small-instruct-8k" + "phi-3-small-instruct-8k", ] def __init__(self, **kwargs): @@ -97,22 +97,15 @@ def message_retrieval(self, response: ChatCompletion) -> List[str]: def create(self, params: Dict[str, Any]) -> ChatCompletion: """Create a completion for a given config.""" messages = params.get("messages", []) - + if "system" not in [m["role"] for m in messages]: messages.insert(0, {"role": "system", "content": self.system_prompt}) - data = { - "messages": messages, - "model": self.model, - **params - } + data = {"messages": messages, "model": self.model, **params} if self._check_rate_limit(): try: - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.github_token}" - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.github_token}"} response = self._call_api(self.github_endpoint_url, headers, data) self._increment_request_count() @@ -121,10 +114,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion: logger.warning(f"GitHub API call failed: {str(e)}. Falling back to Azure.") if self.use_azure_fallback: - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.azure_api_key}" - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.azure_api_key}"} response = self._call_api(self.github_endpoint_url, headers, data) return self._process_response(response) @@ -157,11 +147,8 @@ def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion: choices = [ Choice( index=i, - message=ChatCompletionMessage( - role="assistant", - content=choice["message"]["content"] - ), - finish_reason=choice.get("finish_reason") + message=ChatCompletionMessage(role="assistant", content=choice["message"]["content"]), + finish_reason=choice.get("finish_reason"), ) for i, choice in enumerate(response_data["choices"]) ] @@ -169,7 +156,7 @@ def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion: usage = CompletionUsage( prompt_tokens=response_data["usage"]["prompt_tokens"], completion_tokens=response_data["usage"]["completion_tokens"], - total_tokens=response_data["usage"]["total_tokens"] + total_tokens=response_data["usage"]["total_tokens"], ) return ChatCompletion( @@ -178,7 +165,7 @@ def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion: created=response_data["created"], object="chat.completion", choices=choices, - usage=usage + usage=usage, ) def cost(self, response: ChatCompletion) -> float: diff --git a/test/oai/test_githubllm.py b/test/oai/test_githubllm.py index 0a3b879f1dc..f1f65fa8e36 100644 --- a/test/oai/test_githubllm.py +++ b/test/oai/test_githubllm.py @@ -1,33 +1,34 @@ +from unittest.mock import MagicMock, patch import pytest -from unittest.mock import patch, MagicMock + from autogen.oai.github import GithubClient, GithubWrapper @pytest.fixture def github_client(): - with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): + with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}): return GithubClient(model="gpt-4o", system_prompt="Test prompt") + @pytest.fixture def github_wrapper(): - with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): - config = { - "model": "gpt-4o", - "system_prompt": "Test prompt", - "use_azure_fallback": True - } + with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}): + config = {"model": "gpt-4o", "system_prompt": "Test prompt", "use_azure_fallback": True} return GithubWrapper(config_list=[config]) + def test_github_client_initialization(github_client): assert github_client.model == "gpt-4o" assert github_client.system_prompt == "Test prompt" assert github_client.use_azure_fallback == True + def test_github_client_unsupported_model(): with pytest.raises(ValueError): - with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): + with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}): GithubClient(model="unsupported-model") -@patch('requests.post') + +@patch("requests.post") def test_github_client_create(mock_post, github_client): mock_response = MagicMock() mock_response.json.return_value = { @@ -35,7 +36,7 @@ def test_github_client_create(mock_post, github_client): "model": "gpt-4o", "created": 1234567890, "choices": [{"message": {"content": "Test response"}, "finish_reason": "stop"}], - "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, } mock_post.return_value = mock_response @@ -47,21 +48,24 @@ def test_github_client_create(mock_post, github_client): assert len(response.choices) == 1 assert response.choices[0].message.content == "Test response" + def test_github_client_message_retrieval(github_client): mock_response = MagicMock() mock_response.choices = [ MagicMock(message=MagicMock(content="Response 1")), - MagicMock(message=MagicMock(content="Response 2")) + MagicMock(message=MagicMock(content="Response 2")), ] - + messages = github_client.message_retrieval(mock_response) assert messages == ["Response 1", "Response 2"] + def test_github_client_cost(github_client): mock_response = MagicMock() cost = github_client.cost(mock_response) assert cost == 0.0 # Assuming the placeholder implementation + def test_github_client_get_usage(github_client): mock_response = MagicMock() mock_response.usage.prompt_tokens = 10 @@ -75,7 +79,8 @@ def test_github_client_get_usage(github_client): assert usage["total_tokens"] == 30 assert usage["model"] == "gpt-4o" -@patch('autogen.oai.github.GithubClient.create') + +@patch("autogen.oai.github.GithubClient.create") def test_github_wrapper_create(mock_create, github_wrapper): mock_response = MagicMock() mock_create.return_value = mock_response @@ -84,29 +89,31 @@ def test_github_wrapper_create(mock_create, github_wrapper): response = github_wrapper.create(**params) assert response == mock_response - assert hasattr(response, 'config_id') + assert hasattr(response, "config_id") mock_create.assert_called_once_with(params) def test_github_wrapper_message_retrieval(github_wrapper): mock_response = MagicMock() mock_response.config_id = 0 - - with patch.object(github_wrapper._clients[0], 'message_retrieval') as mock_retrieval: + + + with patch.object(github_wrapper._clients[0], "message_retrieval") as mock_retrieval: mock_retrieval.return_value = ["Test message"] messages = github_wrapper.message_retrieval(mock_response) - + assert messages == ["Test message"] def test_github_wrapper_cost(github_wrapper): mock_response = MagicMock() mock_response.config_id = 0 - - with patch.object(github_wrapper._clients[0], 'cost') as mock_cost: + + with patch.object(github_wrapper._clients[0], "cost") as mock_cost: mock_cost.return_value = 0.05 cost = github_wrapper.cost(mock_response) - + assert cost == 0.05 + def test_github_wrapper_get_usage(github_wrapper): mock_response = MagicMock() mock_response.usage.prompt_tokens = 10 From 77e599a44b46dbd01b4d6a85f088739742673870 Mon Sep 17 00:00:00 2001 From: Tonic Date: Tue, 6 Aug 2024 19:00:21 +0200 Subject: [PATCH 4/4] fix test --- test/oai/test_githubllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/oai/test_githubllm.py b/test/oai/test_githubllm.py index f1f65fa8e36..6b3c918a8ce 100644 --- a/test/oai/test_githubllm.py +++ b/test/oai/test_githubllm.py @@ -19,7 +19,7 @@ def github_wrapper(): def test_github_client_initialization(github_client): assert github_client.model == "gpt-4o" assert github_client.system_prompt == "Test prompt" - assert github_client.use_azure_fallback == True + assert github_client.use_azure_fallback def test_github_client_unsupported_model():