From 026d97c41566c3bb2e1ae6d1cd842c955c43b0c1 Mon Sep 17 00:00:00 2001 From: Tonic Date: Tue, 6 Aug 2024 14:36:17 +0200 Subject: [PATCH 1/7] add githubllm client , wrapper and test --- autogen/oai/github.py | 242 +++++++++++++++++++++++++++++++++++++ test/oai/test_githubllm.py | 121 +++++++++++++++++++ 2 files changed, 363 insertions(+) create mode 100644 autogen/oai/github.py create mode 100644 test/oai/test_githubllm.py diff --git a/autogen/oai/github.py b/autogen/oai/github.py new file mode 100644 index 00000000000..31c0d07b1e7 --- /dev/null +++ b/autogen/oai/github.py @@ -0,0 +1,242 @@ +'''Create a Github LLM Client with Azure Fallback. + +# Usage example: +if __name__ == "__main__": + config = { + "model": "gpt-4o", + "system_prompt": "You are a knowledgeable history teacher.", + "use_azure_fallback": True + } + + wrapper = GithubWrapper(config_list=[config]) + + response = wrapper.create(messages=[{"role": "user", "content": "What is the capital of France?"}]) + print(wrapper.message_retrieval(response)[0]) + + conversation = [ + {"role": "user", "content": "Tell me about the French Revolution."}, + {"role": "assistant", "content": "The French Revolution was a period of major social and political upheaval in France that began in 1789 with the Storming of the Bastille and ended in the late 1790s with the ascent of Napoleon Bonaparte."}, + {"role": "user", "content": "What were the main causes?"} + ] + + response = wrapper.create(messages=conversation) + print(wrapper.message_retrieval(response)[0]) +''' + +from __future__ import annotations + +import os +import logging +import time +import json +from typing import Any, Dict, List, Optional, Union, Tuple + +import requests +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types.completion_usage import CompletionUsage + +from autogen.oai.client_utils import should_hide_tools, validate_parameter +from autogen.cache import Cache + +logger = logging.getLogger(__name__) + +class GithubClient: + """GitHub LLM Client with Azure Fallback""" + + SUPPORTED_MODELS = [ + "AI21-Jamba-Instruct", + "cohere-command-r", + "cohere-command-r-plus", + "cohere-embed-v3-english", + "cohere-embed-v3-multilingual", + "meta-llama-3-70b-instruct", + "meta-llama-3-8b-instruct", + "meta-llama-3.1-405b-instruct", + "meta-llama-3.1-70b-instruct", + "meta-llama-3.1-8b-instruct", + "mistral-large", + "mistral-large-2407", + "mistral-nemo", + "mistral-small", + "gpt-4o", + "gpt-4o-mini", + "phi-3-medium-instruct-128k", + "phi-3-medium-instruct-4k", + "phi-3-mini-instruct-128k", + "phi-3-mini-instruct-4k", + "phi-3-small-instruct-128k", + "phi-3-small-instruct-8k" + ] + + def __init__(self, **kwargs): + self.github_endpoint_url = "https://models.inference.ai.azure.com/chat/completions" + self.model = kwargs.get("model") + self.system_prompt = kwargs.get("system_prompt", "You are a helpful assistant.") + self.use_azure_fallback = kwargs.get("use_azure_fallback", True) + self.rate_limit_reset_time = 0 + self.request_count = 0 + self.max_requests_per_minute = 15 + self.max_requests_per_day = 150 + + self.github_token = os.environ.get("GITHUB_TOKEN") + self.azure_api_key = os.environ.get("AZURE_API_KEY") + + if not self.github_token: + raise ValueError("GITHUB_TOKEN environment variable is not set.") + if self.use_azure_fallback and not self.azure_api_key: + raise ValueError("AZURE_API_KEY environment variable is not set.") + + if self.model.lower() not in [model.lower() for model in self.SUPPORTED_MODELS]: + raise ValueError(f"Model {self.model} is not supported. Please choose from {self.SUPPORTED_MODELS}") + + def message_retrieval(self, response: ChatCompletion) -> List[str]: + """Retrieve the messages from the response.""" + return [choice.message.content for choice in response.choices] + + def create(self, params: Dict[str, Any]) -> ChatCompletion: + """Create a completion for a given config.""" + messages = params.get("messages", []) + + if "system" not in [m["role"] for m in messages]: + messages.insert(0, {"role": "system", "content": self.system_prompt}) + + data = { + "messages": messages, + "model": self.model, + **params + } + + if self._check_rate_limit(): + try: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.github_token}" + } + + response = self._call_api(self.github_endpoint_url, headers, data) + self._increment_request_count() + return self._process_response(response) + except (requests.exceptions.RequestException, ValueError) as e: + logger.warning(f"GitHub API call failed: {str(e)}. Falling back to Azure.") + + if self.use_azure_fallback: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.azure_api_key}" + } + + response = self._call_api(self.github_endpoint_url, headers, data) + return self._process_response(response) + else: + raise ValueError("Rate limit reached and Azure fallback is disabled.") + + def _check_rate_limit(self) -> bool: + """Check if the rate limit has been reached.""" + current_time = time.time() + if current_time < self.rate_limit_reset_time: + return False + if self.request_count >= self.max_requests_per_minute: + self.rate_limit_reset_time = current_time + 60 + self.request_count = 0 + return False + return True + + def _increment_request_count(self): + """Increment the request count.""" + self.request_count += 1 + + def _call_api(self, endpoint_url: str, headers: Dict[str, str], data: Dict[str, Any]) -> Dict[str, Any]: + """Make an API call to either GitHub or Azure.""" + response = requests.post(endpoint_url, headers=headers, json=data) + response.raise_for_status() + return response.json() + + def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion: + """Process the API response and return a ChatCompletion object.""" + choices = [ + Choice( + index=i, + message=ChatCompletionMessage( + role="assistant", + content=choice["message"]["content"] + ), + finish_reason=choice.get("finish_reason") + ) + for i, choice in enumerate(response_data["choices"]) + ] + + usage = CompletionUsage( + prompt_tokens=response_data["usage"]["prompt_tokens"], + completion_tokens=response_data["usage"]["completion_tokens"], + total_tokens=response_data["usage"]["total_tokens"] + ) + + return ChatCompletion( + id=response_data["id"], + model=response_data["model"], + created=response_data["created"], + object="chat.completion", + choices=choices, + usage=usage + ) + + def cost(self, response: ChatCompletion) -> float: + """Calculate the cost of the response.""" + # Pass + return 0.0 # Placeholder + + @staticmethod + def get_usage(response: ChatCompletion) -> Dict: + return { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + "cost": response.cost if hasattr(response, "cost") else 0, + "model": response.model, + } + +class GithubWrapper: + """Wrapper for GitHub LLM Client""" + + def __init__(self, config_list: Optional[List[Dict[str, Any]]] = None, **kwargs): + self._clients = [] + self._config_list = [] + + if config_list: + for config in config_list: + self._register_client(config) + self._config_list.append(config) + else: + self._register_client(kwargs) + self._config_list = [kwargs] + + def _register_client(self, config: Dict[str, Any]): + client = GithubClient(**config) + self._clients.append(client) + + def create(self, **params: Any) -> ChatCompletion: + """Create a completion using available clients.""" + for i, client in enumerate(self._clients): + try: + response = client.create(params) + response.config_id = i + return response + except Exception as e: + logger.warning(f"Error with client {i}: {str(e)}") + if i == len(self._clients) - 1: + raise + + def message_retrieval(self, response: ChatCompletion) -> List[str]: + """Retrieve messages from the response.""" + return self._clients[response.config_id].message_retrieval(response) + + def cost(self, response: ChatCompletion) -> float: + """Calculate the cost of the response.""" + return self._clients[response.config_id].cost(response) + + @staticmethod + def get_usage(response: ChatCompletion) -> Dict: + """Get usage information from the response.""" + return GithubClient.get_usage(response) + diff --git a/test/oai/test_githubllm.py b/test/oai/test_githubllm.py new file mode 100644 index 00000000000..fd76c68446d --- /dev/null +++ b/test/oai/test_githubllm.py @@ -0,0 +1,121 @@ +import pytest +from unittest.mock import patch, MagicMock +from autogen.oai.github import GithubClient, GithubWrapper + +@pytest.fixture +def github_client(): + with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): + return GithubClient(model="gpt-4o", system_prompt="Test prompt") + +@pytest.fixture +def github_wrapper(): + with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): + config = { + "model": "gpt-4o", + "system_prompt": "Test prompt", + "use_azure_fallback": True + } + return GithubWrapper(config_list=[config]) + +def test_github_client_initialization(github_client): + assert github_client.model == "gpt-4o" + assert github_client.system_prompt == "Test prompt" + assert github_client.use_azure_fallback == True + +def test_github_client_unsupported_model(): + with pytest.raises(ValueError): + with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): + GithubClient(model="unsupported-model") + +@patch('requests.post') +def test_github_client_create(mock_post, github_client): + mock_response = MagicMock() + mock_response.json.return_value = { + "id": "test_id", + "model": "gpt-4o", + "created": 1234567890, + "choices": [{"message": {"content": "Test response"}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + } + mock_post.return_value = mock_response + + params = {"messages": [{"role": "user", "content": "Test message"}]} + response = github_client.create(params) + + assert response.id == "test_id" + assert response.model == "gpt-4o" + assert len(response.choices) == 1 + assert response.choices[0].message.content == "Test response" + +def test_github_client_message_retrieval(github_client): + mock_response = MagicMock() + mock_response.choices = [ + MagicMock(message=MagicMock(content="Response 1")), + MagicMock(message=MagicMock(content="Response 2")) + ] + + messages = github_client.message_retrieval(mock_response) + assert messages == ["Response 1", "Response 2"] + +def test_github_client_cost(github_client): + mock_response = MagicMock() + cost = github_client.cost(mock_response) + assert cost == 0.0 # Assuming the placeholder implementation + +def test_github_client_get_usage(github_client): + mock_response = MagicMock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 20 + mock_response.usage.total_tokens = 30 + mock_response.model = "gpt-4o" + + usage = github_client.get_usage(mock_response) + assert usage["prompt_tokens"] == 10 + assert usage["completion_tokens"] == 20 + assert usage["total_tokens"] == 30 + assert usage["model"] == "gpt-4o" + +@patch('your_module.GithubClient.create') +def test_github_wrapper_create(mock_create, github_wrapper): + mock_response = MagicMock() + mock_create.return_value = mock_response + + params = {"messages": [{"role": "user", "content": "Test message"}]} + response = github_wrapper.create(**params) + + assert response == mock_response + assert hasattr(response, 'config_id') + mock_create.assert_called_once_with(params) + +def test_github_wrapper_message_retrieval(github_wrapper): + mock_response = MagicMock() + mock_response.config_id = 0 + + with patch.object(github_wrapper._clients[0], 'message_retrieval') as mock_retrieval: + mock_retrieval.return_value = ["Test message"] + messages = github_wrapper.message_retrieval(mock_response) + + assert messages == ["Test message"] + +def test_github_wrapper_cost(github_wrapper): + mock_response = MagicMock() + mock_response.config_id = 0 + + with patch.object(github_wrapper._clients[0], 'cost') as mock_cost: + mock_cost.return_value = 0.05 + cost = github_wrapper.cost(mock_response) + + assert cost == 0.05 + +def test_github_wrapper_get_usage(github_wrapper): + mock_response = MagicMock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 20 + mock_response.usage.total_tokens = 30 + mock_response.model = "gpt-4o" + + usage = github_wrapper.get_usage(mock_response) + assert usage["prompt_tokens"] == 10 + assert usage["completion_tokens"] == 20 + assert usage["total_tokens"] == 30 + assert usage["model"] == "gpt-4o" \ No newline at end of file From 1448b9c35433ec3ecfaced7b42e2a9e2a34881cc Mon Sep 17 00:00:00 2001 From: Tonic Date: Tue, 6 Aug 2024 15:48:08 +0200 Subject: [PATCH 2/7] add test and ruff lint improvements --- test/oai/test_githubllm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/oai/test_githubllm.py b/test/oai/test_githubllm.py index fd76c68446d..0a3b879f1dc 100644 --- a/test/oai/test_githubllm.py +++ b/test/oai/test_githubllm.py @@ -75,7 +75,7 @@ def test_github_client_get_usage(github_client): assert usage["total_tokens"] == 30 assert usage["model"] == "gpt-4o" -@patch('your_module.GithubClient.create') +@patch('autogen.oai.github.GithubClient.create') def test_github_wrapper_create(mock_create, github_wrapper): mock_response = MagicMock() mock_create.return_value = mock_response @@ -118,4 +118,4 @@ def test_github_wrapper_get_usage(github_wrapper): assert usage["prompt_tokens"] == 10 assert usage["completion_tokens"] == 20 assert usage["total_tokens"] == 30 - assert usage["model"] == "gpt-4o" \ No newline at end of file + assert usage["model"] == "gpt-4o" From 100298abab7b9e37a6c9022ecb20b92a180c1ca6 Mon Sep 17 00:00:00 2001 From: Tonic Date: Tue, 6 Aug 2024 17:32:11 +0200 Subject: [PATCH 3/7] fix some kind of linting --- autogen/oai/github.py | 51 ++++++++++++++------------------------ test/oai/test_githubllm.py | 49 ++++++++++++++++++++---------------- 2 files changed, 47 insertions(+), 53 deletions(-) diff --git a/autogen/oai/github.py b/autogen/oai/github.py index 31c0d07b1e7..4dfbdc7e94c 100644 --- a/autogen/oai/github.py +++ b/autogen/oai/github.py @@ -1,4 +1,4 @@ -'''Create a Github LLM Client with Azure Fallback. +"""Create a Github LLM Client with Azure Fallback. # Usage example: if __name__ == "__main__": @@ -7,9 +7,9 @@ "system_prompt": "You are a knowledgeable history teacher.", "use_azure_fallback": True } - + wrapper = GithubWrapper(config_list=[config]) - + response = wrapper.create(messages=[{"role": "user", "content": "What is the capital of France?"}]) print(wrapper.message_retrieval(response)[0]) @@ -18,29 +18,29 @@ {"role": "assistant", "content": "The French Revolution was a period of major social and political upheaval in France that began in 1789 with the Storming of the Bastille and ended in the late 1790s with the ascent of Napoleon Bonaparte."}, {"role": "user", "content": "What were the main causes?"} ] - + response = wrapper.create(messages=conversation) print(wrapper.message_retrieval(response)[0]) -''' +""" from __future__ import annotations - -import os +import json import logging +import os import time -import json -from typing import Any, Dict, List, Optional, Union, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import requests from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice from openai.types.completion_usage import CompletionUsage -from autogen.oai.client_utils import should_hide_tools, validate_parameter from autogen.cache import Cache +from autogen.oai.client_utils import should_hide_tools, validate_parameter logger = logging.getLogger(__name__) + class GithubClient: """GitHub LLM Client with Azure Fallback""" @@ -66,7 +66,7 @@ class GithubClient: "phi-3-mini-instruct-128k", "phi-3-mini-instruct-4k", "phi-3-small-instruct-128k", - "phi-3-small-instruct-8k" + "phi-3-small-instruct-8k", ] def __init__(self, **kwargs): @@ -97,22 +97,15 @@ def message_retrieval(self, response: ChatCompletion) -> List[str]: def create(self, params: Dict[str, Any]) -> ChatCompletion: """Create a completion for a given config.""" messages = params.get("messages", []) - + if "system" not in [m["role"] for m in messages]: messages.insert(0, {"role": "system", "content": self.system_prompt}) - data = { - "messages": messages, - "model": self.model, - **params - } + data = {"messages": messages, "model": self.model, **params} if self._check_rate_limit(): try: - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.github_token}" - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.github_token}"} response = self._call_api(self.github_endpoint_url, headers, data) self._increment_request_count() @@ -121,10 +114,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion: logger.warning(f"GitHub API call failed: {str(e)}. Falling back to Azure.") if self.use_azure_fallback: - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.azure_api_key}" - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.azure_api_key}"} response = self._call_api(self.github_endpoint_url, headers, data) return self._process_response(response) @@ -157,11 +147,8 @@ def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion: choices = [ Choice( index=i, - message=ChatCompletionMessage( - role="assistant", - content=choice["message"]["content"] - ), - finish_reason=choice.get("finish_reason") + message=ChatCompletionMessage(role="assistant", content=choice["message"]["content"]), + finish_reason=choice.get("finish_reason"), ) for i, choice in enumerate(response_data["choices"]) ] @@ -169,7 +156,7 @@ def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion: usage = CompletionUsage( prompt_tokens=response_data["usage"]["prompt_tokens"], completion_tokens=response_data["usage"]["completion_tokens"], - total_tokens=response_data["usage"]["total_tokens"] + total_tokens=response_data["usage"]["total_tokens"], ) return ChatCompletion( @@ -178,7 +165,7 @@ def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion: created=response_data["created"], object="chat.completion", choices=choices, - usage=usage + usage=usage, ) def cost(self, response: ChatCompletion) -> float: diff --git a/test/oai/test_githubllm.py b/test/oai/test_githubllm.py index 0a3b879f1dc..f1f65fa8e36 100644 --- a/test/oai/test_githubllm.py +++ b/test/oai/test_githubllm.py @@ -1,33 +1,34 @@ +from unittest.mock import MagicMock, patch import pytest -from unittest.mock import patch, MagicMock + from autogen.oai.github import GithubClient, GithubWrapper @pytest.fixture def github_client(): - with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): + with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}): return GithubClient(model="gpt-4o", system_prompt="Test prompt") + @pytest.fixture def github_wrapper(): - with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): - config = { - "model": "gpt-4o", - "system_prompt": "Test prompt", - "use_azure_fallback": True - } + with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}): + config = {"model": "gpt-4o", "system_prompt": "Test prompt", "use_azure_fallback": True} return GithubWrapper(config_list=[config]) + def test_github_client_initialization(github_client): assert github_client.model == "gpt-4o" assert github_client.system_prompt == "Test prompt" assert github_client.use_azure_fallback == True + def test_github_client_unsupported_model(): with pytest.raises(ValueError): - with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): + with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}): GithubClient(model="unsupported-model") -@patch('requests.post') + +@patch("requests.post") def test_github_client_create(mock_post, github_client): mock_response = MagicMock() mock_response.json.return_value = { @@ -35,7 +36,7 @@ def test_github_client_create(mock_post, github_client): "model": "gpt-4o", "created": 1234567890, "choices": [{"message": {"content": "Test response"}, "finish_reason": "stop"}], - "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, } mock_post.return_value = mock_response @@ -47,21 +48,24 @@ def test_github_client_create(mock_post, github_client): assert len(response.choices) == 1 assert response.choices[0].message.content == "Test response" + def test_github_client_message_retrieval(github_client): mock_response = MagicMock() mock_response.choices = [ MagicMock(message=MagicMock(content="Response 1")), - MagicMock(message=MagicMock(content="Response 2")) + MagicMock(message=MagicMock(content="Response 2")), ] - + messages = github_client.message_retrieval(mock_response) assert messages == ["Response 1", "Response 2"] + def test_github_client_cost(github_client): mock_response = MagicMock() cost = github_client.cost(mock_response) assert cost == 0.0 # Assuming the placeholder implementation + def test_github_client_get_usage(github_client): mock_response = MagicMock() mock_response.usage.prompt_tokens = 10 @@ -75,7 +79,8 @@ def test_github_client_get_usage(github_client): assert usage["total_tokens"] == 30 assert usage["model"] == "gpt-4o" -@patch('autogen.oai.github.GithubClient.create') + +@patch("autogen.oai.github.GithubClient.create") def test_github_wrapper_create(mock_create, github_wrapper): mock_response = MagicMock() mock_create.return_value = mock_response @@ -84,29 +89,31 @@ def test_github_wrapper_create(mock_create, github_wrapper): response = github_wrapper.create(**params) assert response == mock_response - assert hasattr(response, 'config_id') + assert hasattr(response, "config_id") mock_create.assert_called_once_with(params) def test_github_wrapper_message_retrieval(github_wrapper): mock_response = MagicMock() mock_response.config_id = 0 - - with patch.object(github_wrapper._clients[0], 'message_retrieval') as mock_retrieval: + + + with patch.object(github_wrapper._clients[0], "message_retrieval") as mock_retrieval: mock_retrieval.return_value = ["Test message"] messages = github_wrapper.message_retrieval(mock_response) - + assert messages == ["Test message"] def test_github_wrapper_cost(github_wrapper): mock_response = MagicMock() mock_response.config_id = 0 - - with patch.object(github_wrapper._clients[0], 'cost') as mock_cost: + + with patch.object(github_wrapper._clients[0], "cost") as mock_cost: mock_cost.return_value = 0.05 cost = github_wrapper.cost(mock_response) - + assert cost == 0.05 + def test_github_wrapper_get_usage(github_wrapper): mock_response = MagicMock() mock_response.usage.prompt_tokens = 10 From c17dc9708a34726f9d69d63184a27c61d0626873 Mon Sep 17 00:00:00 2001 From: Tonic Date: Tue, 6 Aug 2024 19:19:07 +0200 Subject: [PATCH 4/7] Update test_githubllm.py --- test/oai/test_githubllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/oai/test_githubllm.py b/test/oai/test_githubllm.py index f1f65fa8e36..6b3c918a8ce 100644 --- a/test/oai/test_githubllm.py +++ b/test/oai/test_githubllm.py @@ -19,7 +19,7 @@ def github_wrapper(): def test_github_client_initialization(github_client): assert github_client.model == "gpt-4o" assert github_client.system_prompt == "Test prompt" - assert github_client.use_azure_fallback == True + assert github_client.use_azure_fallback def test_github_client_unsupported_model(): From 63c2df2e7182a3b515821c484a4ad33223042666 Mon Sep 17 00:00:00 2001 From: Tonic Date: Tue, 6 Aug 2024 19:19:07 +0200 Subject: [PATCH 5/7] Update test_githubllm.py --- autogen/oai/github.py | 3 ++- test/oai/test_githubllm.py | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/autogen/oai/github.py b/autogen/oai/github.py index 4dfbdc7e94c..9d448e8c39e 100644 --- a/autogen/oai/github.py +++ b/autogen/oai/github.py @@ -24,6 +24,7 @@ """ from __future__ import annotations + import json import logging import os @@ -183,6 +184,7 @@ def get_usage(response: ChatCompletion) -> Dict: "model": response.model, } + class GithubWrapper: """Wrapper for GitHub LLM Client""" @@ -226,4 +228,3 @@ def cost(self, response: ChatCompletion) -> float: def get_usage(response: ChatCompletion) -> Dict: """Get usage information from the response.""" return GithubClient.get_usage(response) - diff --git a/test/oai/test_githubllm.py b/test/oai/test_githubllm.py index f1f65fa8e36..7d58561f443 100644 --- a/test/oai/test_githubllm.py +++ b/test/oai/test_githubllm.py @@ -1,8 +1,10 @@ from unittest.mock import MagicMock, patch + import pytest from autogen.oai.github import GithubClient, GithubWrapper + @pytest.fixture def github_client(): with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}): @@ -19,7 +21,7 @@ def github_wrapper(): def test_github_client_initialization(github_client): assert github_client.model == "gpt-4o" assert github_client.system_prompt == "Test prompt" - assert github_client.use_azure_fallback == True + assert github_client.use_azure_fallback def test_github_client_unsupported_model(): @@ -92,17 +94,18 @@ def test_github_wrapper_create(mock_create, github_wrapper): assert hasattr(response, "config_id") mock_create.assert_called_once_with(params) + def test_github_wrapper_message_retrieval(github_wrapper): mock_response = MagicMock() mock_response.config_id = 0 - with patch.object(github_wrapper._clients[0], "message_retrieval") as mock_retrieval: mock_retrieval.return_value = ["Test message"] messages = github_wrapper.message_retrieval(mock_response) assert messages == ["Test message"] + def test_github_wrapper_cost(github_wrapper): mock_response = MagicMock() mock_response.config_id = 0 From 28d047ea20074700c5ad83df64202134decba7c0 Mon Sep 17 00:00:00 2001 From: Tonic Date: Mon, 23 Sep 2024 13:18:51 +0200 Subject: [PATCH 6/7] add a branch --- autogen/oai/{github.py => aiinference.py} | 0 test/oai/{test_githubllm.py => test_aiinference.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename autogen/oai/{github.py => aiinference.py} (100%) rename test/oai/{test_githubllm.py => test_aiinference.py} (100%) diff --git a/autogen/oai/github.py b/autogen/oai/aiinference.py similarity index 100% rename from autogen/oai/github.py rename to autogen/oai/aiinference.py diff --git a/test/oai/test_githubllm.py b/test/oai/test_aiinference.py similarity index 100% rename from test/oai/test_githubllm.py rename to test/oai/test_aiinference.py From 9473bdb3d7bc069a239b75b35f8df77f2f01ab46 Mon Sep 17 00:00:00 2001 From: Tonic Date: Mon, 23 Sep 2024 14:55:06 +0200 Subject: [PATCH 7/7] add azure ai inference api --- autogen/oai/aiinference.py | 175 +++++++++++++++++++------------------ autogen/oai/client.py | 13 +++ 2 files changed, 101 insertions(+), 87 deletions(-) diff --git a/autogen/oai/aiinference.py b/autogen/oai/aiinference.py index 7216998d106..438c74ed608 100644 --- a/autogen/oai/aiinference.py +++ b/autogen/oai/aiinference.py @@ -1,28 +1,3 @@ -"""Create a Github LLM Client with Azure Fallback. - -# Usage example: -if __name__ == "__main__": - config = { - "model": "gpt-4o", - "system_prompt": "You are a knowledgeable history teacher.", - "use_azure_fallback": True - } - - wrapper = GithubWrapper(config_list=[config]) - - response = wrapper.create(messages=[{"role": "user", "content": "What is the capital of France?"}]) - print(wrapper.message_retrieval(response)[0]) - - conversation = [ - {"role": "user", "content": "Tell me about the French Revolution."}, - {"role": "assistant", "content": "The French Revolution was a period of major social and political upheaval in France that began in 1789 with the Storming of the Bastille and ended in the late 1790s with the ascent of Napoleon Bonaparte."}, - {"role": "user", "content": "What were the main causes?"} - ] - - response = wrapper.create(messages=conversation) - print(wrapper.message_retrieval(response)[0]) -""" - from __future__ import annotations import json @@ -36,14 +11,61 @@ from openai.types.chat.chat_completion import Choice from openai.types.completion_usage import CompletionUsage -from autogen.cache import Cache -from autogen.oai.client_utils import should_hide_tools, validate_parameter +from autogen.oai.client_utils import validate_parameter logger = logging.getLogger(__name__) +class AzureAIInferenceClient: + """Azure AI Inference Client + + This class provides an interface to interact with Azure AI Inference API for natural language processing tasks. + It supports various language models and handles API requests, response processing, and error handling. + + Key Features: + 1. Supports multiple AI models provided by Azure AI Inference. + 2. Handles authentication using API keys. + 3. Provides methods for creating chat completions. + 4. Processes and formats API responses into standardized ChatCompletion objects. + 5. Implements rate limiting and error handling for robust API interactions. + 6. Calculates usage statistics and estimated costs for API calls. + + Usage: + - Initialize the client with the desired model and API key. + - Use the 'create' method to generate chat completions. + - Retrieve messages and usage information from the responses. + + Note: Ensure that the AZURE_API_KEY is set in the environment variables or provided during initialization. + + # Example usage + if __name__ == "__main__": + import os + import autogen + + config_list = [ + { + "model": "gpt-4o", + "api_key": os.getenv("AZURE_API_KEY"), + } + ] -class GithubClient: - """GitHub LLM Client with Azure Fallback""" + assistant = autogen.AssistantAgent( + "assistant", + llm_config={"config_list": config_list, "cache_seed": 42}, + ) + + human = autogen.UserProxyAgent( + "human", + human_input_mode="TERMINATE", + max_consecutive_auto_reply=10, + code_execution_config={"work_dir": "coding"}, + llm_config={"config_list": config_list, "cache_seed": 42}, + ) + + human.initiate_chat( + assistant, + message="Would I be better off deploying multiple models on cloud or at home?", + ) + """ SUPPORTED_MODELS = [ "AI21-Jamba-Instruct", @@ -69,74 +91,54 @@ class GithubClient: ] def __init__(self, **kwargs): - self.github_endpoint_url = "https://models.inference.ai.azure.com/chat/completions" + self.endpoint_url = "https://models.inference.ai.azure.com/chat/completions" self.model = kwargs.get("model") - self.system_prompt = kwargs.get("system_prompt", "You are a helpful assistant.") - self.use_azure_fallback = kwargs.get("use_azure_fallback", True) - self.rate_limit_reset_time = 0 - self.request_count = 0 - self.max_requests_per_minute = 15 - self.max_requests_per_day = 150 - - self.github_token = os.environ.get("GITHUB_TOKEN") - self.azure_api_key = os.environ.get("AZURE_API_KEY") - - if not self.github_token: - raise ValueError("GITHUB_TOKEN environment variable is not set.") - if self.use_azure_fallback and not self.azure_api_key: - raise ValueError("AZURE_API_KEY environment variable is not set.") + self.api_key = kwargs.get("api_key") or os.environ.get("AZURE_API_KEY") + + if not self.api_key: + raise ValueError("AZURE_API_KEY is not set in environment variables or provided in kwargs.") if self.model.lower() not in [model.lower() for model in self.SUPPORTED_MODELS]: raise ValueError(f"Model {self.model} is not supported. Please choose from {self.SUPPORTED_MODELS}") + def load_config(self, params: Dict[str, Any]) -> Dict[str, Any]: + """Load the configuration for the Azure AI Inference client.""" + config = {} + config["model"] = params.get("model", self.model) + config["temperature"] = validate_parameter(params, "temperature", (float, int), False, 1.0, (0.0, 2.0), None) + config["max_tokens"] = validate_parameter(params, "max_tokens", int, False, 4096, (1, None), None) + config["top_p"] = validate_parameter(params, "top_p", (float, int), True, None, (0.0, 1.0), None) + config["stop"] = validate_parameter(params, "stop", (str, list), True, None, None, None) + config["stream"] = validate_parameter(params, "stream", bool, False, False, None, None) + + return config + def message_retrieval(self, response: ChatCompletion) -> List[str]: """Retrieve the messages from the response.""" return [choice.message.content for choice in response.choices] def create(self, params: Dict[str, Any]) -> ChatCompletion: """Create a completion for a given config.""" + config = self.load_config(params) messages = params.get("messages", []) - if "system" not in [m["role"] for m in messages]: - messages.insert(0, {"role": "system", "content": self.system_prompt}) - - data = {"messages": messages, "model": self.model, **params} - - if self._check_rate_limit(): - try: - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.github_token}"} + data = { + "messages": messages, + "model": config["model"], + "temperature": config["temperature"], + "max_tokens": config["max_tokens"], + "top_p": config["top_p"], + "stop": config["stop"], + "stream": config["stream"], + } - response = self._call_api(self.github_endpoint_url, headers, data) - self._increment_request_count() - return self._process_response(response) - except (requests.exceptions.RequestException, ValueError) as e: - logger.warning(f"GitHub API call failed: {str(e)}. Falling back to Azure.") + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} - if self.use_azure_fallback: - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.azure_api_key}"} - - response = self._call_api(self.github_endpoint_url, headers, data) - return self._process_response(response) - else: - raise ValueError("Rate limit reached and Azure fallback is disabled.") - - def _check_rate_limit(self) -> bool: - """Check if the rate limit has been reached.""" - current_time = time.time() - if current_time < self.rate_limit_reset_time: - return False - if self.request_count >= self.max_requests_per_minute: - self.rate_limit_reset_time = current_time + 60 - self.request_count = 0 - return False - return True - - def _increment_request_count(self): - """Increment the request count.""" - self.request_count += 1 + response = self._call_api(self.endpoint_url, headers, data) + return self._process_response(response) def _call_api(self, endpoint_url: str, headers: Dict[str, str], data: Dict[str, Any]) -> Dict[str, Any]: - """Make an API call to either GitHub or Azure.""" + """Make an API call to Azure AI Inference.""" response = requests.post(endpoint_url, headers=headers, json=data) response.raise_for_status() return response.json() @@ -169,8 +171,8 @@ def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion: def cost(self, response: ChatCompletion) -> float: """Calculate the cost of the response.""" - # Pass - return 0.0 # Placeholder + # Implement cost calculation logic here if needed + return 0.0 @staticmethod def get_usage(response: ChatCompletion) -> Dict: @@ -182,9 +184,8 @@ def get_usage(response: ChatCompletion) -> Dict: "model": response.model, } - -class GithubWrapper: - """Wrapper for GitHub LLM Client""" +class AzureAIInferenceWrapper: + """Wrapper for Azure AI Inference Client""" def __init__(self, config_list: Optional[List[Dict[str, Any]]] = None, **kwargs): self._clients = [] @@ -199,7 +200,7 @@ def __init__(self, config_list: Optional[List[Dict[str, Any]]] = None, **kwargs) self._config_list = [kwargs] def _register_client(self, config: Dict[str, Any]): - client = GithubClient(**config) + client = AzureAIInferenceClient(**config) self._clients.append(client) def create(self, **params: Any) -> ChatCompletion: @@ -225,4 +226,4 @@ def cost(self, response: ChatCompletion) -> float: @staticmethod def get_usage(response: ChatCompletion) -> Dict: """Get usage information from the response.""" - return GithubClient.get_usage(response) + return AzureAIInferenceClient.get_usage(response) \ No newline at end of file diff --git a/autogen/oai/client.py b/autogen/oai/client.py index fb13afdfcc6..6d053936f81 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -84,6 +84,13 @@ except ImportError as e: cohere_import_exception = e +try : + from autogen.oai.aiinference import AzureAIInferenceClient + + aiinference_import_exception : Optional[ImportError] = None +except ImportError as e: + aiinference_import_exception = e + logger = logging.getLogger(__name__) if not logger.handlers: # Add the console handler. @@ -522,6 +529,12 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s raise ImportError("Please install `cohere` to use the Groq API.") client = CohereClient(**openai_config) self._clients.append(client) + elif api_type is not None and api_type.startswith("aiinference"): + if aiinference_import_exception: + raise ImportError("Please install `azure-ai-inference` to use Azure Ai Inference API.") + client = AzureAIInferenceClient(**openai_config) + self._clients.append(client) + else: client = OpenAI(**openai_config) self._clients.append(OpenAIClient(client))