diff --git a/examples/providers/azure_ex.py b/examples/providers/azure_ex.py new file mode 100644 index 00000000..346104e7 --- /dev/null +++ b/examples/providers/azure_ex.py @@ -0,0 +1,37 @@ +import ell +import openai +import os +ell.init(verbose=True, store='./logdir') + +# your subscription key +subscription_key = os.getenv("AZURE_OPENAI_API_KEY") +# Your Azure OpenAI resource https://.openai.azure.com/ +azure_endpoint = "https://.openai.azure.com/" +# Option 2: Use a client directly +azureClient = openai.AzureOpenAI( + azure_endpoint = azure_endpoint, + api_key = subscription_key, + api_version = "2024-05-01-preview", +) +# (Recommended) Option 1: Register all the models on your Azure resource & use your models automatically +ell.config.register_model("", azureClient) + +@ell.simple(model="") +def write_a_story(about : str): + return f"write me a story about {about}!" + +write_a_story("cats") + + +# Option 2: Use a client directly +azureClient = openai.AzureOpenAI( + azure_endpoint = azure_endpoint, + api_key = subscription_key, + api_version = "2024-05-01-preview", +) + +@ell.simple(model="", client=azureClient) +def write_a_story(about : str): + return f"write me a story about {about}" + +write_a_story("cats") diff --git a/examples/providers/groq_ex.py b/examples/providers/groq_ex.py index a0a527c1..a2c445f7 100644 --- a/examples/providers/groq_ex.py +++ b/examples/providers/groq_ex.py @@ -6,6 +6,8 @@ ell.init(verbose=True, store='./logdir') + +# (Recomended) Option 1: Register all groq models. ell.models.groq.register() # use GROQ_API_KEY env var # ell.models.groq.register(api_key="gsk-") # @@ -15,7 +17,7 @@ def write_a_story(about : str): write_a_story("cats") -# or use the client directly +# Option 2: Use a client directly client = groq.Groq() @ell.simple(model="llama3-8b-8192", temperature=0.1, client=client) diff --git a/src/ell/configurator.py b/src/ell/configurator.py index a0758f5a..896181ce 100644 --- a/src/ell/configurator.py +++ b/src/ell/configurator.py @@ -1,4 +1,4 @@ -from functools import wraps +from functools import lru_cache, wraps from typing import Dict, Any, Optional, Tuple, Union, Type import openai import logging @@ -133,7 +133,10 @@ def get_provider_for(self, client: Union[Type[Any], Any]) -> Optional[Provider]: """ client_type = type(client) if not isinstance(client, type) else client - return self.providers.get(client_type) + for provider_type, provider in self.providers.items(): + if issubclass(client_type, provider_type) or client_type == provider_type: + return provider + return None # Single* instance # XXX: Make a singleton diff --git a/src/ell/lmp/complex.py b/src/ell/lmp/complex.py index 2d2c2bfd..f8b7b167 100644 --- a/src/ell/lmp/complex.py +++ b/src/ell/lmp/complex.py @@ -17,7 +17,6 @@ def complex(model: str, client: Optional[Any] = None, tools: Optional[List[Calla default_client_from_decorator = client default_model_from_decorator = model default_api_params_from_decorator = api_params - def parameterized_lm_decorator( prompt: LMP, ) -> Callable[..., Union[List[Message], Message]]: @@ -66,7 +65,7 @@ def model_call( if should_log: model_usage_logger_post_start(n) with model_usage_logger_post_intermediate(n) as _logger: - (result, final_api_params, metadata) = provider.call(ell_call, origin_id=_invocation_origin, logger=_logger) + (result, final_api_params, metadata) = provider.call(ell_call, origin_id=_invocation_origin, logger=_logger if should_log else None) if isinstance(result, list) and len(result) == 1: result = result[0] diff --git a/src/ell/lmp/simple.py b/src/ell/lmp/simple.py index 086b3ade..b1cd5813 100644 --- a/src/ell/lmp/simple.py +++ b/src/ell/lmp/simple.py @@ -11,7 +11,7 @@ def simple(model: str, client: Optional[Any] = None, exempt_from_tracking=False def convert_multimodal_response_to_lstr(response): return [x.content[0].text for x in response] if isinstance(response, list) else response.content[0].text - return complex(model, client, exempt_from_tracking, **api_params, post_callback=convert_multimodal_response_to_lstr) + return complex(model, client, exempt_from_tracking=exempt_from_tracking, **api_params, post_callback=convert_multimodal_response_to_lstr)