Skip to content

Commit

Permalink
Merge pull request #214 from MadcowD/wguss/azure
Browse files Browse the repository at this point in the history
Azure
  • Loading branch information
MadcowD authored Sep 23, 2024
2 parents 816d2ea + f100098 commit b89c184
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 6 deletions.
37 changes: 37 additions & 0 deletions examples/providers/azure_ex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import ell
import openai
import os
ell.init(verbose=True, store='./logdir')

# your subscription key
subscription_key = os.getenv("AZURE_OPENAI_API_KEY")
# Your Azure OpenAI resource https://<your resource name>.openai.azure.com/
azure_endpoint = "https://<your resource name>.openai.azure.com/"
# Option 2: Use a client directly
azureClient = openai.AzureOpenAI(
azure_endpoint = azure_endpoint,
api_key = subscription_key,
api_version = "2024-05-01-preview",
)
# (Recommended) Option 1: Register all the models on your Azure resource & use your models automatically
ell.config.register_model("<your-azure-model-deployment-name>", azureClient)

@ell.simple(model="<your-azure-model-deployment-name>")
def write_a_story(about : str):
return f"write me a story about {about}!"

write_a_story("cats")


# Option 2: Use a client directly
azureClient = openai.AzureOpenAI(
azure_endpoint = azure_endpoint,
api_key = subscription_key,
api_version = "2024-05-01-preview",
)

@ell.simple(model="<your-azure-model-deployment-name>", client=azureClient)
def write_a_story(about : str):
return f"write me a story about {about}"

write_a_story("cats")
4 changes: 3 additions & 1 deletion examples/providers/groq_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@


ell.init(verbose=True, store='./logdir')

# (Recomended) Option 1: Register all groq models.
ell.models.groq.register() # use GROQ_API_KEY env var
# ell.models.groq.register(api_key="gsk-") #

Expand All @@ -15,7 +17,7 @@ def write_a_story(about : str):

write_a_story("cats")

# or use the client directly
# Option 2: Use a client directly
client = groq.Groq()

@ell.simple(model="llama3-8b-8192", temperature=0.1, client=client)
Expand Down
7 changes: 5 additions & 2 deletions src/ell/configurator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import wraps
from functools import lru_cache, wraps
from typing import Dict, Any, Optional, Tuple, Union, Type
import openai
import logging
Expand Down Expand Up @@ -133,7 +133,10 @@ def get_provider_for(self, client: Union[Type[Any], Any]) -> Optional[Provider]:
"""

client_type = type(client) if not isinstance(client, type) else client
return self.providers.get(client_type)
for provider_type, provider in self.providers.items():
if issubclass(client_type, provider_type) or client_type == provider_type:
return provider
return None

# Single* instance
# XXX: Make a singleton
Expand Down
3 changes: 1 addition & 2 deletions src/ell/lmp/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def complex(model: str, client: Optional[Any] = None, tools: Optional[List[Calla
default_client_from_decorator = client
default_model_from_decorator = model
default_api_params_from_decorator = api_params

def parameterized_lm_decorator(
prompt: LMP,
) -> Callable[..., Union[List[Message], Message]]:
Expand Down Expand Up @@ -66,7 +65,7 @@ def model_call(

if should_log: model_usage_logger_post_start(n)
with model_usage_logger_post_intermediate(n) as _logger:
(result, final_api_params, metadata) = provider.call(ell_call, origin_id=_invocation_origin, logger=_logger)
(result, final_api_params, metadata) = provider.call(ell_call, origin_id=_invocation_origin, logger=_logger if should_log else None)
if isinstance(result, list) and len(result) == 1:
result = result[0]

Expand Down
2 changes: 1 addition & 1 deletion src/ell/lmp/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def simple(model: str, client: Optional[Any] = None, exempt_from_tracking=False

def convert_multimodal_response_to_lstr(response):
return [x.content[0].text for x in response] if isinstance(response, list) else response.content[0].text
return complex(model, client, exempt_from_tracking, **api_params, post_callback=convert_multimodal_response_to_lstr)
return complex(model, client, exempt_from_tracking=exempt_from_tracking, **api_params, post_callback=convert_multimodal_response_to_lstr)



Expand Down

0 comments on commit b89c184

Please sign in to comment.