From 569f9a996abe928fba13414dae1bbdf9626d82a2 Mon Sep 17 00:00:00 2001 From: William Guss Date: Wed, 18 Sep 2024 14:45:45 -0700 Subject: [PATCH 01/17] rambling.. --- docs/ramblings/providers.py | 228 ++++++++++++++++++++++++++++++++++++ 1 file changed, 228 insertions(+) create mode 100644 docs/ramblings/providers.py diff --git a/docs/ramblings/providers.py b/docs/ramblings/providers.py new file mode 100644 index 00000000..85ee3689 --- /dev/null +++ b/docs/ramblings/providers.py @@ -0,0 +1,228 @@ + + + +# Goal with this refactor +# - Force a clean provider interface so that implementers build compatible and maintainable interfaces +# - Automate testing of new providers +# - Make the code as understandable as possible. +# - Manage all the metadata around providers in one place. +# - Providers should specify what they are capable of so we can validate at compile time that it makese sense (what params are available) + + +def validate_call_params(self, model : str, client : Any, api_params : Dict[str, Any]) -> None: + """Validates the call parameters.""" + pass + + +class ProviderCapabilities(BaseModel): + """The capabilities of a provider. This allowes ell to validate at compile time that a provider supports the features it needs.""" + supports_streaming : bool + supports_structured_outputs : bool + supports_function_calling : bool + supports_tool_calling : bool + + +@abstractmethod +def capabilities(self, model : str, client : Any) -> ProviderCapabilities: + """Returns the capabilities of the provider.""" + pass + +@abstractmethod +def ell_call_to_provider_call(self, ell_call : EllCall) -> T: + """Converts an EllCall to a provider call.""" + pass + +@abstractmethod +def provider_response_to_ell_response(self, ell_call : EllCall, provider_response : Any) -> EllResponse: + """Converts a provider response to an Ell response.""" + pass + + +class Provider(ABC) + + @abstractmethod + def provider_call_function(self, client) -> Callable: + """Returns the function that makes the call to the provider.""" + return NotImplemented + + + +class OpenAIProvider(Provider): + def provider_call_function(self, client) -> Callable: + return client.chat.completions.create + + +import inspect +from typing import Any, Dict + +def validate_provider_call_params(self, ell_call: EllCall, client: Any): + provider_call_func = self.provider_call_function(client) + provider_call_params = inspect.signature(provider_call_func).parameters + + converted_params = self.ell_call_to_provider_call(ell_call) + + required_params = { + name: param for name, param in provider_call_params.items() + if param.default == param.empty and param.kind != param.VAR_KEYWORD + } + + for param_name in required_params: + assert param_name in converted_params, f"Required parameter '{param_name}' is missing in the converted call parameters." + + for param_name, param_value in converted_params.items(): + assert param_name in provider_call_params, f"Unexpected parameter '{param_name}' in the converted call parameters." + + param_type = provider_call_params[param_name].annotation + if param_type != inspect.Parameter.empty: + assert isinstance(param_value, param_type), f"Parameter '{param_name}' should be of type {param_type}." + + print("All parameters validated successfully.") + + + +# How do we force the nick scenario +# If we use response_format -> we sshould parse the resposne into the universal format. + + +# i like that u can use your proviers params in your @ell.call +# alterntively we coudl do the vercel shit + +# universal params: subset of params + +class UniversalParams(BaseModel): + messages : List[Message] + + +@ell.simple(openai("gpt-4", **openai params), tools=[], ell params.. ) + + + +# Trying to currently solve hte params problem. I dont want you to have to learn a new set of params. You should be able to use your API params however you want. +# Not even a universal set of params. But then we get ugly shit like + +@ell.simple("claude-3", system="hi") + + +# Process +# (messages + tools + widgets) -> (call params + messages) -> (resposne (no streaming)) -> (messages + metadata) + +# +# is that api params can live inside of messages +# Compoenents aroudn are + + + +# 1. we create the call parameters +# 2. we validate the call parameters + # Certain things arent allowed like stream=True for non-streaming providers +# 3. we send them to the api +# 4. we translate the response to universal format +# 5. we return the resposne toe hte api file. + + + +# Params +# eveyr api has their own set of params. the ell way right now is fine, but some should be prohibited and we want to know what params are available. +# can solve using + + + +class Provider2_0(ABC): + + """Universal Parameters""" + @abstractmethod + def provider_call_function(self, client : Optional[Any] = None, model : Optional[str] = None) -> Dict[str, Any]: + return NotImplemented + + # How do we prevent system param? + @abstractmethod + def disallowed_provider_params(self) -> List[str]: + """ + Returns a list of disallowed call params that ell will override. + """ + return {"system", "tools", "tool_choice", "stream", "functions", "function_call"} + + def available_params(self): + return inspect.signature(self.provider_call_function).parameters - self.disallowed_provider_params() + + """Universal Messages""" + @abstractmethod + def translate_ell_to_provider(self, ell_call : EllCall) -> Any: + """Converts universal messages to the provider-specific format.""" + return NotImplemented + + @abstractmethod + def translate_provider_to_ell(self, provider_response : Any, ell_call : EllCall) -> Tuple[List[Message], EllMetadata]: + """Converts provider responses to universal format.""" + return NotImplemented + + def call_model(self, client : Optional[Any] = None, model : Optional[str] = None, messages : Optional[List[Message]] = None, tools : Optional[List[LMP]] = None, **api_params) -> Any: + # Automatic validation of params + assert api_params.keys() in self.available_params(), f"Invalid parameters: {api_params}" + assert api_params.keys() not in self.disallowed_provider_params(), f"Disallowed parameters: {api_params}" + + # Call + call_params = self.translate_ell_to_provider(ell_call) + provider_resp = self.provider_call_function(client, model)(**call_params) + return self.translate_provider_to_ell(provider_resp, ell_call) + + +class CallMetadata(BaseModel): + """A universal metadata format for ell studio?""" + usage : Optional[Usage] = None + model : Optional[str] = None + provider : Optional[str] = None + provider_response : Optional[Any] = None + other : Optional[Dict[str, Any]] = None + + +# TODO: How does this interact with streaming? Cause isn't the full story + + + +# Translationc + +# How do we force implementers to implement parameter translation like tools etc. +# What about capabilities? Why do we need to know? Well if there aren't any tools available. + + +def translate_provider_to_ell( + ell_call : EllCall, + provider_response : Any +) -> Tuple[[Message], CallMetadata]: + """Converts provider responses to universal format.""" + return NotImplemented + +# We have to actually test with a known provider response which we cant automatically do +# We could force providers to extract toolcalls from the response and then we wouldnt have to do it for every provider. + + +@ell.simple(tools=[mytool], system="hi") +def my_prompt(self, client, model, messages, tools, **api_params): + return "usethist tool" + + +# This is bad because we providers have different levels of multimodality etc. +class Provider(ABC): + + @abstractmethod + def response_to_tool_calls(self, provider_response : Any) -> List[ToolCall]: + """Extracts tool calls from the provider response.""" + return NotImplemented + + @abstractmethod + def response_to_content(self, provider_response : Any) -> str: + """Extracts the content from the provider response.""" + return NotImplemented + +# How would you guarantee that a provider? Respond with a tool call if a tool call occurs within the provider. +# Without actually knowing the details of the provider, there's no way To guarantee this. It almost has to be like A required argument of the response construction + +So you could. Require the implementer to say if there were A tool call or not in the response. +It's not possible to prevent people from writing **** code. Like we can't know if they're stupid provider has a type of a response that's not a tool call. +Unless we really explicitly add them mark what was in the response. + + +# Models (maybe models should live close to providers) + +# This prevents us from doing routing but that's actualyl openrouters purpose \ No newline at end of file From b4fb9086dd150de6ea6405d7153c52ac2190f587 Mon Sep 17 00:00:00 2001 From: William Guss Date: Wed, 18 Sep 2024 20:54:04 -0700 Subject: [PATCH 02/17] wip --- docs/ramblings/providers.py | 1 - src/ell/provider.py | 115 +++++++++++++++++++++++++++--------- 2 files changed, 87 insertions(+), 29 deletions(-) diff --git a/docs/ramblings/providers.py b/docs/ramblings/providers.py index 85ee3689..0fa37df8 100644 --- a/docs/ramblings/providers.py +++ b/docs/ramblings/providers.py @@ -222,7 +222,6 @@ def response_to_content(self, provider_response : Any) -> str: It's not possible to prevent people from writing **** code. Like we can't know if they're stupid provider has a type of a response that's not a tool call. Unless we really explicitly add them mark what was in the response. - # Models (maybe models should live close to providers) # This prevents us from doing routing but that's actualyl openrouters purpose \ No newline at end of file diff --git a/src/ell/provider.py b/src/ell/provider.py index 9350be20..f995f3c4 100644 --- a/src/ell/provider.py +++ b/src/ell/provider.py @@ -1,13 +1,12 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple, Type, TypedDict, Union from ell.types import Message, ContentBlock, ToolCall from ell.types._lstr import _lstr import json from dataclasses import dataclass from ell.types.message import LMP - @dataclass class APICallResult: response: Any @@ -15,42 +14,102 @@ class APICallResult: actual_n: int final_call_params: Dict[str, Any] +class EllCall(TypedDict): + messages: List[Message] + client: Optional[Any] = None + tools: Optional[List[LMP]] = None + response_format: Optional[Dict[str, Any]] = None + +e = EllCall(messages=[], client=None, tools=None, response_format=None) + + +class Metadata(TypedDict): + """First class metadata so that ell studio can work, you can add more stuff here if you want""" + class Provider(ABC): """ Abstract base class for all providers. Providers are API interfaces to language models, not necessarily API providers. For example, the OpenAI provider is an API interface to OpenAI's API but also to Ollama and Azure OpenAI. + In Ell. We hate abstractions. The only reason this exists is to force implementers to implement their own provider correctly -_-. """ - @classmethod + ################################ + ### API PARAMETERS ############# + ################################ @abstractmethod - def call_model( - cls, - client: Any, - model: str, - messages: List[Any], - api_params: Dict[str, Any], - tools: Optional[list[LMP]] = None, - ) -> APICallResult: - """Make the API call to the language model and return the result along with actual streaming, n values, and final call parameters.""" - pass - - @classmethod + def provider_call_function(self, **ell_call : EllCall) -> Dict[str, Any]: + """ + Implement this method to return the function that makes the API call to the language model. + For example, if you're implementing the OpenAI provider, you would return the function that makes the API call to OpenAI's API. + ```python + return openai.Completion.create + ``` + """ + return NotImplemented + @abstractmethod - def process_response( - cls, call_result: APICallResult, _invocation_origin: str, logger: Optional[Any] = None, tools: Optional[List[LMP]] = None, - ) -> Tuple[List[Message], Dict[str, Any]]: - """Process the API response and convert it to ell format.""" - pass + def disallowed_provider_params(self) -> FrozenSet[str]: + """ + Returns a list of disallowed call params that ell will override. + """ + return frozenset({"system", "tools", "tool_choice", "stream", "functions", "function_call", "response_format"}) - @classmethod + def available_params(self) -> APICallParams: + return get_params_of_call_function + EllCall.__required_keys__ + + + ################################ + ### TRANSLATION ############### + ################################ @abstractmethod - def supports_streaming(cls) -> bool: - """Check if the provider supports streaming.""" - pass + def translate_to_provider(self, ell_call : EllCall) -> APICallParams: + """Converts an ell call to provider call params!""" + return NotImplemented + + @abstractmethod + def translate_from_provider(self, provider_response : Any, ell_call : EllCall) -> Tuple[List[Message], Metadata]: + """Converts provider responses to universal format.""" + return NotImplemented + + ################################ + ### CALL MODEL ################ + ################################ + def call_model(self, client : Optional[Any] = None, model : Optional[str] = None, messages : Optional[List[Message]] = None, tools : Optional[List[LMP]] = None, **api_params) -> Any: + # Automatic validation of params + assert api_params.keys() in self.available_params(), f"Invalid parameters: {api_params}" + assert api_params.keys() not in self.disallowed_provider_params(), f"Disallowed parameters: {api_params}" - @classmethod + # Call + call_params = self.translate_to_provider(ell_call) + provider_resp = self.provider_call_function(client, model)(**call_params) + return self.translate_from_provider(provider_resp, ell_call) + + +class Provider2_0(ABC): + + + # How do we prevent system param? @abstractmethod - def get_client_type(cls) -> Type: - """Return the type of client this provider supports.""" - pass + def disallowed_provider_params(self) -> List[str]: + """ + Returns a list of disallowed call params that ell will override. + """ + return {"system", "tools", "tool_choice", "stream", "functions", "function_call"} + + + @abstractmethod + def translate_provider_to_ell(self, provider_response : Any, ell_call : EllCall) -> Tuple[List[Message], EllMetadata]: + """Converts provider responses to universal format.""" + return NotImplemented + + def call_model(self, client : Optional[Any] = None, model : Optional[str] = None, messages : Optional[List[Message]] = None, tools : Optional[List[LMP]] = None, **api_params) -> Any: + # Automatic validation of params + assert api_params.keys() in self.available_params(), f"Invalid parameters: {api_params}" + assert api_params.keys() not in self.disallowed_provider_params(), f"Disallowed parameters: {api_params}" + + # Call + call_params = self.translate_ell_to_provider(ell_call) + provider_resp = self.provider_call_function(client, model)(**call_params) + return self.translate_provider_to_ell(provider_resp, ell_call) + \ No newline at end of file From e4c8ce5fee2184ed001efbf89d0a4e46dd9a09ab Mon Sep 17 00:00:00 2001 From: William Guss Date: Thu, 19 Sep 2024 12:52:15 -0700 Subject: [PATCH 03/17] experiemting --- docs/ramblings/0.1.0/autostreamprevention.py | 28 +++++ docs/ramblings/0.1.0/mypytest.py | 13 +++ src/ell/lmp/complex.py | 53 +++++---- src/ell/provider.py | 110 +++++++++++++------ src/ell/util/api.py | 13 +-- src/ell/util/verbosity.py | 6 +- 6 files changed, 161 insertions(+), 62 deletions(-) create mode 100644 docs/ramblings/0.1.0/autostreamprevention.py create mode 100644 docs/ramblings/0.1.0/mypytest.py diff --git a/docs/ramblings/0.1.0/autostreamprevention.py b/docs/ramblings/0.1.0/autostreamprevention.py new file mode 100644 index 00000000..0a568e27 --- /dev/null +++ b/docs/ramblings/0.1.0/autostreamprevention.py @@ -0,0 +1,28 @@ +import openai +import os + +# Define the function to stream the response +def stream_openai_response(prompt): + try: + # Make the API call + response = openai.chat.completions.create( + model="o1-mini", # Specify the model + messages=[{"role": "user", "content": prompt}], + stream=True # Enable streaming + ) + + # Stream the response + for chunk in response: + if chunk.choices[0].delta.get("content"): + print(chunk.choices[0].delta.content, end="", flush=True) + + print() # Print a newline at the end + + except Exception as e: + print(f"An error occurred: {e}") + +# Example usage +prompt = "Tell me a short joke." +stream_openai_response(prompt) + +# This shows that openai won't fake streaming, it will just fail on the request \ No newline at end of file diff --git a/docs/ramblings/0.1.0/mypytest.py b/docs/ramblings/0.1.0/mypytest.py new file mode 100644 index 00000000..824c1794 --- /dev/null +++ b/docs/ramblings/0.1.0/mypytest.py @@ -0,0 +1,13 @@ +from typing import TypedDict + + +class Test(TypedDict): + name: str + age: int + + +def test(**t: Test): + print(t) + +# no type hinting like ts thats unfortunate. +test( ) diff --git a/src/ell/lmp/complex.py b/src/ell/lmp/complex.py index f571720c..6b7a09d9 100644 --- a/src/ell/lmp/complex.py +++ b/src/ell/lmp/complex.py @@ -12,7 +12,7 @@ from functools import wraps from typing import Any, Dict, Optional, List, Callable, Union -def complex(model: str, client: Optional[Any] = None, exempt_from_tracking=False, tools: Optional[List[Callable]] = None, post_callback: Optional[Callable] = None, **api_params): +def complex(model: str, client: Optional[Any] = None, tools: Optional[List[Callable]] = None, exempt_from_tracking=False, post_callback: Optional[Callable] = None, **api_params): """ A sophisticated language model programming decorator for complex LLM interactions. @@ -213,47 +213,62 @@ def parallel_assistant(message_history: List[Message]) -> List[Message]: - ell.studio: For visualizing and analyzing LMP executions. """ default_client_from_decorator = client + default_model_from_decorator = model def parameterized_lm_decorator( prompt: LMP, ) -> Callable[..., Union[List[Message], Message]]: - color = compute_color(prompt) _warnings(model, prompt, default_client_from_decorator) - @wraps(prompt) def model_call( - *fn_args, + *prompt_args, _invocation_origin : str = None, client: Optional[Any] = None, lm_params: Optional[LMPParams] = {}, - invocation_api_params=False, - **fn_kwargs, + **prompt_kwargs, ) -> _lstr_generic: - res = prompt(*fn_args, **fn_kwargs) - - assert exempt_from_tracking or _invocation_origin is not None, "Invocation origin is required when using a tracked LMP" + # promt -> str + res = prompt(*prompt_args, **prompt_kwargs) + # Convert prompt into ell messages messages = _get_messages(res, prompt) - - if config.verbose and not exempt_from_tracking: model_usage_logger_pre(prompt, fn_args, fn_kwargs, "notimplemented", messages, color) - - (result, _api_params, metadata) = call(model=model, messages=messages, api_params={**config.default_lm_params, **api_params, **lm_params}, client=client or default_client_from_decorator, _invocation_origin=_invocation_origin, _exempt_from_tracking=exempt_from_tracking, _logging_color=color, _name=prompt.__name__, tools=tools) - + # done. + + # Cute verbose logging. + if config.verbose and not exempt_from_tracking: model_usage_logger_pre(prompt, prompt_args, prompt_kwargs, model_call.__ell_hash__, messages) + + # Call the model. We use this data class because we have so many params! + merged_call_params = {**config.default_lm_params, **api_params, **lm_params} + ell_call = EllCall( + model=merged_call_params.get("model", default_model_from_decorator), + messages=messages, + client = client or default_client_from_decorator, + api_params=merged_call_params, + tools=tools, + invocation_id=_invocation_origin, + ) + # Get the provider for the model + provider = config.get_provider_for(ell_call) + (result, _api_params, metadata) = provider.call_model(ell_call) + + (result, _api_params, metadata) = call(client=client or default_client_from_decorator, _invocation_origin=_invocation_origin, should_log=config.verbose and not exempt_from_tracking, _name=prompt.__name__, tools=tools) + + # Finish result = post_callback(result) if post_callback else result - - return result, api_params, metadata + # omg bug spotted! + # These get sent to track. + # This is wack. + return result, _api_params, metadata - # TODO: # we'll deal with type safety here later model_call.__ell_api_params__ = api_params model_call.__ell_func__ = prompt model_call.__ell_type__ = LMPType.LM model_call.__ell_exempt_from_tracking = exempt_from_tracking - # model_call.__ell_uses__ = prompt.__ell_uses__ - # model_call.__ell_hash__ = prompt.__ell_hash__ + if exempt_from_tracking: return model_call diff --git a/src/ell/provider.py b/src/ell/provider.py index f995f3c4..a116ceee 100644 --- a/src/ell/provider.py +++ b/src/ell/provider.py @@ -15,10 +15,11 @@ class APICallResult: final_call_params: Dict[str, Any] class EllCall(TypedDict): - messages: List[Message] - client: Optional[Any] = None - tools: Optional[List[LMP]] = None - response_format: Optional[Dict[str, Any]] = None + model : str + messages : List[Message] + client : Any + tools : Optional[List[LMP]] + response_format : Optional[Dict[str, Any]] e = EllCall(messages=[], client=None, tools=None, response_format=None) @@ -32,13 +33,14 @@ class Provider(ABC): Abstract base class for all providers. Providers are API interfaces to language models, not necessarily API providers. For example, the OpenAI provider is an API interface to OpenAI's API but also to Ollama and Azure OpenAI. In Ell. We hate abstractions. The only reason this exists is to force implementers to implement their own provider correctly -_-. + """ ################################ ### API PARAMETERS ############# ################################ @abstractmethod - def provider_call_function(self, **ell_call : EllCall) -> Dict[str, Any]: + def provider_call_function(self) -> Dict[str, Any]: """ Implement this method to return the function that makes the API call to the language model. For example, if you're implementing the OpenAI provider, you would return the function that makes the API call to OpenAI's API. @@ -55,15 +57,15 @@ def disallowed_provider_params(self) -> FrozenSet[str]: """ return frozenset({"system", "tools", "tool_choice", "stream", "functions", "function_call", "response_format"}) - def available_params(self) -> APICallParams: - return get_params_of_call_function + EllCall.__required_keys__ + def available_params(self) -> Partial[APICallParams]: + return frozenset(get_params_of_call_function(provider_call_params.keys())) + EllCall.__required_keys__ - disallowed_params ################################ ### TRANSLATION ############### ################################ @abstractmethod - def translate_to_provider(self, ell_call : EllCall) -> APICallParams: + def translate_to_provider(self, ) -> APICallParams: """Converts an ell call to provider call params!""" return NotImplemented @@ -75,41 +77,83 @@ def translate_from_provider(self, provider_response : Any, ell_call : EllCall) - ################################ ### CALL MODEL ################ ################################ - def call_model(self, client : Optional[Any] = None, model : Optional[str] = None, messages : Optional[List[Message]] = None, tools : Optional[List[LMP]] = None, **api_params) -> Any: + def call_model(self, model : Optional[str] = None, client : Optional[Any] = None, messages : Optional[List[Message]] = None, tools : Optional[List[LMP]] = None, **api_params) -> Any: # Automatic validation of params - assert api_params.keys() in self.available_params(), f"Invalid parameters: {api_params}" + assert api_params.keys() not in self.disallowed_provider_params(), f"Disallowed parameters: {api_params}" + assert api_params.keys() in self.available_params(), f"Invalid parameters: {api_params}" # Call call_params = self.translate_to_provider(ell_call) provider_resp = self.provider_call_function(client, model)(**call_params) return self.translate_from_provider(provider_resp, ell_call) + def default_models(self) -> List[str]: + """Returns a list of default models for this provider.""" + return [ + ] + + def register_all_models(self, client : Any): + """Registers all default models for this provider.""" + for model in self.default_models(): + self.register_model(model, client) + + def validate_call(self, call : EllCall): + if model == "o1-preview" or model == "o1-mini": + # Ensure no system messages are present + assert all(msg['role'] != 'system' for msg in final_call_params['messages']), "System messages are not allowed for o1-preview or o1-mini models" + + if self.model_is_available(call.model): + return + else: + raise ValueError(f"Model {call.model} not available for provider {self.name}") + + +class OpenAIClientProvider(Provider): + """Use this for providers that are a wrapper around an OpenAI client e.g. mistral, groq, azure, etc.""" + + ... + +class OpenAIProvider(OpenAIClientProvider): + def default_models(self) -> List[str]: + return [ + "gpt-4o", + "gpt-4o-mini", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-07-18", + "gpt-4o-2024-06-20", + "gpt-4o-2024-04-09", + "gpt-4o-2024-03-13", + "gpt-4o-2024-02-29", + ] + + def validate_call(self, call : EllCall): + super().validate_call(call) + if model == "o1-preview" or model == "o1-mini": + # Ensure no system messages are present + assert all(msg['role'] != 'system' for msg in final_call_params['messages']), "System messages are not allowed for o1-preview or o1-mini models" + + def provider_call_function(self, EllCall) -> Dict[str, Any]: + if EllCall['response_format']: + return EllCall['client'].beta.chat.completions.parse(**EllCall) + else: + return EllCall['client'].chat.completions.create(**EllCall) + + def available_params(self, ell_call : EllCall) -> Partial[APICallParams]: + defualt_params = get_params_of_call_function(self.provider_call_function(ell_call)) + + if ell_call['response_format']: + # no streaming currently + eturn defualt_params - {'stream'} + else: + return defualt_params + +class OllamaProvider(OpenAIClientProvider): + def default_models(self) -> List[str]: + -class Provider2_0(ABC): - # How do we prevent system param? - @abstractmethod - def disallowed_provider_params(self) -> List[str]: - """ - Returns a list of disallowed call params that ell will override. - """ - return {"system", "tools", "tool_choice", "stream", "functions", "function_call"} - - @abstractmethod - def translate_provider_to_ell(self, provider_response : Any, ell_call : EllCall) -> Tuple[List[Message], EllMetadata]: - """Converts provider responses to universal format.""" - return NotImplemented - - def call_model(self, client : Optional[Any] = None, model : Optional[str] = None, messages : Optional[List[Message]] = None, tools : Optional[List[LMP]] = None, **api_params) -> Any: - # Automatic validation of params - assert api_params.keys() in self.available_params(), f"Invalid parameters: {api_params}" - assert api_params.keys() not in self.disallowed_provider_params(), f"Disallowed parameters: {api_params}" - # Call - call_params = self.translate_ell_to_provider(ell_call) - provider_resp = self.provider_call_function(client, model)(**call_params) - return self.translate_provider_to_ell(provider_resp, ell_call) - \ No newline at end of file diff --git a/src/ell/util/api.py b/src/ell/util/api.py index 40493344..e4dad014 100644 --- a/src/ell/util/api.py +++ b/src/ell/util/api.py @@ -24,8 +24,7 @@ def call( tools: Optional[list[LMP]] = None, client: Optional[Any] = None, _invocation_origin: str, - _exempt_from_tracking: bool, - _logging_color: Optional[str] = None, + should_log: bool, _name: Optional[str] = None, ) -> Tuple[Union[Message, List[Message]], Dict[str, Any], Dict[str, Any]]: """ @@ -48,14 +47,14 @@ def call( # XXX: Could actually delete htis call_result = provider_class.call_model(client, model, messages, api_params, tools) - if config.verbose and not _exempt_from_tracking: - model_usage_logger_post_start(_logging_color, call_result.actual_n) + if should_log: + model_usage_logger_post_start(call_result.actual_n) - with model_usage_logger_post_intermediate(_logging_color, call_result.actual_n) as _logger: - tracked_results, metadata = provider_class.process_response(call_result, _invocation_origin, _logger if config.verbose and not _exempt_from_tracking else None, tools) + with model_usage_logger_post_intermediate(call_result.actual_n) as _logger: + tracked_results, metadata = provider_class.process_response(call_result, _invocation_origin, _logger if should_log else None, tools) - if config.verbose and not _exempt_from_tracking: + if config.verbose and not should_log: model_usage_logger_post_end() diff --git a/src/ell/util/verbosity.py b/src/ell/util/verbosity.py index 14df32a9..d061bc33 100644 --- a/src/ell/util/verbosity.py +++ b/src/ell/util/verbosity.py @@ -115,17 +115,17 @@ def print_wrapped_messages(messages: List[Message], max_role_length: int, color: if i < len(messages) - 1: print(f"{PIPE_COLOR}│{RESET}") + def model_usage_logger_pre( invoking_lmp: LMP, lmp_args: Tuple, lmp_kwargs: Dict, lmp_hash: str, messages: List[Message], - color: str = "", arg_max_length: int = 8 ): """Log model usage before execution with customizable argument display length and ASCII box.""" - color = color or compute_color(invoking_lmp) + color = compute_color(invoking_lmp) formatted_args = [format_arg(arg, arg_max_length) for arg in lmp_args] formatted_kwargs = [format_kwarg(key, lmp_kwargs[key], arg_max_length) for key in lmp_kwargs] formatted_params = ', '.join(formatted_args + formatted_kwargs) @@ -157,7 +157,7 @@ def model_usage_logger_post_start(color: str = "", n: int = 1): from contextlib import contextmanager @contextmanager -def model_usage_logger_post_intermediate(color: str = "", n: int = 1): +def model_usage_logger_post_intermediate( n: int = 1): """Context manager to log intermediate model output without wrapping, only indenting if necessary.""" terminal_width = get_terminal_width() prefix = f"{PIPE_COLOR}│ " From 4fcd3d407b68f2c9910e006ff918463cbc274243 Mon Sep 17 00:00:00 2001 From: William Guss Date: Thu, 19 Sep 2024 15:53:04 -0700 Subject: [PATCH 04/17] remove configverbose true --- docs/ramblings/parsing_example.py | 2 +- examples/bv.py | 3 +-- examples/calculator_structured.py | 4 ++-- examples/chord_progression_writer.py | 4 ++-- examples/client_example.py | 2 +- examples/diamond_depencies.py | 4 ++-- examples/future/limbo.py | 3 +-- examples/future/meme_maker.py | 4 ++-- examples/future/structured.py | 2 +- examples/future/tool_using_chatbot.py | 3 +-- examples/git_issue.py | 2 +- examples/hello_postgres.py | 3 +-- examples/joke.py | 3 +-- examples/multilmp.py | 4 ++-- examples/output_freezing.py | 7 +++---- examples/quick_chat.py | 4 ++-- examples/rag.py | 2 +- 17 files changed, 25 insertions(+), 31 deletions(-) diff --git a/docs/ramblings/parsing_example.py b/docs/ramblings/parsing_example.py index fe5206bc..87e238e2 100644 --- a/docs/ramblings/parsing_example.py +++ b/docs/ramblings/parsing_example.py @@ -3,7 +3,7 @@ from typing import Callable, List, Tuple import ell from ell.types._lstr import _lstr -ell.config.verbose = True + diff --git a/examples/bv.py b/examples/bv.py index 53e8b418..fd85ace1 100644 --- a/examples/bv.py +++ b/examples/bv.py @@ -30,8 +30,7 @@ def write_a_complete_python_class(user_spec : str): if __name__ == "__main__": - ell.config.verbose = True - ell.set_store(SQLiteStore("./logdir"), autocommit=True) + ell.init(verbose=True, store=SQLiteStore("./logdir"), autocommit=True) # test[0] = "modified at execution :O" w = get_lmp(z=20) cls_Def = w("A class that represents a bank") diff --git a/examples/calculator_structured.py b/examples/calculator_structured.py index 2ee7c987..bdd02b0d 100644 --- a/examples/calculator_structured.py +++ b/examples/calculator_structured.py @@ -5,7 +5,7 @@ from ell.stores.sql import SQLiteStore -ell.config.verbose = True + @dataclasses.dataclass @@ -69,5 +69,5 @@ def calc_structured(task: str) -> float: if __name__ == "__main__": - ell.set_store('./logdir', autocommit=True) + ell.init(store='./logdir', autocommit=True, verbose=True) print(calc_structured("What is two plus two?")) diff --git a/examples/chord_progression_writer.py b/examples/chord_progression_writer.py index 0a69ac7f..07c3ddbf 100644 --- a/examples/chord_progression_writer.py +++ b/examples/chord_progression_writer.py @@ -4,7 +4,7 @@ import pygame import time -ell.config.verbose = True + CHORD_FORMAT = "| Chord | Chord | ... |" @@ -57,7 +57,7 @@ def play_midi_file(file_path): from ell.stores.sql import SQLiteStore if __name__ == "__main__": - ell.set_store('./logdir', autocommit=True) + ell.init(store='./logdir', autocommit=True, verbose=True) genre = input("Enter the genre of the song (or press Enter to skip): ").strip() or None key = input("Enter the key of the song (or press Enter to skip): ").strip() or None diff --git a/examples/client_example.py b/examples/client_example.py index fa1dec1b..1ab4cc3b 100644 --- a/examples/client_example.py +++ b/examples/client_example.py @@ -5,7 +5,7 @@ import ell.lmp.simple -ell.config.verbose = True + client = openai.Client(api_key=open(os.path.expanduser("~/.oaikey")).read().strip()) diff --git a/examples/diamond_depencies.py b/examples/diamond_depencies.py index 49d590a3..3ae8edb5 100644 --- a/examples/diamond_depencies.py +++ b/examples/diamond_depencies.py @@ -1,7 +1,7 @@ import random from typing import List, Tuple import ell -ell.config.verbose = True + @ell.simple(model="gpt-4o-mini", temperature=1.0) def random_number() -> str: @@ -31,7 +31,7 @@ def choose_which_is_a_better_piece_of_writing(poem : str, story : str) -> str: if __name__ == "__main__": from ell.stores.sql import SQLiteStore - ell.set_store('./logdir', autocommit=True) + ell.init(store='./logdir', autocommit=True, verbose=True) num = random_number() diff --git a/examples/future/limbo.py b/examples/future/limbo.py index dec4b804..06a76ab1 100644 --- a/examples/future/limbo.py +++ b/examples/future/limbo.py @@ -4,8 +4,7 @@ -ell.set_store('./logdir', autocommit=True) -ell.config.verbose = True +ell.init(verbose=True, store='./logdir', autocommit=True) @ell.tool(autogenerate=True) diff --git a/examples/future/meme_maker.py b/examples/future/meme_maker.py index 92e8e4b0..105c212c 100644 --- a/examples/future/meme_maker.py +++ b/examples/future/meme_maker.py @@ -4,7 +4,7 @@ import os import ell -ell.config.verbose = True + # Load the cat meme image using PIL cat_meme_pil = Image.open(os.path.join(os.path.dirname(__file__), "catmeme.jpg")) @@ -17,6 +17,6 @@ def make_a_joke_about_the_image(image: Image.Image) -> str: if __name__ == "__main__": - ell.set_store('./logdir', autocommit=True) + ell.init(store='./logdir', autocommit=True, verbose=True) joke = make_a_joke_about_the_image(cat_meme_pil) print(joke) \ No newline at end of file diff --git a/examples/future/structured.py b/examples/future/structured.py index 5724f26f..162cafde 100644 --- a/examples/future/structured.py +++ b/examples/future/structured.py @@ -2,7 +2,7 @@ import ell from pydantic import BaseModel, Field -ell.config.verbose = True + class Test(BaseModel): diff --git a/examples/future/tool_using_chatbot.py b/examples/future/tool_using_chatbot.py index 930c3205..cd5ce512 100644 --- a/examples/future/tool_using_chatbot.py +++ b/examples/future/tool_using_chatbot.py @@ -6,8 +6,7 @@ -ell.set_store('./logdir', autocommit=True) -ell.config.verbose = True +ell.init(verbose=True, store='./logdir', autocommit=True) @ell.tool() diff --git a/examples/git_issue.py b/examples/git_issue.py index b345fd68..fb796df2 100644 --- a/examples/git_issue.py +++ b/examples/git_issue.py @@ -3,7 +3,7 @@ from ell.stores.sql import SQLiteStore -ell.config.verbose = True + @ell.simple(model="gpt-4o-mini", temperature=0.1) def generate_description(about : str): diff --git a/examples/hello_postgres.py b/examples/hello_postgres.py index 81d66e42..e36dfb5f 100644 --- a/examples/hello_postgres.py +++ b/examples/hello_postgres.py @@ -20,8 +20,7 @@ def hello(world : str): if __name__ == "__main__": - ell.config.verbose = True - ell.set_store(PostgresStore('postgresql://postgres:postgres@localhost:5432/ell'), autocommit=True) + ell.init(verbose=True, store=PostgresStore('postgresql://postgres:postgres@localhost:5432/ell'), autocommit=True) greeting = hello("sam altman") # > "hello sama! ... " diff --git a/examples/joke.py b/examples/joke.py index 303ebe37..212924f1 100644 --- a/examples/joke.py +++ b/examples/joke.py @@ -22,7 +22,6 @@ def joke(topic : str): if __name__ == "__main__": - ell.config.verbose = True - ell.set_store('./logdir', autocommit=False) + ell.init(verbose=True, store='./logdir', autocommit=False) # Todo: Figure configuration for automcommititng. joke("minecraft") # \ No newline at end of file diff --git a/examples/multilmp.py b/examples/multilmp.py index 2068cf73..6a36c850 100644 --- a/examples/multilmp.py +++ b/examples/multilmp.py @@ -3,7 +3,7 @@ from ell.stores.sql import SQLiteStore -ell.config.verbose = True + @ell.simple(model="gpt-4o-mini", temperature=1.0) def generate_story_ideas(about : str): @@ -33,7 +33,7 @@ def write_a_really_good_story(about : str): if __name__ == "__main__": from ell.stores.sql import SQLiteStore - ell.set_store('./logdir', autocommit=True) + ell.init(store='./logdir', autocommit=True, verbose=True) # with ell.cache(write_a_really_good_story): story = write_a_really_good_story("a dog") diff --git a/examples/output_freezing.py b/examples/output_freezing.py index 9bdb5a1d..7ccf7461 100644 --- a/examples/output_freezing.py +++ b/examples/output_freezing.py @@ -1,6 +1,6 @@ import ell from ell.stores.sql import SQLiteStore -ell.config.verbose = True + BASE_PROMPT = """You are an adept python programmer. Only answer in python code. Avoid markdown formatting at all costs.""" @@ -28,9 +28,8 @@ def write_unit_for_a_class(class_def : str): if __name__ == "__main__": - store = SQLiteStore("./logdir") - ell.set_store(store, autocommit=True) + ell.init(store='./logdir', autocommit=True, verbose=True) - with store.freeze(create_a_python_class): + with ell.get_store().freeze(create_a_python_class): _class_def = create_a_python_class("A class that represents a bank") _unit_tests = write_unit_for_a_class(_class_def) \ No newline at end of file diff --git a/examples/quick_chat.py b/examples/quick_chat.py index 89f1ffc5..4f756e98 100644 --- a/examples/quick_chat.py +++ b/examples/quick_chat.py @@ -1,7 +1,7 @@ import random from typing import List, Tuple import ell -ell.config.verbose = True + names_list = [ @@ -49,7 +49,7 @@ def chat(message_history : List[Tuple[str, str]], *, personality : str): if __name__ == "__main__": from ell.stores.sql import SQLiteStore - ell.set_store('./logdir', autocommit=True) + ell.init(store='./logdir', autocommit=True, verbose=True) for __ in range(100): messages : List[Tuple[str, str]]= [] diff --git a/examples/rag.py b/examples/rag.py index 22659ac1..db7b89b7 100644 --- a/examples/rag.py +++ b/examples/rag.py @@ -48,7 +48,7 @@ def rag(query: str, context: str) -> str: if __name__ == "__main__": - ell.config.verbose = True + documents = [ "ell is a cool new framework written by will", From 29c4497fad942735fa765fe4f64087eef5025f25 Mon Sep 17 00:00:00 2001 From: William Guss Date: Thu, 19 Sep 2024 20:57:03 -0700 Subject: [PATCH 05/17] minimal configurator changes. --- docs/principles.md | 63 +++++++++++++++ src/ell/configurator.py | 170 ++++++++++++++++------------------------ 2 files changed, 130 insertions(+), 103 deletions(-) create mode 100644 docs/principles.md diff --git a/docs/principles.md b/docs/principles.md new file mode 100644 index 00000000..0f18ea53 --- /dev/null +++ b/docs/principles.md @@ -0,0 +1,63 @@ +# Principles for developing ell + +Some principles for developing ell that we pick up along the way. + +1. went missing +2. went missing.. +1. the user shouldn't wait to find out they're missing something: + Consider caching + ``` + import ell + + @ell.simple + def fn(): return "prompt" + + with ell.cache(fn): + fn() + ``` + If I don't have a store installed, this shit will break when i get to the ell.cache. + + We prefer to have store enable caching; that is the cache contextmanager is only enabled if we have a store: + + ``` + import ell + + store = ell.stores.SQLiteStore("mystore") + ell.use_store(store) + + @ell.simple + def fn(): return "prompt" + + with ell.store.cache(lmp): + fn() + ``` + +2. no unreadable side-effects. + ``` + store = ell.stores.SQLiteStore("mystore") + ell.use_store(store) + ``` + is preferred to: + ``` + store = ell.stores.SQLiteStore("mystore") + store.install() + ``` + This is a side-effect. + + +4. api providers are the single source of truth for model information + - we will never implement Model("gpt-4", Capabilities(vision=True)) + - always rely on the api to tell you if you're using something a model can't do + - in that sense ell.simple should be the thinnest possible wrapper around the api + +5. ell is a library not a framework + - we are building pytorch not keras. nice agent frameworks etc can exist on top of ell, but are not a part of ell itself. ell is meant to give you all of the building blocks to build systems. + - in the meta programming space, we will support standardized building blocks (optimizers, established prompt compilers, etc) but not too frameworky. + (this is actually is a sticky point and drawing the line will always be hard, but initially this is good.) + +6. less abstraction is better + - more single files , less multi file abstractions + - you should just be able to read the source & understand. + +7. ell studio is not ell + - ell studio is an exception in that we can bloat it as much as we need to make the dx beautiful. \ No newline at end of file diff --git a/src/ell/configurator.py b/src/ell/configurator.py index 79e6afb5..9f76843f 100644 --- a/src/ell/configurator.py +++ b/src/ell/configurator.py @@ -7,11 +7,29 @@ from pydantic import BaseModel, ConfigDict, Field from ell.store import Store from ell.provider import Provider +from dataclasses import dataclass, field _config_logger = logging.getLogger(__name__) + +@dataclass(frozen=True) +class _Model: + name: str + default_client: Optional[Union[openai.Client, Any]] = None + #XXX: Deprecation in 0.1.0 + #XXX: We will depreciate this when streaming is implemented. + # Currently we stream by default for the verbose renderer, + # but in the future we will not support streaming by default + # and stream=True must be passed which will then make API providers the + # single source of truth for whether or not a model supports an api parameter. + # This makes our implementation extremely light, only requiring us to provide + # a list of model names in registration. + supports_streaming : Optional[bool] = field(default=None) + + + class Config(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - registry: Dict[str, openai.Client] = Field(default_factory=dict, description="A dictionary mapping model names to OpenAI clients.") + registry: Dict[str, _Model] = Field(default_factory=dict, description="A dictionary mapping model names to their configurations.") verbose: bool = Field(default=False, description="If True, enables verbose logging.") wrapped_logging: bool = Field(default=True, description="If True, enables wrapped logging for better readability.") override_wrapped_logging_width: Optional[int] = Field(default=None, description="If set, overrides the default width for wrapped logging.") @@ -21,41 +39,38 @@ class Config(BaseModel): default_lm_params: Dict[str, Any] = Field(default_factory=dict, description="Default parameters for language models.") default_client: Optional[openai.Client] = Field(default=None, description="The default OpenAI client used when a specific model client is not found.") providers: Dict[Type, Type[Provider]] = Field(default_factory=dict, description="A dictionary mapping client types to provider classes.") - def __init__(self, **data): super().__init__(**data) self._lock = threading.Lock() self._local = threading.local() - def register_model(self, model_name: str, client: Any) -> None: + + def register_model( + self, + name: str, + default_client: Optional[Union[openai.Client, Any]] = None, + supports_streaming: Optional[bool] = None + ) -> None: """ - Register an OpenAI client for a specific model name. - - :param model_name: The name of the model to register. - :type model_name: str - :param client: The OpenAI client to associate with the model. - :type client: openai.Client + Register a model with its configuration. """ with self._lock: - self.registry[model_name] = client + # XXX: Will be deprecated in 0.1.0 + self.registry[name] = _Model( + name=name, + default_client=default_client, + supports_streaming=supports_streaming + ) - @property - def has_store(self) -> bool: - """ - Check if a store is set. - :return: True if a store is set, False otherwise. - :rtype: bool - """ - return self.store is not None @contextmanager - def model_registry_override(self, overrides: Dict[str, Any]): + def model_registry_override(self, overrides: Dict[str, _Model]): """ - Temporarily override the model registry with new client mappings. + Temporarily override the model registry with new model configurations. - :param overrides: A dictionary of model names to OpenAI clients to override. - :type overrides: Dict[str, openai.Client] + :param overrides: A dictionary of model names to ModelConfig instances to override. + :type overrides: Dict[str, ModelConfig] """ if not hasattr(self._local, 'stack'): self._local.stack = [] @@ -71,19 +86,19 @@ def model_registry_override(self, overrides: Dict[str, Any]): finally: self._local.stack.pop() - def get_client_for(self, model_name: str) -> Tuple[Optional[Any], bool]: + def get_client_for(self, model_name: str) -> Tuple[Optional[openai.Client], bool]: """ Get the OpenAI client for a specific model name. :param model_name: The name of the model to get the client for. :type model_name: str - :return: The OpenAI client for the specified model, or None if not found. - :rtype: Optional[openai.Client] + :return: The OpenAI client for the specified model, or None if not found, and a fallback flag. + :rtype: Tuple[Optional[openai.Client], bool] """ current_registry = self._local.stack[-1] if hasattr(self._local, 'stack') and self._local.stack else self.registry - client = current_registry.get(model_name) + model_config = current_registry.get(model_name) fallback = False - if model_name not in current_registry.keys(): + if not model_config: warning_message = f"Warning: A default provider for model '{model_name}' could not be found. Falling back to default OpenAI client from environment variables." if self.verbose: from colorama import Fore, Style @@ -92,68 +107,16 @@ def get_client_for(self, model_name: str) -> Tuple[Optional[Any], bool]: _config_logger.debug(warning_message) client = self.default_client fallback = True - return client, fallback - - def reset(self) -> None: - """ - Reset the configuration to its initial state. - """ - with self._lock: - self.__init__() - if hasattr(self._local, 'stack'): - del self._local.stack - - def set_store(self, store: Union[Store, str], autocommit: bool = True) -> None: - """ - Set the store for the configuration. - - :param store: The store to set. Can be a Store instance or a string path for SQLiteStore. - :type store: Union[Store, str] - :param autocommit: Whether to enable autocommit for the store. - :type autocommit: bool - """ - if isinstance(store, str): - from ell.stores.sql import SQLiteStore - self.store = SQLiteStore(store) else: - self.store = store - self.autocommit = autocommit or self.autocommit - - def get_store(self) -> Store: - """ - Get the current store. - - :return: The current store. - :rtype: Store - """ - return self.store - - def set_default_lm_params(self, **params: Dict[str, Any]) -> None: - """ - Set default parameters for language models. - - :param params: Keyword arguments representing the default parameters. - :type params: Dict[str, Any] - """ - self.default_lm_params = params - - - - def set_default_client(self, client: openai.Client) -> None: - """ - Set the default OpenAI client. - - :param client: The default OpenAI client to set. - :type client: openai.Client - """ - self.default_client = client + client = model_config.default_client + return client, fallback def register_provider(self, provider_class: Type[Provider]) -> None: """ Register a provider class for a specific client type. :param provider_class: The provider class to register. - :type provider_class: Type[AbstractProvider] + :type provider_class: Type[Provider] """ with self._lock: self.providers[provider_class.get_client_type()] = provider_class @@ -165,11 +128,12 @@ def get_provider_for(self, client: Any) -> Optional[Type[Provider]]: :param client: The client instance to get the provider for. :type client: Any :return: The provider class for the specified client, or None if not found. - :rtype: Optional[Type[AbstractProvider]] + :rtype: Optional[Type[Provider]] """ return next((provider for client_type, provider in self.providers.items() if isinstance(client, client_type)), None) -# Singleton instance +# Single* instance +# XXX: Make a singleton config = Config() def init( @@ -178,7 +142,7 @@ def init( autocommit: bool = True, lazy_versioning: bool = True, default_lm_params: Optional[Dict[str, Any]] = None, - default_openai_client: Optional[openai.Client] = None + default_client: Optional[Any] = None ) -> None: """ Initialize the ELL configuration with various settings. @@ -196,34 +160,34 @@ def init( :param default_openai_client: Set the default OpenAI client. :type default_openai_client: openai.Client, optional """ + # XXX: prevent double init config.verbose = verbose config.lazy_versioning = lazy_versioning - if store is not None: - config.set_store(store, autocommit) + if isinstance(store, str): + from ell.stores.sql import SQLiteStore + config.store = SQLiteStore(store) + else: + config.store = store + config.autocommit = autocommit or config.autocommit if default_lm_params is not None: - config.set_default_lm_params(**default_lm_params) - + config.default_lm_params.update(default_lm_params) - - if default_openai_client is not None: - config.set_default_client(default_openai_client) + if default_client is not None: + config.default_client = default_client # Existing helper functions -@wraps(config.get_store) -def get_store() -> Store: - return config.get_store() - -@wraps(config.set_store) -def set_store(*args, **kwargs) -> None: - return config.set_store(*args, **kwargs) +def get_store() -> Union[Store, None]: + return config.store -@wraps(config.set_default_lm_params) -def set_default_lm_params(*args, **kwargs) -> None: - return config.set_default_lm_params(*args, **kwargs) +# Will be deprecated at 0.1.0 # You can add more helper functions here if needed @wraps(config.register_provider) def register_provider(*args, **kwargs) -> None: - return config.register_provider(*args, **kwargs) \ No newline at end of file + return config.register_provider(*args, **kwargs) + +# Deprecated now (remove at 0.1.0) +def set_store(*args, **kwargs) -> None: + raise DeprecationWarning("The set_store function is deprecated and will be removed in a future version. Use ell.init(store=...) instead.") \ No newline at end of file From f02648a5b8dffebb83bc29bb8f312089f1f60d70 Mon Sep 17 00:00:00 2001 From: William Guss Date: Fri, 20 Sep 2024 00:14:32 -0700 Subject: [PATCH 06/17] almost there --- .../a new provider_api.md} | 31 +++- docs/ramblings/principles.md | 43 ------ docs/src/core_concepts/configuration.rst | 2 +- examples/quick_chat.py | 8 +- src/ell/lmp/_track.py | 2 +- src/ell/lmp/complex.py | 91 ++++++++---- src/ell/models/openai.py | 65 ++++---- src/ell/provider.py | 139 +++++++----------- src/ell/util/_warnings.py | 9 +- src/ell/util/api.py | 61 -------- 10 files changed, 184 insertions(+), 267 deletions(-) rename docs/ramblings/{providers.py => 0.1.0/a new provider_api.md} (90%) delete mode 100644 docs/ramblings/principles.md delete mode 100644 src/ell/util/api.py diff --git a/docs/ramblings/providers.py b/docs/ramblings/0.1.0/a new provider_api.md similarity index 90% rename from docs/ramblings/providers.py rename to docs/ramblings/0.1.0/a new provider_api.md index 0fa37df8..ae583de0 100644 --- a/docs/ramblings/providers.py +++ b/docs/ramblings/0.1.0/a new provider_api.md @@ -218,10 +218,33 @@ def response_to_content(self, provider_response : Any) -> str: # How would you guarantee that a provider? Respond with a tool call if a tool call occurs within the provider. # Without actually knowing the details of the provider, there's no way To guarantee this. It almost has to be like A required argument of the response construction -So you could. Require the implementer to say if there were A tool call or not in the response. -It's not possible to prevent people from writing **** code. Like we can't know if they're stupid provider has a type of a response that's not a tool call. -Unless we really explicitly add them mark what was in the response. +# So you could. Require the implementer to say if there were A tool call or not in the response. +# It's not possible to prevent people from writing **** code. Like we can't know if they're stupid provider has a type of a response that's not a tool call. +# Unless we really explicitly add them mark what was in the response. # Models (maybe models should live close to providers) -# This prevents us from doing routing but that's actualyl openrouters purpose \ No newline at end of file +# This prevents us from doing routing but that's actualyl openrouters purpose + + + + + +# right now we stream by default +# but this a problemn for models dont support it we'd ahve to make two requests which imo is a nono. + +# Future todo stream=False is default. We don't log steaming completions with verbose mode. +# Set verbose_stream=False to stop background streaming, or pass stream=False + + +register_model( + name="", + default_client=client, + disallowed_params={"stream", "stream_options"}, + default_params={"stream": False, "stream_options": {}}, +) + + +# if you set stream=False we dont log streaming completions + + diff --git a/docs/ramblings/principles.md b/docs/ramblings/principles.md deleted file mode 100644 index da9001bf..00000000 --- a/docs/ramblings/principles.md +++ /dev/null @@ -1,43 +0,0 @@ -# Principles for developing ell - -1. -2. -3. the user shouldn't wait to find out they're missing something: - Consider caching - ``` - import ell - - @ell.simple - def fn(): return "prompt" - - with ell.cache(fn): - fn() - ``` - If I don't have a store installed, this shit will break when i get to the ell.cache. - - We prefer to have store enable caching; that is the cache contextmanager is only enabled if we have a store: - - ``` - import ell - - store = ell.stores.SQLiteStore("mystore") - ell.use_store(store) - - @ell.simple - def fn(): return "prompt" - - with ell.store.cache(lmp): - fn() - ``` - -4. no unreadable side-effects. - ``` - store = ell.stores.SQLiteStore("mystore") - ell.use_store(store) - ``` - is preferred to: - ``` - store = ell.stores.SQLiteStore("mystore") - store.install() - ``` - This is a side-effect. diff --git a/docs/src/core_concepts/configuration.rst b/docs/src/core_concepts/configuration.rst index 45dc13d3..eb1afd7e 100644 --- a/docs/src/core_concepts/configuration.rst +++ b/docs/src/core_concepts/configuration.rst @@ -12,7 +12,7 @@ You can modify the global configuration using the ``ell.config`` object which is .. autopydantic_model:: ell.Config :members: - :exclude-members: default_client, registry, store, has_store + :exclude-members: default_client, registry, store :model-show-json: false :model-show-validator-members: false :model-show-config-summary: false diff --git a/examples/quick_chat.py b/examples/quick_chat.py index 4f756e98..1945df8b 100644 --- a/examples/quick_chat.py +++ b/examples/quick_chat.py @@ -23,8 +23,8 @@ def create_personality() -> str: """You are backstoryGPT. You come up with a backstory for a character incljuding name. Choose a completely random name from the list. Format as follows. -Name: -Backstory: <3 sentence backstory>'""" # System prompt + Name: + Backstory: <3 sentence backstory>'""" # System prompt return "Come up with a backstory about " + random.choice(names_list) # User prompt @@ -39,9 +39,9 @@ def chat(message_history : List[Tuple[str, str]], *, personality : str): return [ ell.system(f"""Here is your description. -{personality}. + {personality}. -Your goal is to come up with a response to a chat. Only respond in one sentence (should be like a text message in informality.) Never use Emojis."""), + Your goal is to come up with a response to a chat. Only respond in one sentence (should be like a text message in informality.) Never use Emojis."""), ell.user(format_message_history(message_history)), ] diff --git a/src/ell/lmp/_track.py b/src/ell/lmp/_track.py index fb6294e8..bed59374 100644 --- a/src/ell/lmp/_track.py +++ b/src/ell/lmp/_track.py @@ -94,7 +94,7 @@ def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str: if len(cached_invocations) > 0: - # TODO THis is bad? + # XXX: Fix caching. results = [d.deserialize() for d in cached_invocations[0].results] logger.info(f"Using cached result for {func_to_track.__qualname__} with state cache key: {state_cache_key}") diff --git a/src/ell/lmp/complex.py b/src/ell/lmp/complex.py index 6b7a09d9..3a290e97 100644 --- a/src/ell/lmp/complex.py +++ b/src/ell/lmp/complex.py @@ -1,17 +1,20 @@ from ell.configurator import config from ell.lmp._track import _track +from ell.provider import EllCallParams from ell.types._lstr import _lstr from ell.types import Message, ContentBlock from ell.types.message import LMP, InvocableLM, LMPParams, MessageOrDict, _lstr_generic from ell.types.studio import LMPType -from ell.util._warnings import _warnings +from ell.util._warnings import _no_api_key_warning, _warnings from ell.util.api import call from ell.util.verbosity import compute_color, model_usage_logger_pre +from ell.util.verbosity import model_usage_logger_post_end, model_usage_logger_post_intermediate, model_usage_logger_post_start from functools import wraps -from typing import Any, Dict, Optional, List, Callable, Union +from typing import Any, Dict, Optional, List, Callable, Tuple, Union +#XXX: Remove the docstirng here. def complex(model: str, client: Optional[Any] = None, tools: Optional[List[Callable]] = None, exempt_from_tracking=False, post_callback: Optional[Callable] = None, **api_params): """ A sophisticated language model programming decorator for complex LLM interactions. @@ -152,7 +155,7 @@ def extract_person_info(text: str) -> List[Message]: text = "John Doe is a 30-year-old software engineer." result : ell.Message = extract_person_info(text) - person_info = result.structured[0] + person_info = result.parsed print(f"Name: {person_info.name}, Age: {person_info.age}") 5. Multimodal Input: @@ -214,6 +217,7 @@ def parallel_assistant(message_history: List[Message]) -> List[Message]: """ default_client_from_decorator = client default_model_from_decorator = model + default_api_params_from_decorator = api_params def parameterized_lm_decorator( @@ -224,50 +228,55 @@ def parameterized_lm_decorator( @wraps(prompt) def model_call( *prompt_args, - _invocation_origin : str = None, + _invocation_origin : Optional[str] = None, client: Optional[Any] = None, - lm_params: Optional[LMPParams] = {}, + api_params: Optional[Dict[str, Any]] = None, **prompt_kwargs, - ) -> _lstr_generic: + ) -> Tuple[Any, Any, Any]: # promt -> str res = prompt(*prompt_args, **prompt_kwargs) # Convert prompt into ell messages messages = _get_messages(res, prompt) - # done. - + + # XXX: move should log to a logger. + should_log = not exempt_from_tracking and config.verbose # Cute verbose logging. - if config.verbose and not exempt_from_tracking: model_usage_logger_pre(prompt, prompt_args, prompt_kwargs, model_call.__ell_hash__, messages) - - # Call the model. We use this data class because we have so many params! - merged_call_params = {**config.default_lm_params, **api_params, **lm_params} - ell_call = EllCall( - model=merged_call_params.get("model", default_model_from_decorator), + if should_log: model_usage_logger_pre(prompt, prompt_args, prompt_kwargs, model_call.__ell_hash__, messages) #type: ignore + + # Call the model. + # Merge API params + merged_api_params = {**config.default_lm_params, **default_api_params_from_decorator, **(api_params or {})} + n = merged_api_params.get("n", 1) + # Merge client overrides & client registry + merged_client = _client_for_model(model, client or default_client_from_decorator) + ell_call = EllCallParams( + model=merged_api_params.get("model", default_model_from_decorator), messages=messages, client = client or default_client_from_decorator, - api_params=merged_call_params, + api_params=merged_api_params, + origin_id=_invocation_origin, tools=tools, - invocation_id=_invocation_origin, ) # Get the provider for the model provider = config.get_provider_for(ell_call) - (result, _api_params, metadata) = provider.call_model(ell_call) - (result, _api_params, metadata) = call(client=client or default_client_from_decorator, _invocation_origin=_invocation_origin, should_log=config.verbose and not exempt_from_tracking, _name=prompt.__name__, tools=tools) + if should_log: model_usage_logger_post_start(n) + with model_usage_logger_post_intermediate(n) as _logger: + (result, final_api_params, metadata) = provider.call_model(ell_call, _logger) - # Finish result = post_callback(result) if post_callback else result - - # omg bug spotted! - # These get sent to track. - # This is wack. - return result, _api_params, metadata + if should_log: + model_usage_logger_post_end() + # + # These get sent to track. This is wack. + return result, final_api_params, metadata - model_call.__ell_api_params__ = api_params - model_call.__ell_func__ = prompt - model_call.__ell_type__ = LMPType.LM - model_call.__ell_exempt_from_tracking = exempt_from_tracking + model_call.__ell_api_params__ = default_api_params_from_decorator #type: ignore + model_call.__ell_func__ = prompt #type: ignore + model_call.__ell_type__ = LMPType.LM #type: ignore + model_call.__ell_exempt_from_tracking = exempt_from_tracking #type: ignore if exempt_from_tracking: @@ -276,13 +285,15 @@ def model_call( return _track(model_call, forced_dependencies=dict(tools=tools)) return parameterized_lm_decorator + + def _get_messages(prompt_ret: Union[str, list[MessageOrDict]], prompt: LMP) -> list[Message]: """ Helper function to convert the output of an LMP into a list of Messages. """ if isinstance(prompt_ret, str): has_system_prompt = prompt.__doc__ is not None and prompt.__doc__.strip() != "" - messages = [Message(role="system", content=[ContentBlock(text=_lstr(prompt.__doc__) )])] if has_system_prompt else [] + messages = [Message(role="system", content=[ContentBlock(text=_lstr(prompt.__doc__ ) )])] if has_system_prompt else [] return messages + [ Message(role="user", content=[ContentBlock(text=prompt_ret)]) ] @@ -291,3 +302,25 @@ def _get_messages(prompt_ret: Union[str, list[MessageOrDict]], prompt: LMP) -> l prompt_ret, list ), "Need to pass a list of Messages to the language model" return prompt_ret + + +def _client_for_model( + model: str, + client: Optional[Any] = None, + _name: Optional[str] = None, +) -> Any: + # XXX: Move to config to centralize api keys etc. + if not client: + client, was_fallback = config.get_client_for(model) + # XXX: Wrong. + if not client and not was_fallback: + raise RuntimeError(_no_api_key_warning(model, _name, '', long=True, error=True)) + + if client is None: + raise ValueError(f"No client found for model '{model}'. Ensure the model is registered using 'register_model' in 'config.py' or specify a client directly using the 'client' argument in the decorator or function call.") + + # compatibility with local models necessetates no api key. + # if not client.api_key: + # raise RuntimeError(_no_api_key_warning(model, _name, client, long=True, error=True)) + + return client \ No newline at end of file diff --git a/src/ell/models/openai.py b/src/ell/models/openai.py index 676b25b0..63830e8b 100644 --- a/src/ell/models/openai.py +++ b/src/ell/models/openai.py @@ -46,40 +46,43 @@ def register(client: openai.Client): The function doesn't return anything but updates the global configuration with the registered models. """ - model_data = [ - ('gpt-4-1106-preview', 'system'), - ('gpt-4-32k-0314', 'openai'), - ('text-embedding-3-large', 'system'), - ('gpt-4-0125-preview', 'system'), - ('babbage-002', 'system'), - ('gpt-4-turbo-preview', 'system'), - ('gpt-4o', 'system'), - ('gpt-4o-2024-05-13', 'system'), - ('gpt-4o-mini-2024-07-18', 'system'), - ('gpt-4o-mini', 'system'), - ('gpt-4o-2024-08-06', 'system'), - ('gpt-3.5-turbo-0301', 'openai'), - ('gpt-3.5-turbo-0613', 'openai'), - ('tts-1', 'openai-internal'), - ('gpt-3.5-turbo', 'openai'), - ('gpt-3.5-turbo-16k', 'openai-internal'), - ('davinci-002', 'system'), - ('gpt-3.5-turbo-16k-0613', 'openai'), - ('gpt-4-turbo-2024-04-09', 'system'), - ('gpt-3.5-turbo-0125', 'system'), - ('gpt-4-turbo', 'system'), - ('gpt-3.5-turbo-1106', 'system'), - ('gpt-3.5-turbo-instruct-0914', 'system'), - ('gpt-3.5-turbo-instruct', 'system'), - ('gpt-4-0613', 'openai'), - ('gpt-4', 'openai'), - ('gpt-4-0314', 'openai'), - ('o1-preview', 'system'), - ('o1-mini', 'system'), + #XXX: Deprecation in 0.1.0 + standard_models = [ + 'gpt-4-1106-preview', + 'gpt-4-32k-0314', + 'text-embedding-3-large', + 'gpt-4-0125-preview', + 'babbage-002', + 'gpt-4-turbo-preview', + 'gpt-4o', + 'gpt-4o-2024-05-13', + 'gpt-4o-mini-2024-07-18', + 'gpt-4o-mini', + 'gpt-4o-2024-08-06', + 'gpt-3.5-turbo-0301', + 'gpt-3.5-turbo-0613', + 'tts-1', + 'gpt-3.5-turbo', + 'gpt-3.5-turbo-16k', + 'davinci-002', + 'gpt-3.5-turbo-16k-0613', + 'gpt-4-turbo-2024-04-09', + 'gpt-3.5-turbo-0125', + 'gpt-4-turbo', + 'gpt-3.5-turbo-1106', + 'gpt-3.5-turbo-instruct-0914', + 'gpt-3.5-turbo-instruct', + 'gpt-4-0613', + 'gpt-4', + 'gpt-4-0314', ] - for model_id, owned_by in model_data: + for model_id in standard_models: config.register_model(model_id, client) + #XXX: Deprecation in 0.1.0 + config.register_model('o1-preview', client, supports_streaming=False) + config.register_model('o1-mini', client, supports_streaming=False) + default_client = None try: default_client = openai.Client() diff --git a/src/ell/provider.py b/src/ell/provider.py index a116ceee..1d779e44 100644 --- a/src/ell/provider.py +++ b/src/ell/provider.py @@ -1,33 +1,29 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple, Type, TypedDict, Union +from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple, Type, TypedDict, Union + +from pydantic import BaseModel, ConfigDict, Field from ell.types import Message, ContentBlock, ToolCall from ell.types._lstr import _lstr import json from dataclasses import dataclass from ell.types.message import LMP -@dataclass -class APICallResult: - response: Any - actual_streaming: bool - actual_n: int - final_call_params: Dict[str, Any] - -class EllCall(TypedDict): - model : str - messages : List[Message] - client : Any - tools : Optional[List[LMP]] - response_format : Optional[Dict[str, Any]] -e = EllCall(messages=[], client=None, tools=None, response_format=None) +class EllCallParams(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + model: str = Field(..., description="Model identifier") + messages: List[Message] = Field(..., description="Conversation context") + client: Any = Field(..., description="API client") + tools: Optional[List[LMP]] = Field(None, description="Available tools") + api_params: Dict[str, Any] = Field(default_factory=dict, description="API parameters") + origin_id: Optional[str] = Field(None, description="Tracking ID") class Metadata(TypedDict): """First class metadata so that ell studio can work, you can add more stuff here if you want""" - +#XXX: Needs a better name. class Provider(ABC): """ Abstract base class for all providers. Providers are API interfaces to language models, not necessarily API providers. @@ -40,7 +36,7 @@ class Provider(ABC): ### API PARAMETERS ############# ################################ @abstractmethod - def provider_call_function(self) -> Dict[str, Any]: + def provider_call_function(self, api_call_params : Dict[str, Any]) -> Callable[..., Any]: """ Implement this method to return the function that makes the API call to the language model. For example, if you're implementing the OpenAI provider, you would return the function that makes the API call to OpenAI's API. @@ -55,105 +51,70 @@ def disallowed_provider_params(self) -> FrozenSet[str]: """ Returns a list of disallowed call params that ell will override. """ - return frozenset({"system", "tools", "tool_choice", "stream", "functions", "function_call", "response_format"}) + pass - def available_params(self) -> Partial[APICallParams]: - return frozenset(get_params_of_call_function(provider_call_params.keys())) + EllCall.__required_keys__ - disallowed_params + def available_params(self) -> APICallParams: + return frozenset(get_params_of_call_function(provider_call_params.keys())) + EllCallParams.__required_keys__ - disallowed_params ################################ ### TRANSLATION ############### ################################ @abstractmethod - def translate_to_provider(self, ) -> APICallParams: + def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]: """Converts an ell call to provider call params!""" return NotImplemented @abstractmethod - def translate_from_provider(self, provider_response : Any, ell_call : EllCall) -> Tuple[List[Message], Metadata]: + def translate_from_provider(self, provider_response : Any, ell_call : EllCallParams, logger : Optional[Callable[[str], None]] = None) -> Tuple[List[Message], Metadata]: """Converts provider responses to universal format.""" return NotImplemented ################################ ### CALL MODEL ################ ################################ - def call_model(self, model : Optional[str] = None, client : Optional[Any] = None, messages : Optional[List[Message]] = None, tools : Optional[List[LMP]] = None, **api_params) -> Any: + # Be careful to override this method in your provider. + def call_model(self, ell_call : EllCallParams, logger : Optional[Any] = None) -> Tuple[List[Message], Dict[str, Any], Metadata]: # Automatic validation of params - assert api_params.keys() not in self.disallowed_provider_params(), f"Disallowed parameters: {api_params}" - assert api_params.keys() in self.available_params(), f"Invalid parameters: {api_params}" + assert ell_call.api_params.keys() not in self.disallowed_provider_params(), f"Disallowed parameters: {ell_call.api_params}" + assert ell_call.api_params.keys() in self.available_params(), f"Invalid parameters: {ell_call.api_params}" # Call - call_params = self.translate_to_provider(ell_call) - provider_resp = self.provider_call_function(client, model)(**call_params) - return self.translate_from_provider(provider_resp, ell_call) - - def default_models(self) -> List[str]: - """Returns a list of default models for this provider.""" - return [ - ] - - def register_all_models(self, client : Any): - """Registers all default models for this provider.""" - for model in self.default_models(): - self.register_model(model, client) - - def validate_call(self, call : EllCall): - if model == "o1-preview" or model == "o1-mini": - # Ensure no system messages are present - assert all(msg['role'] != 'system' for msg in final_call_params['messages']), "System messages are not allowed for o1-preview or o1-mini models" - - if self.model_is_available(call.model): - return - else: - raise ValueError(f"Model {call.model} not available for provider {self.name}") - - -class OpenAIClientProvider(Provider): - """Use this for providers that are a wrapper around an OpenAI client e.g. mistral, groq, azure, etc.""" - - ... - -class OpenAIProvider(OpenAIClientProvider): - def default_models(self) -> List[str]: - return [ - "gpt-4o", - "gpt-4o-mini", - "gpt-4o-2024-08-06", - "gpt-4o-2024-05-13", - "gpt-4o-2024-07-18", - "gpt-4o-2024-06-20", - "gpt-4o-2024-04-09", - "gpt-4o-2024-03-13", - "gpt-4o-2024-02-29", - ] - - def validate_call(self, call : EllCall): - super().validate_call(call) - if model == "o1-preview" or model == "o1-mini": - # Ensure no system messages are present - assert all(msg['role'] != 'system' for msg in final_call_params['messages']), "System messages are not allowed for o1-preview or o1-mini models" - - def provider_call_function(self, EllCall) -> Dict[str, Any]: - if EllCall['response_format']: - return EllCall['client'].beta.chat.completions.parse(**EllCall) - else: - return EllCall['client'].chat.completions.create(**EllCall) + api_call_params = self.translate_to_provider(ell_call) + provider_resp = self.provider_call_function(api_call_params)(**api_call_params) + messages, metadata = self.translate_from_provider(provider_resp, ell_call, logger) - def available_params(self, ell_call : EllCall) -> Partial[APICallParams]: - defualt_params = get_params_of_call_function(self.provider_call_function(ell_call)) + return messages, api_call_params, metadata - if ell_call['response_format']: - # no streaming currently - eturn defualt_params - {'stream'} - else: - return defualt_params -class OllamaProvider(OpenAIClientProvider): - def default_models(self) -> List[str]: +# # +# def validate_provider_call_params(self, ell_call: EllCall, client: Any): +# provider_call_func = self.provider_call_function(client) +# provider_call_params = inspect.signature(provider_call_func).parameters + +# converted_params = self.ell_call_to_provider_call(ell_call) + +# required_params = { +# name: param for name, param in provider_call_params.items() +# if param.default == param.empty and param.kind != param.VAR_KEYWORD +# } + +# for param_name in required_params: +# assert param_name in converted_params, f"Required parameter '{param_name}' is missing in the converted call parameters." + +# for param_name, param_value in converted_params.items(): +# assert param_name in provider_call_params, f"Unexpected parameter '{param_name}' in the converted call parameters." + +# param_type = provider_call_params[param_name].annotation +# if param_type != inspect.Parameter.empty: +# assert isinstance(param_value, param_type), f"Parameter '{param_name}' should be of type {param_type}." + +# print("All parameters validated successfully.") + diff --git a/src/ell/util/_warnings.py b/src/ell/util/_warnings.py index 260dccd7..9418fa66 100644 --- a/src/ell/util/_warnings.py +++ b/src/ell/util/_warnings.py @@ -6,13 +6,14 @@ logger = logging.getLogger(__name__) -def _no_api_key_warning(model, name, client_to_use : Optional[Any], long=False, error=False): +def _no_api_key_warning(model, client_to_use : Optional[Any], name = None, long=False, error=False): color = Fore.RED if error else Fore.LIGHTYELLOW_EX prefix = "ERROR" if error else "WARNING" # openai default client_to_use_name = client_to_use.__class__.__name__ if (client_to_use) else "OpenAI" client_to_use_module = client_to_use.__class__.__module__ if (client_to_use) else "openai" - return f"""{color}{prefix}: No API key found for model `{model}` used by LMP `{name}` using client `{client_to_use_name}`""" + (f""". + lmp_name = f"used by LMP `{name}` " if name else "" + return f"""{color}{prefix}: No API key found for model `{model}` {lmp_name}using client `{client_to_use_name}`""" + (f""". To fix this: * Set your API key in the appropriate environment variable for your chosen provider @@ -22,13 +23,13 @@ def _no_api_key_warning(model, name, client_to_use : Optional[Any], long=False, from {client_to_use_module} import {client_to_use_name} @ell.simple(model="{model}", client={client_to_use_name}(api_key=your_api_key)) - def {name}(...): + def your_lmp_name(...): ... ``` * Or explicitly specify the client when calling the LMP: ``` - {name}(..., client={client_to_use_name}(api_key=your_api_key)) + your_lmp_name(..., client={client_to_use_name}(api_key=your_api_key)) ``` """ if long else " at time of definition. Can be okay if custom client specified later! https://docs.ell.so/core_concepts/models_and_api_clients.html ") + f"{Style.RESET_ALL}" diff --git a/src/ell/util/api.py b/src/ell/util/api.py deleted file mode 100644 index e4dad014..00000000 --- a/src/ell/util/api.py +++ /dev/null @@ -1,61 +0,0 @@ -from functools import partial - -from ell.configurator import config - -from collections import defaultdict -from ell.types._lstr import _lstr -from ell.types import Message, ContentBlock, ToolCall - -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type -from ell.types.message import LMP, LMPParams, MessageOrDict - -from ell.util.verbosity import model_usage_logger_post_end, model_usage_logger_post_intermediate, model_usage_logger_post_start -from ell.util._warnings import _no_api_key_warning -from ell.provider import APICallResult, Provider - -import logging -logger = logging.getLogger(__name__) - -def call( - *, - model: str, - messages: list[Message], - api_params: Dict[str, Any], - tools: Optional[list[LMP]] = None, - client: Optional[Any] = None, - _invocation_origin: str, - should_log: bool, - _name: Optional[str] = None, -) -> Tuple[Union[Message, List[Message]], Dict[str, Any], Dict[str, Any]]: - """ - Helper function to run the language model with the provided messages and parameters. - """ - if not client: - client, was_fallback = config.get_client_for(model) - if not client and not was_fallback: - raise RuntimeError(_no_api_key_warning(model, _name, '', long=True, error=True)) - - if client is None: - raise ValueError(f"No client found for model '{model}'. Ensure the model is registered using 'register_model' in 'config.py' or specify a client directly using the 'client' argument in the decorator or function call.") - - if not client.api_key: - raise RuntimeError(_no_api_key_warning(model, _name, client, long=True, error=True)) - - provider_class: Type[Provider] = config.get_provider_for(client) - - - # XXX: Could actually delete htis - call_result = provider_class.call_model(client, model, messages, api_params, tools) - - if should_log: - model_usage_logger_post_start(call_result.actual_n) - - with model_usage_logger_post_intermediate(call_result.actual_n) as _logger: - tracked_results, metadata = provider_class.process_response(call_result, _invocation_origin, _logger if should_log else None, tools) - - - if config.verbose and not should_log: - model_usage_logger_post_end() - - - return (tracked_results[0] if len(tracked_results) == 1 else tracked_results), call_result.final_call_params, metadata \ No newline at end of file From 8eacee22aa0f49a65696713921329b519d5a78bb Mon Sep 17 00:00:00 2001 From: William Guss Date: Fri, 20 Sep 2024 13:27:55 -0700 Subject: [PATCH 07/17] lstr refactor & good provider implementation --- src/ell/configurator.py | 18 +- src/ell/lmp/complex.py | 412 ++++++++++++++++----------------- src/ell/lmp/simple.py | 160 ++++++------- src/ell/lmp/tool.py | 6 +- src/ell/provider.py | 92 ++++---- src/ell/providers/anthropic.py | 8 +- src/ell/providers/openai.py | 8 +- src/ell/types/_lstr.py | 268 ++++++++++----------- tests/test_lstr.py | 56 ++--- tests/test_openai_provider.py | 2 +- 10 files changed, 523 insertions(+), 507 deletions(-) diff --git a/src/ell/configurator.py b/src/ell/configurator.py index 9f76843f..b6be398c 100644 --- a/src/ell/configurator.py +++ b/src/ell/configurator.py @@ -38,7 +38,7 @@ class Config(BaseModel): lazy_versioning: bool = Field(default=True, description="If True, enables lazy versioning for improved performance.") default_lm_params: Dict[str, Any] = Field(default_factory=dict, description="Default parameters for language models.") default_client: Optional[openai.Client] = Field(default=None, description="The default OpenAI client used when a specific model client is not found.") - providers: Dict[Type, Type[Provider]] = Field(default_factory=dict, description="A dictionary mapping client types to provider classes.") + providers: Dict[Type, Provider] = Field(default_factory=dict, description="A dictionary mapping client types to provider classes.") def __init__(self, **data): super().__init__(**data) self._lock = threading.Lock() @@ -111,26 +111,28 @@ def get_client_for(self, model_name: str) -> Tuple[Optional[openai.Client], bool client = model_config.default_client return client, fallback - def register_provider(self, provider_class: Type[Provider]) -> None: + def register_provider(self, provider: Provider, client_type: Type[Any]) -> None: """ Register a provider class for a specific client type. :param provider_class: The provider class to register. :type provider_class: Type[Provider] """ + assert isinstance(client_type, type), "client_type must be a type (e.g. openai.Client), not an an instance (myclient := openai.Client()))" with self._lock: - self.providers[provider_class.get_client_type()] = provider_class + self.providers[client_type] = provider - def get_provider_for(self, client: Any) -> Optional[Type[Provider]]: + def get_provider_for(self, client: Union[Type[Any], Any]) -> Optional[Provider]: """ - Get the provider class for a specific client instance. + Get the provider instance for a specific client instance. :param client: The client instance to get the provider for. :type client: Any - :return: The provider class for the specified client, or None if not found. - :rtype: Optional[Type[Provider]] + :return: The provider instance for the specified client, or None if not found. + :rtype: Optional[Provider] """ - return next((provider for client_type, provider in self.providers.items() if isinstance(client, client_type)), None) + client_type = type(client) if not isinstance(client, type) else client + return self.providers.get(client_type) # Single* instance # XXX: Make a singleton diff --git a/src/ell/lmp/complex.py b/src/ell/lmp/complex.py index 3a290e97..d5aa3880 100644 --- a/src/ell/lmp/complex.py +++ b/src/ell/lmp/complex.py @@ -6,7 +6,6 @@ from ell.types.message import LMP, InvocableLM, LMPParams, MessageOrDict, _lstr_generic from ell.types.studio import LMPType from ell.util._warnings import _no_api_key_warning, _warnings -from ell.util.api import call from ell.util.verbosity import compute_color, model_usage_logger_pre from ell.util.verbosity import model_usage_logger_post_end, model_usage_logger_post_intermediate, model_usage_logger_post_start @@ -14,212 +13,11 @@ from functools import wraps from typing import Any, Dict, Optional, List, Callable, Tuple, Union -#XXX: Remove the docstirng here. def complex(model: str, client: Optional[Any] = None, tools: Optional[List[Callable]] = None, exempt_from_tracking=False, post_callback: Optional[Callable] = None, **api_params): - """ - A sophisticated language model programming decorator for complex LLM interactions. - - This decorator transforms a function into a Language Model Program (LMP) capable of handling - multi-turn conversations, tool usage, and various output formats. It's designed for advanced - use cases where full control over the LLM's capabilities is needed. - - :param model: The name or identifier of the language model to use. - :type model: str - :param client: An optional OpenAI client instance. If not provided, a default client will be used. - :type client: Optional[openai.Client] - :param tools: A list of tool functions that can be used by the LLM. Only available for certain models. - :type tools: Optional[List[Callable]] - :param response_format: The response format for the LLM. Only available for certain models. - :type response_format: Optional[Dict[str, Any]] - :param n: The number of responses to generate for the LLM. Only available for certain models. - :type n: Optional[int] - :param temperature: The temperature parameter for controlling the randomness of the LLM. - :type temperature: Optional[float] - :param max_tokens: The maximum number of tokens to generate for the LLM. - :type max_tokens: Optional[int] - :param top_p: The top-p sampling parameter for controlling the diversity of the LLM. - :type top_p: Optional[float] - :param frequency_penalty: The frequency penalty parameter for controlling the LLM's repetition. - :type frequency_penalty: Optional[float] - :param presence_penalty: The presence penalty parameter for controlling the LLM's relevance. - :type presence_penalty: Optional[float] - :param stop: The stop sequence for the LLM. - :type stop: Optional[List[str]] - :param exempt_from_tracking: If True, the LMP usage won't be tracked. Default is False. - :type exempt_from_tracking: bool - :param post_callback: An optional function to process the LLM's output before returning. - :type post_callback: Optional[Callable] - :param api_params: Additional keyword arguments to pass to the underlying API call. - :type api_params: Any - - :return: A decorator that can be applied to a function, transforming it into a complex LMP. - :rtype: Callable - - Functionality: - - 1. Advanced LMP Creation: - - Supports multi-turn conversations and stateful interactions. - - Enables tool usage within the LLM context. - - Allows for various output formats, including structured data and function calls. - - 2. Flexible Input Handling: - - Can process both single prompts and conversation histories. - - Supports multimodal inputs (text, images, etc.) in the prompt. - - 3. Comprehensive Integration: - - Integrates with ell's tracking system for monitoring LMP versions, usage, and performance. - - Supports various language models and API configurations. - - 4. Output Processing: - - Can return raw LLM outputs or process them through a post-callback function. - - Supports returning multiple message types (e.g., text, function calls, tool results). - - Usage Modes and Examples: - - 1. Basic Prompt: - - .. code-block:: python - - @ell.complex(model="gpt-4") - def generate_story(prompt: str) -> List[Message]: - '''You are a creative story writer''' # System prompt - return [ - ell.user(f"Write a short story based on this prompt: {prompt}") - ] - - story : ell.Message = generate_story("A robot discovers emotions") - print(story.text) # Access the text content of the last message - - 2. Multi-turn Conversation: - - .. code-block:: python - - @ell.complex(model="gpt-4") - def chat_bot(message_history: List[Message]) -> List[Message]: - return [ - ell.system("You are a helpful assistant."), - ] + message_history - - conversation = [ - ell.user("Hello, who are you?"), - ell.assistant("I'm an AI assistant. How can I help you today?"), - ell.user("Can you explain quantum computing?") - ] - response : ell.Message = chat_bot(conversation) - print(response.text) # Print the assistant's response - - 3. Tool Usage: - - .. code-block:: python - - @ell.tool() - def get_weather(location: str) -> str: - # Implementation to fetch weather - return f"The weather in {location} is sunny." - - @ell.complex(model="gpt-4", tools=[get_weather]) - def weather_assistant(message_history: List[Message]) -> List[Message]: - return [ - ell.system("You are a weather assistant. Use the get_weather tool when needed."), - ] + message_history - - conversation = [ - ell.user("What's the weather like in New York?") - ] - response : ell.Message = weather_assistant(conversation) - - if response.tool_calls: - tool_results = response.call_tools_and_collect_as_message() - print("Tool results:", tool_results.text) - - # Continue the conversation with tool results - final_response = weather_assistant(conversation + [response, tool_results]) - print("Final response:", final_response.text) - - 4. Structured Output: - - .. code-block:: python - - from pydantic import BaseModel - - class PersonInfo(BaseModel): - name: str - age: int - - @ell.complex(model="gpt-4", response_format=PersonInfo) - def extract_person_info(text: str) -> List[Message]: - return [ - ell.system("Extract person information from the given text."), - ell.user(text) - ] - - text = "John Doe is a 30-year-old software engineer." - result : ell.Message = extract_person_info(text) - person_info = result.parsed - print(f"Name: {person_info.name}, Age: {person_info.age}") - - 5. Multimodal Input: - - .. code-block:: python - - @ell.complex(model="gpt-4-vision-preview") - def describe_image(image: PIL.Image.Image) -> List[Message]: - return [ - ell.system("Describe the contents of the image in detail."), - ell.user([ - ContentBlock(text="What do you see in this image?"), - ContentBlock(image=image) - ]) - ] - - image = PIL.Image.open("example.jpg") - description = describe_image(image) - print(description.text) - - 6. Parallel Tool Execution: - - .. code-block:: python - - @ell.complex(model="gpt-4", tools=[tool1, tool2, tool3]) - def parallel_assistant(message_history: List[Message]) -> List[Message]: - return [ - ell.system("You can use multiple tools in parallel."), - ] + message_history - - response = parallel_assistant([ell.user("Perform tasks A, B, and C simultaneously.")]) - if response.tool_calls: - tool_results : ell.Message = response.call_tools_and_collect_as_message(parallel=True, max_workers=3) - print("Parallel tool results:", tool_results.text) - - Helper Functions for Output Processing: - - - response.text: Get the full text content of the last message. - - response.text_only: Get only the text content, excluding non-text elements. - - response.tool_calls: Access the list of tool calls in the message. - - response.tool_results: Access the list of tool results in the message. - - response.structured: Access structured data outputs. - - response.call_tools_and_collect_as_message(): Execute tool calls and collect results. - - Message(role="user", content=[...]).to_openai_message(): Convert to OpenAI API format. - - Notes: - - - The decorated function should return a list of Message objects. - - For tool usage, ensure that tools are properly decorated with @ell.tool(). - - When using structured outputs, specify the response_format in the decorator. - - The complex decorator supports all features of simpler decorators like @ell.simple. - - Use helper functions and properties to easily access and process different types of outputs. - - See Also: - - - ell.simple: For simpler text-only LMP interactions. - - ell.tool: For defining tools that can be used within complex LMPs. - - ell.studio: For visualizing and analyzing LMP executions. - """ default_client_from_decorator = client default_model_from_decorator = model default_api_params_from_decorator = api_params - def parameterized_lm_decorator( prompt: LMP, ) -> Callable[..., Union[List[Message], Message]]: @@ -250,19 +48,20 @@ def model_call( # Merge client overrides & client registry merged_client = _client_for_model(model, client or default_client_from_decorator) ell_call = EllCallParams( - model=merged_api_params.get("model", default_model_from_decorator), + # XXX: Could change behaviour of overriding ell params for dyanmic tool calls. + model=merged_api_params.pop("model", default_model_from_decorator), messages=messages, client = client or default_client_from_decorator, api_params=merged_api_params, - origin_id=_invocation_origin, tools=tools, ) # Get the provider for the model provider = config.get_provider_for(ell_call) + assert provider is not None, f"No provider found for model {ell_call.client}." if should_log: model_usage_logger_post_start(n) with model_usage_logger_post_intermediate(n) as _logger: - (result, final_api_params, metadata) = provider.call_model(ell_call, _logger) + (result, final_api_params, metadata) = provider.call(ell_call, origin_id=_invocation_origin, logger=_logger) result = post_callback(result) if post_callback else result if should_log: @@ -303,7 +102,6 @@ def _get_messages(prompt_ret: Union[str, list[MessageOrDict]], prompt: LMP) -> l ), "Need to pass a list of Messages to the language model" return prompt_ret - def _client_for_model( model: str, client: Optional[Any] = None, @@ -323,4 +121,204 @@ def _client_for_model( # if not client.api_key: # raise RuntimeError(_no_api_key_warning(model, _name, client, long=True, error=True)) - return client \ No newline at end of file + return client + + +complex.__doc__ = """A sophisticated language model programming decorator for complex LLM interactions. + +This decorator transforms a function into a Language Model Program (LMP) capable of handling +multi-turn conversations, tool usage, and various output formats. It's designed for advanced +use cases where full control over the LLM's capabilities is needed. + +:param model: The name or identifier of the language model to use. +:type model: str +:param client: An optional OpenAI client instance. If not provided, a default client will be used. +:type client: Optional[openai.Client] +:param tools: A list of tool functions that can be used by the LLM. Only available for certain models. +:type tools: Optional[List[Callable]] +:param response_format: The response format for the LLM. Only available for certain models. +:type response_format: Optional[Dict[str, Any]] +:param n: The number of responses to generate for the LLM. Only available for certain models. +:type n: Optional[int] +:param temperature: The temperature parameter for controlling the randomness of the LLM. +:type temperature: Optional[float] +:param max_tokens: The maximum number of tokens to generate for the LLM. +:type max_tokens: Optional[int] +:param top_p: The top-p sampling parameter for controlling the diversity of the LLM. +:type top_p: Optional[float] +:param frequency_penalty: The frequency penalty parameter for controlling the LLM's repetition. +:type frequency_penalty: Optional[float] +:param presence_penalty: The presence penalty parameter for controlling the LLM's relevance. +:type presence_penalty: Optional[float] +:param stop: The stop sequence for the LLM. +:type stop: Optional[List[str]] +:param exempt_from_tracking: If True, the LMP usage won't be tracked. Default is False. +:type exempt_from_tracking: bool +:param post_callback: An optional function to process the LLM's output before returning. +:type post_callback: Optional[Callable] +:param api_params: Additional keyword arguments to pass to the underlying API call. +:type api_params: Any + +:return: A decorator that can be applied to a function, transforming it into a complex LMP. +:rtype: Callable + +Functionality: + +1. Advanced LMP Creation: + - Supports multi-turn conversations and stateful interactions. + - Enables tool usage within the LLM context. + - Allows for various output formats, including structured data and function calls. + +2. Flexible Input Handling: + - Can process both single prompts and conversation histories. + - Supports multimodal inputs (text, images, etc.) in the prompt. + +3. Comprehensive Integration: + - Integrates with ell's tracking system for monitoring LMP versions, usage, and performance. + - Supports various language models and API configurations. + +4. Output Processing: + - Can return raw LLM outputs or process them through a post-callback function. + - Supports returning multiple message types (e.g., text, function calls, tool results). + +Usage Modes and Examples: + +1. Basic Prompt: + +.. code-block:: python + + @ell.complex(model="gpt-4") + def generate_story(prompt: str) -> List[Message]: + '''You are a creative story writer''' # System prompt + return [ + ell.user(f"Write a short story based on this prompt: {prompt}") + ] + + story : ell.Message = generate_story("A robot discovers emotions") + print(story.text) # Access the text content of the last message + +2. Multi-turn Conversation: + +.. code-block:: python + + @ell.complex(model="gpt-4") + def chat_bot(message_history: List[Message]) -> List[Message]: + return [ + ell.system("You are a helpful assistant."), + ] + message_history + + conversation = [ + ell.user("Hello, who are you?"), + ell.assistant("I'm an AI assistant. How can I help you today?"), + ell.user("Can you explain quantum computing?") + ] + response : ell.Message = chat_bot(conversation) + print(response.text) # Print the assistant's response + +3. Tool Usage: + +.. code-block:: python + + @ell.tool() + def get_weather(location: str) -> str: + # Implementation to fetch weather + return f"The weather in {location} is sunny." + + @ell.complex(model="gpt-4", tools=[get_weather]) + def weather_assistant(message_history: List[Message]) -> List[Message]: + return [ + ell.system("You are a weather assistant. Use the get_weather tool when needed."), + ] + message_history + + conversation = [ + ell.user("What's the weather like in New York?") + ] + response : ell.Message = weather_assistant(conversation) + + if response.tool_calls: + tool_results = response.call_tools_and_collect_as_message() + print("Tool results:", tool_results.text) + + # Continue the conversation with tool results + final_response = weather_assistant(conversation + [response, tool_results]) + print("Final response:", final_response.text) + +4. Structured Output: + +.. code-block:: python + + from pydantic import BaseModel + + class PersonInfo(BaseModel): + name: str + age: int + + @ell.complex(model="gpt-4", response_format=PersonInfo) + def extract_person_info(text: str) -> List[Message]: + return [ + ell.system("Extract person information from the given text."), + ell.user(text) + ] + + text = "John Doe is a 30-year-old software engineer." + result : ell.Message = extract_person_info(text) + person_info = result.parsed + print(f"Name: {person_info.name}, Age: {person_info.age}") + +5. Multimodal Input: + +.. code-block:: python + + @ell.complex(model="gpt-4-vision-preview") + def describe_image(image: PIL.Image.Image) -> List[Message]: + return [ + ell.system("Describe the contents of the image in detail."), + ell.user([ + ContentBlock(text="What do you see in this image?"), + ContentBlock(image=image) + ]) + ] + + image = PIL.Image.open("example.jpg") + description = describe_image(image) + print(description.text) + +6. Parallel Tool Execution: + +.. code-block:: python + + @ell.complex(model="gpt-4", tools=[tool1, tool2, tool3]) + def parallel_assistant(message_history: List[Message]) -> List[Message]: + return [ + ell.system("You can use multiple tools in parallel."), + ] + message_history + + response = parallel_assistant([ell.user("Perform tasks A, B, and C simultaneously.")]) + if response.tool_calls: + tool_results : ell.Message = response.call_tools_and_collect_as_message(parallel=True, max_workers=3) + print("Parallel tool results:", tool_results.text) + +Helper Functions for Output Processing: + +- response.text: Get the full text content of the last message. +- response.text_only: Get only the text content, excluding non-text elements. +- response.tool_calls: Access the list of tool calls in the message. +- response.tool_results: Access the list of tool results in the message. +- response.structured: Access structured data outputs. +- response.call_tools_and_collect_as_message(): Execute tool calls and collect results. +- Message(role="user", content=[...]).to_openai_message(): Convert to OpenAI API format. + +Notes: + +- The decorated function should return a list of Message objects. +- For tool usage, ensure that tools are properly decorated with @ell.tool(). +- When using structured outputs, specify the response_format in the decorator. +- The complex decorator supports all features of simpler decorators like @ell.simple. +- Use helper functions and properties to easily access and process different types of outputs. + +See Also: + +- ell.simple: For simpler text-only LMP interactions. +- ell.tool: For defining tools that can be used within complex LMPs. +- ell.studio: For visualizing and analyzing LMP executions. + """ \ No newline at end of file diff --git a/src/ell/lmp/simple.py b/src/ell/lmp/simple.py index a57566f4..6a5ae5a0 100644 --- a/src/ell/lmp/simple.py +++ b/src/ell/lmp/simple.py @@ -5,86 +5,6 @@ def simple(model: str, client: Optional[Any] = None, exempt_from_tracking=False, **api_params): - """ - The fundamental unit of language model programming in ell. - - This decorator simplifies the process of creating Language Model Programs (LMPs) - that return text-only outputs from language models, while supporting multimodal inputs. - It wraps the more complex 'complex' decorator, providing a streamlined interface for common use cases. - - :param model: The name or identifier of the language model to use. - :type model: str - :param client: An optional OpenAI client instance. If not provided, a default client will be used. - :type client: Optional[openai.Client] - :param exempt_from_tracking: If True, the LMP usage won't be tracked. Default is False. - :type exempt_from_tracking: bool - :param api_params: Additional keyword arguments to pass to the underlying API call. - :type api_params: Any - - Usage: - The decorated function can return either a single prompt or a list of ell.Message objects: - - .. code-block:: python - - @ell.simple(model="gpt-4", temperature=0.7) - def summarize_text(text: str) -> str: - '''You are an expert at summarizing text.''' # System prompt - return f"Please summarize the following text:\\n\\n{text}" # User prompt - - - @ell.simple(model="gpt-4", temperature=0.7) - def describe_image(image : PIL.Image.Image) -> List[ell.Message]: - '''Describe the contents of an image.''' # unused because we're returning a list of Messages - return [ - # helper function for ell.Message(text="...", role="system") - ell.system("You are an AI trained to describe images."), - # helper function for ell.Message(content="...", role="user") - ell.user(["Describe this image in detail.", image]), - ] - - - image_description = describe_image(PIL.Image.open("https://example.com/image.jpg")) - print(image_description) - # Output will be a string text-only description of the image - - summary = summarize_text("Long text to summarize...") - print(summary) - # Output will be a text-only summary - - Notes: - - - This decorator is designed for text-only model outputs, but supports multimodal inputs. - - It simplifies complex responses from language models to text-only format, regardless of - the model's capability for structured outputs, function calling, or multimodal outputs. - - For preserving complex model outputs (e.g., structured data, function calls, or multimodal - outputs), use the @ell.complex decorator instead. @ell.complex returns a Message object (role='assistant') - - The decorated function can return a string or a list of ell.Message objects for more - complex prompts, including multimodal inputs. - - If called with n > 1 in api_params, the wrapped LMP will return a list of strings for the n parallel outputs - of the model instead of just one string. Otherwise, it will return a single string. - - You can pass LM API parameters either in the decorator or when calling the decorated function. - Parameters passed during the function call will override those set in the decorator. - - Example of passing LM API params: - - .. code-block:: python - - @ell.simple(model="gpt-4", temperature=0.7) - def generate_story(prompt: str) -> str: - return f"Write a short story based on this prompt: {prompt}" - - # Using default parameters - story1 = generate_story("A day in the life of a time traveler") - - # Overriding parameters during function call - story2 = generate_story("An AI's first day of consciousness", lm_params={"temperature": 0.9, "max_tokens": 500}) - - See Also: - - - :func:`ell.complex`: For LMPs that preserve full structure of model responses, including multimodal outputs. - - :func:`ell.tool`: For defining tools that can be used within complex LMPs. - - :mod:`ell.studio`: For visualizing and analyzing LMP executions. - """ assert 'tools' not in api_params, "tools are not supported in lm decorator, use multimodal decorator instead" assert 'tool_choice' not in api_params, "tool_choice is not supported in lm decorator, use multimodal decorator instead" assert 'response_format' not in api_params, "response_format is not supported in lm decorator, use multimodal decorator instead" @@ -94,3 +14,83 @@ def convert_multimodal_response_to_lstr(response): return complex(model, client, exempt_from_tracking, **api_params, post_callback=convert_multimodal_response_to_lstr) + +simple.__doc__ = """The fundamental unit of language model programming in ell. + + This decorator simplifies the process of creating Language Model Programs (LMPs) + that return text-only outputs from language models, while supporting multimodal inputs. + It wraps the more complex 'complex' decorator, providing a streamlined interface for common use cases. + + :param model: The name or identifier of the language model to use. + :type model: str + :param client: An optional OpenAI client instance. If not provided, a default client will be used. + :type client: Optional[openai.Client] + :param exempt_from_tracking: If True, the LMP usage won't be tracked. Default is False. + :type exempt_from_tracking: bool + :param api_params: Additional keyword arguments to pass to the underlying API call. + :type api_params: Any + + Usage: + The decorated function can return either a single prompt or a list of ell.Message objects: + + .. code-block:: python + + @ell.simple(model="gpt-4", temperature=0.7) + def summarize_text(text: str) -> str: + '''You are an expert at summarizing text.''' # System prompt + return f"Please summarize the following text:\\n\\n{text}" # User prompt + + + @ell.simple(model="gpt-4", temperature=0.7) + def describe_image(image : PIL.Image.Image) -> List[ell.Message]: + '''Describe the contents of an image.''' # unused because we're returning a list of Messages + return [ + # helper function for ell.Message(text="...", role="system") + ell.system("You are an AI trained to describe images."), + # helper function for ell.Message(content="...", role="user") + ell.user(["Describe this image in detail.", image]), + ] + + + image_description = describe_image(PIL.Image.open("https://example.com/image.jpg")) + print(image_description) + # Output will be a string text-only description of the image + + summary = summarize_text("Long text to summarize...") + print(summary) + # Output will be a text-only summary + + Notes: + + - This decorator is designed for text-only model outputs, but supports multimodal inputs. + - It simplifies complex responses from language models to text-only format, regardless of + the model's capability for structured outputs, function calling, or multimodal outputs. + - For preserving complex model outputs (e.g., structured data, function calls, or multimodal + outputs), use the @ell.complex decorator instead. @ell.complex returns a Message object (role='assistant') + - The decorated function can return a string or a list of ell.Message objects for more + complex prompts, including multimodal inputs. + - If called with n > 1 in api_params, the wrapped LMP will return a list of strings for the n parallel outputs + of the model instead of just one string. Otherwise, it will return a single string. + - You can pass LM API parameters either in the decorator or when calling the decorated function. + Parameters passed during the function call will override those set in the decorator. + + Example of passing LM API params: + + .. code-block:: python + + @ell.simple(model="gpt-4", temperature=0.7) + def generate_story(prompt: str) -> str: + return f"Write a short story based on this prompt: {prompt}" + + # Using default parameters + story1 = generate_story("A day in the life of a time traveler") + + # Overriding parameters during function call + story2 = generate_story("An AI's first day of consciousness", lm_params={"temperature": 0.9, "max_tokens": 500}) + + See Also: + + - :func:`ell.complex`: For LMPs that preserve full structure of model responses, including multimodal outputs. + - :func:`ell.tool`: For defining tools that can be used within complex LMPs. + - :mod:`ell.studio`: For visualizing and analyzing LMP executions. + """ \ No newline at end of file diff --git a/src/ell/lmp/tool.py b/src/ell/lmp/tool.py index bab0df99..fbe5bae5 100644 --- a/src/ell/lmp/tool.py +++ b/src/ell/lmp/tool.py @@ -160,7 +160,7 @@ def wrapper( # Similar to how it's done in the lm decorator # Use _invocation_origin if isinstance(result, str) and _invocation_origin: - result = _lstr(result, _origin_trace=_invocation_origin) + result = _lstr(result,origin_trace=_invocation_origin) #XXX: This _tool_call_id thing is a hack. Tracking should happen via params in the api if _tool_call_id: @@ -168,7 +168,7 @@ def wrapper( content_results = coerce_content_list(result) except ValueError as e: # XXX: TODO: MOVE TRACKING CODE TO _TRACK AND OUT OF HERE AND API. - content_results = [ContentBlock(text=_lstr(json.dumps(result), _origin_trace=_invocation_origin))] + content_results = [ContentBlock(text=_lstr(json.dumps(result),origin_trace=_invocation_origin))] # TODO: poolymorphic validation here is important (cant have tool_call or formatted_response in the result) # XXX: Should we put this coercion here or in the tool call/result area. @@ -179,7 +179,7 @@ def wrapper( # Warning: Formatted response in tool result will be converted to text # TODO: Logging needs to produce not print. print(f"Warning: Formatted response in tool result will be converted to text. Original: {c.parsed}") - c.text = _lstr(c.parsed.model_dump_json(), _origin_trace=_invocation_origin) + c.text = _lstr(c.parsed.model_dump_json(),origin_trace=_invocation_origin) c.parsed = None assert not c.audio, "Audio in tool result" return ToolResult(tool_call_id=_tool_call_id, result=content_results), _invocation_api_params, {} diff --git a/src/ell/provider.py b/src/ell/provider.py index 1d779e44..2f0653ac 100644 --- a/src/ell/provider.py +++ b/src/ell/provider.py @@ -1,5 +1,8 @@ from abc import ABC, abstractmethod from collections import defaultdict +from functools import lru_cache +import inspect +from types import MappingProxyType from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple, Type, TypedDict, Union from pydantic import BaseModel, ConfigDict, Field @@ -10,14 +13,17 @@ from ell.types.message import LMP +# XXX: Might leave this internal to providers so that the complex code is simpler & +# we can literally jsut call provider.call like any openai fn. class EllCallParams(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) model: str = Field(..., description="Model identifier") messages: List[Message] = Field(..., description="Conversation context") client: Any = Field(..., description="API client") tools: Optional[List[LMP]] = Field(None, description="Available tools") api_params: Dict[str, Any] = Field(default_factory=dict, description="API parameters") - origin_id: Optional[str] = Field(None, description="Tracking ID") + + + model_config = ConfigDict(arbitrary_types_allowed=True) class Metadata(TypedDict): @@ -29,32 +35,29 @@ class Provider(ABC): Abstract base class for all providers. Providers are API interfaces to language models, not necessarily API providers. For example, the OpenAI provider is an API interface to OpenAI's API but also to Ollama and Azure OpenAI. In Ell. We hate abstractions. The only reason this exists is to force implementers to implement their own provider correctly -_-. - """ ################################ ### API PARAMETERS ############# ################################ @abstractmethod - def provider_call_function(self, api_call_params : Dict[str, Any]) -> Callable[..., Any]: + def provider_call_function(self, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]: """ Implement this method to return the function that makes the API call to the language model. For example, if you're implementing the OpenAI provider, you would return the function that makes the API call to OpenAI's API. - ```python - return openai.Completion.create - ``` """ return NotImplemented - @abstractmethod - def disallowed_provider_params(self) -> FrozenSet[str]: + + def disallowed_api_params(self) -> FrozenSet[str]: """ Returns a list of disallowed call params that ell will override. """ - pass + return frozenset({"messages", "tools", "model"}) - def available_params(self) -> APICallParams: - return frozenset(get_params_of_call_function(provider_call_params.keys())) + EllCallParams.__required_keys__ - disallowed_params + def available_api_params(self, api_params : Optional[Dict[str, Any]] = None): + params = _call_params(self.provider_call_function(api_params)) + return frozenset(params.keys()) - self.disallowed_api_params() ################################ @@ -66,7 +69,7 @@ def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]: return NotImplemented @abstractmethod - def translate_from_provider(self, provider_response : Any, ell_call : EllCallParams, logger : Optional[Callable[[str], None]] = None) -> Tuple[List[Message], Metadata]: + def translate_from_provider(self, provider_response : Any, ell_call : EllCallParams, origin_id : Optional[str] = None, logger : Optional[Callable[[str], None]] = None) -> Tuple[List[Message], Metadata]: """Converts provider responses to universal format.""" return NotImplemented @@ -74,47 +77,58 @@ def translate_from_provider(self, provider_response : Any, ell_call : EllCallPar ### CALL MODEL ################ ################################ # Be careful to override this method in your provider. - def call_model(self, ell_call : EllCallParams, logger : Optional[Any] = None) -> Tuple[List[Message], Dict[str, Any], Metadata]: + def call(self, ell_call : EllCallParams, origin_id : Optional[str] = None, logger : Optional[Any] = None) -> Tuple[List[Message], Dict[str, Any], Metadata]: # Automatic validation of params - - assert ell_call.api_params.keys() not in self.disallowed_provider_params(), f"Disallowed parameters: {ell_call.api_params}" - assert ell_call.api_params.keys() in self.available_params(), f"Invalid parameters: {ell_call.api_params}" + assert ell_call.api_params.keys() not in self.disallowed_api_params(), f"Disallowed parameters: {ell_call.api_params}" # Call - api_call_params = self.translate_to_provider(ell_call) - provider_resp = self.provider_call_function(api_call_params)(**api_call_params) - messages, metadata = self.translate_from_provider(provider_resp, ell_call, logger) + - return messages, api_call_params, metadata + final_api_call_params = self.translate_to_provider(ell_call) + call = self.provider_call_function(final_api_call_params) + _validate_provider_call_params(final_api_call_params, call) + provider_resp = call(final_api_call_params)(**final_api_call_params) + messages, metadata = self.translate_from_provider(provider_resp, ell_call, origin_id, logger) + _validate_messages_are_tracked(messages, origin_id) + # TODO: Validate messages are tracked. + return messages, final_api_call_params, metadata + -# # -# def validate_provider_call_params(self, ell_call: EllCall, client: Any): -# provider_call_func = self.provider_call_function(client) -# provider_call_params = inspect.signature(provider_call_func).parameters - -# converted_params = self.ell_call_to_provider_call(ell_call) + +# handhold the the implementer, in production mode we can turn these off for speed. +@lru_cache(maxsize=None) +def _call_params(call : Callable[..., Any]) -> MappingProxyType[str, inspect.Parameter]: + return inspect.signature(call).parameters + +def _validate_provider_call_params(api_call_params: Dict[str, Any], call : Callable[..., Any]): + provider_call_params = _call_params(call) -# required_params = { -# name: param for name, param in provider_call_params.items() -# if param.default == param.empty and param.kind != param.VAR_KEYWORD -# } + required_params = { + name: param for name, param in provider_call_params.items() + if param.default == param.empty and param.kind != param.VAR_KEYWORD + } -# for param_name in required_params: -# assert param_name in converted_params, f"Required parameter '{param_name}' is missing in the converted call parameters." + for param_name in required_params: + assert param_name in api_call_params, f"Provider implementation error: Required parameter '{param_name}' is missing in the converted call parameters converted from ell call." -# for param_name, param_value in converted_params.items(): -# assert param_name in provider_call_params, f"Unexpected parameter '{param_name}' in the converted call parameters." + for param_name, param_value in api_call_params.items(): + assert param_name in provider_call_params, f"Provider implementation error: Unexpected parameter '{param_name}' in the converted call parameters." -# param_type = provider_call_params[param_name].annotation -# if param_type != inspect.Parameter.empty: -# assert isinstance(param_value, param_type), f"Parameter '{param_name}' should be of type {param_type}." + param_type = provider_call_params[param_name].annotation + if param_type != inspect.Parameter.empty: + assert isinstance(param_value, param_type), f"Provider implementation error: Parameter '{param_name}' should be of type {param_type}." -# print("All parameters validated successfully.") +def _validate_messages_are_tracked(messages: List[Message], origin_id: Optional[str] = None): + if origin_id is None: return + + for message in messages: + assert isinstance(message.text, _lstr), f"Provider implementation error: Message text should be an instance of _lstr, got {type(message.text)}" + assert message.text._or == origin_id, f"Provider implementation error: Message origin_id {message.text.origin_id} does not match the provided origin_id {origin_id}" diff --git a/src/ell/providers/anthropic.py b/src/ell/providers/anthropic.py index c4970a4e..1505cc05 100644 --- a/src/ell/providers/anthropic.py +++ b/src/ell/providers/anthropic.py @@ -15,7 +15,7 @@ class AnthropicProvider(Provider): @classmethod - def call_model( + def call( cls, client: Anthropic, model: str, @@ -95,7 +95,7 @@ def process_response( elif chunk.type == "content_block_stop": if current_block is not None: if current_block["type"] == "text": - content.append(ContentBlock(text=_lstr(current_block["content"], _origin_trace=_invocation_origin))) + content.append(ContentBlock(text=_lstr(current_block["content"],origin_trace=_invocation_origin))) elif current_block["type"] == "tool_use": try: final_cb = chunk.content_block @@ -116,7 +116,7 @@ def process_response( tool_call=ToolCall( tool=matching_tool, tool_call_id=_lstr( - final_cb.id, _origin_trace=_invocation_origin + final_cb.id,origin_trace=_invocation_origin ), params=params, ) @@ -151,7 +151,7 @@ def process_response( cbs = [] for content_block in call_result.response.content: if content_block.type == "text": - cbs.append(ContentBlock(text=_lstr(content_block.text, _origin_trace=_invocation_origin))) + cbs.append(ContentBlock(text=_lstr(content_block.text,origin_trace=_invocation_origin))) elif content_block.type == "tool_use": assert tools is not None, "Tools were not provided to the model when calling it and yet anthropic returned a tool use." tool_call = ToolCall( diff --git a/src/ell/providers/openai.py b/src/ell/providers/openai.py index 2921062c..30797274 100644 --- a/src/ell/providers/openai.py +++ b/src/ell/providers/openai.py @@ -77,7 +77,7 @@ def message_to_openai_format(message: Message) -> Dict[str, Any]: return openai_message @classmethod - def call_model( + def call( cls, client: Any, model: str, @@ -178,7 +178,7 @@ def process_response( content.append( ContentBlock( text=_lstr( - content=text_content, _origin_trace=_invocation_origin + content=text_content,origin_trace=_invocation_origin ) ) ) @@ -192,7 +192,7 @@ def process_response( content.append( ContentBlock( text=_lstr( - content=choice.content, _origin_trace=_invocation_origin + content=choice.content,origin_trace=_invocation_origin ) ) ) @@ -217,7 +217,7 @@ def process_response( tool_call=ToolCall( tool=matching_tool, tool_call_id=_lstr( - tool_call.id, _origin_trace=_invocation_origin + tool_call.id,origin_trace=_invocation_origin ), params=params, ) diff --git a/src/ell/types/_lstr.py b/src/ell/types/_lstr.py index b77647cd..93d4d582 100644 --- a/src/ell/types/_lstr.py +++ b/src/ell/types/_lstr.py @@ -1,6 +1,7 @@ """ -LM string that supports logits and keeps track of it's _origin_trace even after mutation. +LM string that supports logits and keeps track of it'sorigin_trace even after mutation. """ + import numpy as np from typing import ( Optional, @@ -20,87 +21,88 @@ from pydantic_core import CoreSchema, core_schema + class _lstr(str): """ - A string class that supports logits and keeps track of its _origin_trace even after mutation. - This class is designed to be used in prompt engineering libraries where it is essential to associate - logits with generated text and track the origin of the text. - - The `lstr` class inherits from the built-in `str` class and adds two additional attributes: `logits` and `_origin_trace`. - The `_origin_trace` attribute is a frozen set of strings that represents the _origin_trace(s) of the string. - - The class provides various methods for manipulating the string, such as concatenation, slicing, splitting, and joining. - These methods ensure that the logits and _origin_trace(s) are updated correctly based on the operation performed. - - The `lstr` class is particularly useful in LLM libraries for tracing the flow of prompts through various language model calls. - By tracking the _origin_trace of each string, it is possible to visualize how outputs from one language model program influence - the inputs of another, allowing for a detailed analysis of interactions between different large language models. This capability - is crucial for understanding the propagation of prompts in complex LLM workflows and for building visual graphs that depict these interactions. - - It is important to note that any modification to the string (such as concatenation or replacement) will invalidate the associated logits. - This is because the logits are specifically tied to the original string content, and any change would require a new computation of logits. - The logic behind this is detailed elsewhere in this file. - - Example usage: - ``` - # Create an lstr instance with logits and an _origin_trace - logits = np.array([1.0, 2.0, 3.0]) - _origin_trace = "4e9b7ec9" - lstr_instance = lstr("Hello", logits, _origin_trace) - - # Concatenate two lstr instances - lstr_instance2 = lstr("World", None, "7f4d2c3a") - concatenated_lstr = lstr_instance + lstr_instance2 - - # Get the logits and _origin_trace of the concatenated lstr - print(concatenated_lstr.logits) # Output: None - print(concatenated_lstr._origin_trace) # Output: frozenset({'4e9b7ec9', '7f4d2c3a'}) - - # Split the concatenated lstr into two parts - parts = concatenated_lstr.split() - print(parts) # Output: [lstr('Hello', None, frozenset({'4e9b7ec9', '7f4d2c3a'})), lstr('World', None, frozenset({'4e9b7ec9', '7f4d2c3a'}))] - ``` - Attributes: - _origin_trace (FrozenSet[str]): A frozen set of strings representing the _origin_trace(s) of the string. - - Methods: - __new__: Create a new instance of lstr. - __repr__: Return a string representation of the lstr instance. - __add__: Concatenate this lstr instance with another string or lstr instance. - __mod__: Perform a modulo operation between this lstr instance and another string, lstr, or a tuple of strings and lstrs. - __mul__: Perform a multiplication operation between this lstr instance and an integer or another lstr. - __rmul__: Perform a right multiplication operation between an integer or another lstr and this lstr instance. - __getitem__: Get a slice or index of this lstr instance. - __getattr__: Get an attribute from this lstr instance. - join: Join a sequence of strings or lstr instances into a single lstr instance. - split: Split this lstr instance into a list of lstr instances based on a separator. - rsplit: Split this lstr instance into a list of lstr instances based on a separator, starting from the right. - splitlines: Split this lstr instance into a list of lstr instances based on line breaks. - partition: Partition this lstr instance into three lstr instances based on a separator. - rpartition: Partition this lstr instance into three lstr instances based on a separator, starting from the right. + A string class that supports logits and keeps track of itsorigin_trace even after mutation. + This class is designed to be used in prompt engineering libraries where it is essential to associate + logits with generated text and track the origin of the text. + + The `lstr` class inherits from the built-in `str` class and adds two additional attributes: `logits` and `origin_trace`. + The `origin_trace` attribute is a frozen set of strings that represents theorigin_trace(s) of the string. + + The class provides various methods for manipulating the string, such as concatenation, slicing, splitting, and joining. + These methods ensure that the logits andorigin_trace(s) are updated correctly based on the operation performed. + + The `lstr` class is particularly useful in LLM libraries for tracing the flow of prompts through various language model calls. + By tracking theorigin_trace of each string, it is possible to visualize how outputs from one language model program influence + the inputs of another, allowing for a detailed analysis of interactions between different large language models. This capability + is crucial for understanding the propagation of prompts in complex LLM workflows and for building visual graphs that depict these interactions. + + It is important to note that any modification to the string (such as concatenation or replacement) will invalidate the associated logits. + This is because the logits are specifically tied to the original string content, and any change would require a new computation of logits. + The logic behind this is detailed elsewhere in this file. + + Example usage: + ``` + # Create an lstr instance with logits and anorigin_trace + logits = np.array([1.0, 2.0, 3.0]) + origin_trace = "4e9b7ec9" + lstr_instance = lstr("Hello", logits,origin_trace) + + # Concatenate two lstr instances + lstr_instance2 = lstr("World", None, "7f4d2c3a") + concatenated_lstr = lstr_instance + lstr_instance2 + + # Get the logits andorigin_trace of the concatenated lstr + print(concatenated_lstr.logits) # Output: None + print(concatenated_lstr.origin_trace) # Output: frozenset({'4e9b7ec9', '7f4d2c3a'}) + + # Split the concatenated lstr into two parts + parts = concatenated_lstr.split() + print(parts) # Output: [lstr('Hello', None, frozenset({'4e9b7ec9', '7f4d2c3a'})), lstr('World', None, frozenset({'4e9b7ec9', '7f4d2c3a'}))] + ``` + Attributes: + origin_trace (FrozenSet[str]): A frozen set of strings representing theorigin_trace(s) of the string. + + Methods: + __new__: Create a new instance of lstr. + __repr__: Return a string representation of the lstr instance. + __add__: Concatenate this lstr instance with another string or lstr instance. + __mod__: Perform a modulo operation between this lstr instance and another string, lstr, or a tuple of strings and lstrs. + __mul__: Perform a multiplication operation between this lstr instance and an integer or another lstr. + __rmul__: Perform a right multiplication operation between an integer or another lstr and this lstr instance. + __getitem__: Get a slice or index of this lstr instance. + __getattr__: Get an attribute from this lstr instance. + join: Join a sequence of strings or lstr instances into a single lstr instance. + split: Split this lstr instance into a list of lstr instances based on a separator. + rsplit: Split this lstr instance into a list of lstr instances based on a separator, starting from the right. + splitlines: Split this lstr instance into a list of lstr instances based on line breaks. + partition: Partition this lstr instance into three lstr instances based on a separator. + rpartition: Partition this lstr instance into three lstr instances based on a separator, starting from the right. """ def __new__( cls, content: str, logits: Optional[np.ndarray] = None, - _origin_trace: Optional[Union[str, FrozenSet[str]]] = None, + origin_trace: Optional[Union[str, FrozenSet[str]]] = None, ): """ - Create a new instance of lstr. The `logits` should be a numpy array and `_origin_trace` should be a frozen set of strings or a single string. + Create a new instance of lstr. The `logits` should be a numpy array and `origin_trace` should be a frozen set of strings or a single string. - Args: - content (str): The string content of the lstr. - logits (np.ndarray, optional): The logits associated with this string. Defaults to None. - _origin_trace (Union[str, FrozenSet[str]], optional): The _origin_trace(s) of this string. Defaults to None. + Args: + content (str): The string content of the lstr. + logits (np.ndarray, optional): The logits associated with this string. Defaults to None. + origin_trace (Union[str, FrozenSet[str]], optional): Theorigin_trace(s) of this string. Defaults to None. """ instance = super(_lstr, cls).__new__(cls, content) # instance._logits = logits - if isinstance(_origin_trace, str): - instance.__origin_trace__ = frozenset({_origin_trace}) + if isinstance(origin_trace, str): + instance.__origin_trace__ = frozenset({origin_trace}) else: instance.__origin_trace__ = ( - frozenset(_origin_trace) if _origin_trace is not None else frozenset() + frozenset(origin_trace) if origin_trace is not None else frozenset() ) return instance @@ -112,10 +114,10 @@ def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler ) -> CoreSchema: def validate_lstr(value): - if isinstance(value, dict) and value.get('__lstr', False): - content = value['content'] - _origin_trace = value['__origin_trace__'].split(',') - return cls(content, _origin_trace=_origin_trace) + if isinstance(value, dict) and value.get("__lstr", False): + content = value["content"] + origin_trace = value["__origin_trace__"].split(",") + return cls(content, origin_trace=origin_trace) elif isinstance(value, str): return cls(value) elif isinstance(value, cls): @@ -124,32 +126,37 @@ def validate_lstr(value): raise ValueError(f"Invalid value for lstr: {value}") return core_schema.json_or_python_schema( - json_schema=core_schema.typed_dict_schema({ - 'content': core_schema.typed_dict_field(core_schema.str_schema()), - '__origin_trace__': core_schema.typed_dict_field(core_schema.str_schema()), - '__lstr': core_schema.typed_dict_field(core_schema.bool_schema()), - }), - python_schema=core_schema.union_schema([ - core_schema.is_instance_schema(cls), - core_schema.no_info_plain_validator_function(validate_lstr), - ]), + json_schema=core_schema.typed_dict_schema( + { + "content": core_schema.typed_dict_field(core_schema.str_schema()), + "__origin_trace__": core_schema.typed_dict_field( + core_schema.str_schema() + ), + "__lstr": core_schema.typed_dict_field(core_schema.bool_schema()), + } + ), + python_schema=core_schema.union_schema( + [ + core_schema.is_instance_schema(cls), + core_schema.no_info_plain_validator_function(validate_lstr), + ] + ), serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: { + lambda instance: { "content": str(instance), "__origin_trace__": (instance.__origin_trace__), - "__lstr": True + "__lstr": True, } - ) + ), ) - @property - def _origin_trace(self) -> FrozenSet[str]: + def origin_trace(self) -> FrozenSet[str]: """ - Get the _origin_trace(s) of this lstr instance. + Get theorigin_trace(s) of this lstr instance. Returns: - FrozenSet[str]: A frozen set of strings representing the _origin_trace(s) of this lstr instance. + FrozenSet[str]: A frozen set of strings representing theorigin_trace(s) of this lstr instance. """ return self.__origin_trace__ @@ -161,7 +168,7 @@ def __repr__(self) -> str: Return a string representation of this lstr instance. Returns: - str: A string representation of this lstr instance, including its content, logits, and _origin_trace(s). + str: A string representation of this lstr instance, including its content, logits, andorigin_trace(s). """ return super().__repr__() @@ -173,19 +180,18 @@ def __add__(self, other: Union[str, "_lstr"]) -> "_lstr": other (Union[str, "lstr"]): The string or lstr instance to concatenate with this instance. Returns: - lstr: A new lstr instance containing the concatenated content, with the _origin_trace(s) updated accordingly. + lstr: A new lstr instance containing the concatenated content, with theorigin_trace(s) updated accordingly. """ new_content = super(_lstr, self).__add__(other) self_origin = self.__origin_trace__ - + if isinstance(other, _lstr): - new_origin = set(self_origin) - new_origin.update(other.__origin_trace__) - new_origin = frozenset(new_origin) + new_origin = self_origin + new_origin = new_origin.union(other.__origin_trace__) else: new_origin = self_origin - - return _lstr(new_content, None, new_origin) + + return _lstr(new_content, None, frozenset(new_origin)) def __mod__( self, other: Union[str, "_lstr", Tuple[Union[str, "_lstr"], ...]] @@ -198,7 +204,7 @@ def __mod__( other (Union[str, "lstr", Tuple[Union[str, "lstr"], ...]]): The right operand in the modulo operation. Returns: - lstr: A new lstr instance containing the result of the modulo operation, with the _origin_trace(s) updated accordingly. + lstr: A new lstr instance containing the result of the modulo operation, with theorigin_trace(s) updated accordingly. """ # If 'other' is a tuple, we need to handle each element if isinstance(other, tuple): @@ -211,7 +217,9 @@ def __mod__( else: result_content = super(_lstr, self).__mod__(other) if isinstance(other, _lstr): - new__origin_trace__ = self.__origin_trace__.union(other.__origin_trace__) + new__origin_trace__ = self.__origin_trace__.union( + other.__origin_trace__ + ) else: new__origin_trace__ = self.__origin_trace__ @@ -226,7 +234,7 @@ def __mul__(self, other: SupportsIndex) -> "_lstr": other (Union[SupportsIndex, "lstr"]): The right operand in the multiplication operation. Returns: - lstr: A new lstr instance containing the result of the multiplication operation, with the _origin_trace(s) updated accordingly. + lstr: A new lstr instance containing the result of the multiplication operation, with theorigin_trace(s) updated accordingly. """ if isinstance(other, SupportsIndex): result_content = super(_lstr, self).__mul__(other) @@ -245,7 +253,7 @@ def __rmul__(self, other: SupportsIndex) -> "_lstr": other (Union[SupportsIndex, "lstr"]): The left operand in the multiplication operation. Returns: - lstr: A new lstr instance containing the result of the multiplication operation, with the _origin_trace(s) updated accordingly. + lstr: A new lstr instance containing the result of the multiplication operation, with theorigin_trace(s) updated accordingly. """ return self.__mul__(other) # Multiplication is commutative in this context @@ -257,7 +265,7 @@ def __getitem__(self, key: Union[SupportsIndex, slice]) -> "_lstr": key (Union[SupportsIndex, slice]): The index or slice to retrieve. Returns: - lstr: A new lstr instance containing the sliced or indexed content, with the _origin_trace(s) preserved. + lstr: A new lstr instance containing the sliced or indexed content, with theorigin_trace(s) preserved. """ result = super(_lstr, self).__getitem__(key) # This is a matter of opinon. I believe that when you Index into a language model output, you or divorcing the lodges of the indexed result from their contacts which produce them. Therefore, it is only reasonable to directly index into the lodges without changing the original context, and so any mutation on the string should invalidate the logits. @@ -295,14 +303,14 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: result = attr(*args, **kwargs) # If the result is a string, return an lstr instance if isinstance(result, str): - _origin_traces = self.__origin_trace__ + origin_traces = self.__origin_trace__ for arg in args: if isinstance(arg, _lstr): - _origin_traces = _origin_traces.union(arg.__origin_trace__) + origin_traces = origin_traces.union(arg.__origin_trace__) for key, value in kwargs.items(): if isinstance(value, _lstr): - _origin_traces = _origin_traces.union(value.__origin_trace__) - return _lstr(result, None, _origin_traces) + origin_traces = origin_traces.union(value.__origin_trace__) + return _lstr(result, None, origin_traces) return result @@ -319,7 +327,7 @@ def join(self, iterable: Iterable[Union[str, "_lstr"]]) -> "_lstr": iterable (Iterable[Union[str, "lstr"]]): The sequence of strings or lstr instances to join. Returns: - lstr: A new lstr instance containing the joined content, with the _origin_trace(s) updated accordingly. + lstr: A new lstr instance containing the joined content, with theorigin_trace(s) updated accordingly. """ parts = [str(item) for item in iterable] new_content = super(_lstr, self).join(parts) @@ -341,7 +349,7 @@ def split( maxsplit (SupportsIndex, optional): The maximum number of splits to perform. Defaults to -1. Returns: - List["lstr"]: A list of lstr instances containing the split content, with the _origin_trace(s) preserved. + List["lstr"]: A list of lstr instances containing the split content, with theorigin_trace(s) preserved. """ return self._split_helper(super(_lstr, self).split, sep, maxsplit) @@ -357,7 +365,7 @@ def rsplit( maxsplit (SupportsIndex, optional): The maximum number of splits to perform. Defaults to -1. Returns: - List["lstr"]: A list of lstr instances containing the split content, with the _origin_trace(s) preserved. + List["lstr"]: A list of lstr instances containing the split content, with theorigin_trace(s) preserved. """ return self._split_helper(super(_lstr, self).rsplit, sep, maxsplit) @@ -370,7 +378,7 @@ def splitlines(self, keepends: bool = False) -> List["_lstr"]: keepends (bool, optional): Whether to include the line breaks in the resulting lstr instances. Defaults to False. Returns: - List["lstr"]: A list of lstr instances containing the split content, with the _origin_trace(s) preserved. + List["lstr"]: A list of lstr instances containing the split content, with theorigin_trace(s) preserved. """ return [ _lstr(p, None, self.__origin_trace__) @@ -386,7 +394,7 @@ def partition(self, sep: Union[str, "_lstr"]) -> Tuple["_lstr", "_lstr", "_lstr" sep (Union[str, "lstr"]): The separator to partition on. Returns: - Tuple["lstr", "lstr", "lstr"]: A tuple of three lstr instances containing the content before the separator, the separator itself, and the content after the separator, with the _origin_trace(s) updated accordingly. + Tuple["lstr", "lstr", "lstr"]: A tuple of three lstr instances containing the content before the separator, the separator itself, and the content after the separator, with theorigin_trace(s) updated accordingly. """ return self._partition_helper(super(_lstr, self).partition, sep) @@ -399,7 +407,7 @@ def rpartition(self, sep: Union[str, "_lstr"]) -> Tuple["_lstr", "_lstr", "_lstr sep (Union[str, "lstr"]): The separator to partition on. Returns: - Tuple["lstr", "lstr", "lstr"]: A tuple of three lstr instances containing the content before the separator, the separator itself, and the content after the separator, with the _origin_trace(s) updated accordingly. + Tuple["lstr", "lstr", "lstr"]: A tuple of three lstr instances containing the content before the separator, the separator itself, and the content after the separator, with theorigin_trace(s) updated accordingly. """ return self._partition_helper(super(_lstr, self).rpartition, sep) @@ -414,7 +422,7 @@ def _partition_helper( sep (Union[str, "lstr"]): The separator to partition on. Returns: - Tuple["lstr", "lstr", "lstr"]: A tuple of three lstr instances containing the content before the separator, the separator itself, and the content after the separator, with the _origin_trace(s) updated accordingly. + Tuple["lstr", "lstr", "lstr"]: A tuple of three lstr instances containing the content before the separator, the separator itself, and the content after the separator, with theorigin_trace(s) updated accordingly. """ part1, part2, part3 = method(sep) new__origin_trace__ = ( @@ -443,15 +451,15 @@ def _split_helper( maxsplit (SupportsIndex, optional): The maximum number of splits to perform. Defaults to -1. Returns: - List["lstr"]: A list of lstr instances containing the split content, with the _origin_trace(s) preserved. + List["lstr"]: A list of lstr instances containing the split content, with theorigin_trace(s) preserved. """ - _origin_traces = ( + origin_traces = ( self.__origin_trace__ | sep.__origin_trace__ if isinstance(sep, _lstr) else self.__origin_trace__ ) parts = method(sep, maxsplit) - return [_lstr(part, None, _origin_traces) for part in parts] + return [_lstr(part, None, origin_traces) for part in parts] if __name__ == "__main__": @@ -460,42 +468,42 @@ def _split_helper( import string def generate_random_string(length): - return ''.join(random.choices(string.ascii_letters + string.digits, k=length)) + return "".join(random.choices(string.ascii_letters + string.digits, k=length)) def test_concatenation(): s1 = generate_random_string(1000) s2 = generate_random_string(1000) - + lstr_time = timeit.timeit(lambda: _lstr(s1) + _lstr(s2), number=10000) str_time = timeit.timeit(lambda: s1 + s2, number=10000) - + print(f"Concatenation: lstr: {lstr_time:.6f}s, str: {str_time:.6f}s") def test_slicing(): s = generate_random_string(10000) ls = _lstr(s) - + lstr_time = timeit.timeit(lambda: ls[1000:2000], number=10000) str_time = timeit.timeit(lambda: s[1000:2000], number=10000) - + print(f"Slicing: lstr: {lstr_time:.6f}s, str: {str_time:.6f}s") def test_splitting(): s = generate_random_string(10000) ls = _lstr(s) - + lstr_time = timeit.timeit(lambda: ls.split(), number=1000) str_time = timeit.timeit(lambda: s.split(), number=1000) - + print(f"Splitting: lstr: {lstr_time:.6f}s, str: {str_time:.6f}s") def test_joining(): words = [generate_random_string(10) for _ in range(1000)] lwords = [_lstr(word) for word in words] - - lstr_time = timeit.timeit(lambda: _lstr(' ').join(lwords), number=1000) - str_time = timeit.timeit(lambda: ' '.join(words), number=1000) - + + lstr_time = timeit.timeit(lambda: _lstr(" ").join(lwords), number=1000) + str_time = timeit.timeit(lambda: " ".join(words), number=1000) + print(f"Joining: lstr: {lstr_time:.6f}s, str: {str_time:.6f}s") print("Running performance tests...") @@ -513,7 +521,7 @@ def test_add(): s2 = generate_random_string(1000) ls1 = _lstr(s1, None, "origin1") ls2 = _lstr(s2, None, "origin2") - + for _ in range(100000): result = ls1 + ls2 @@ -522,14 +530,8 @@ def test_add(): profiler.enable() test_add() profiler.disable() - + s = StringIO() - ps = pstats.Stats(profiler, stream=s).sort_stats('cumulative') + ps = pstats.Stats(profiler, stream=s).sort_stats("cumulative") ps.print_stats(20) # Print top 20 lines print(s.getvalue()) - -# if __name__ == "__main__": -# x = lstr("hello") -# y = lstr("world") -# z = x + y -# print(z) \ No newline at end of file diff --git a/tests/test_lstr.py b/tests/test_lstr.py index d2f85ae2..9e8d6209 100644 --- a/tests/test_lstr.py +++ b/tests/test_lstr.py @@ -9,96 +9,96 @@ def test_init(self): s = _lstr("hello") assert str(s) == "hello" # assert s.logits is None - assert s._origin_trace == frozenset() + assert s.origin_trace == frozenset() - # Test initialization with logits and _origin_trace + # Test initialization with logits andorigin_trace # logits = np.array([0.1, 0.2]) - _origin_trace = "model1" - s = _lstr("world", _origin_trace=_origin_trace) # Removed logits parameter + origin_trace = "model1" + s = _lstr("world",origin_trace=origin_trace) # Removed logits parameter assert str(s) == "world" # assert np.array_equal(s.logits, logits) - assert s._origin_trace == frozenset({_origin_trace}) + assert s.origin_trace == frozenset({_origin_trace}) def test_add(self): s1 = _lstr("hello") - s2 = _lstr("world", _origin_trace="model2") + s2 = _lstr("world",origin_trace="model2") assert isinstance(s1 + s2, str) result = s1 + s2 assert str(result) == "helloworld" # assert result.logits is None - assert result._origin_trace == frozenset({"model2"}) + assert result.origin_trace == frozenset({"model2"}) def test_mod(self): s = _lstr("hello %s") result = s % "world" assert str(result) == "hello world" # assert result.logits is None - assert result._origin_trace == frozenset() + assert result.origin_trace == frozenset() def test_mul(self): - s = _lstr("ha", _origin_trace="model3") + s = _lstr("ha",origin_trace="model3") result = s * 3 assert str(result) == "hahaha" # assert result.logits is None - assert result._origin_trace == frozenset({"model3"}) + assert result.origin_trace == frozenset({"model3"}) def test_getitem(self): s = _lstr( - "hello", _origin_trace="model4" + "hello",origin_trace="model4" ) # Removed logits parameter result = s[1:4] assert str(result) == "ell" # assert result.logits is None - assert result._origin_trace == frozenset({"model4"}) + assert result.origin_trace == frozenset({"model4"}) def test_upper(self): - # Test upper method without _origin_trace and logits + # Test upper method withoutorigin_trace and logits s = _lstr("hello") result = s.upper() assert str(result) == "HELLO" # assert result.logits is None - assert result._origin_trace == frozenset() + assert result.origin_trace == frozenset() - # Test upper method with _origin_trace - s = _lstr("world", _origin_trace="model11") + # Test upper method withorigin_trace + s = _lstr("world",origin_trace="model11") result = s.upper() assert str(result) == "WORLD" # assert result.logits is None - assert result._origin_trace == frozenset({"model11"}) + assert result.origin_trace == frozenset({"model11"}) def test_join(self): - s = _lstr(", ", _origin_trace="model5") - parts = [_lstr("hello"), _lstr("world", _origin_trace="model6")] + s = _lstr(", ",origin_trace="model5") + parts = [_lstr("hello"), _lstr("world",origin_trace="model6")] result = s.join(parts) assert str(result) == "hello, world" # assert result.logits is None - assert result._origin_trace == frozenset({"model5", "model6"}) + assert result.origin_trace == frozenset({"model5", "model6"}) def test_split(self): - s = _lstr("hello world", _origin_trace="model7") + s = _lstr("hello world",origin_trace="model7") parts = s.split() assert [str(p) for p in parts] == ["hello", "world"] # assert all(p.logits is None for p in parts) - assert all(p._origin_trace == frozenset({"model7"}) for p in parts) + assert all(p.origin_trace == frozenset({"model7"}) for p in parts) def test_partition(self): - s = _lstr("hello, world", _origin_trace="model8") + s = _lstr("hello, world",origin_trace="model8") part1, sep, part2 = s.partition(", ") assert (str(part1), str(sep), str(part2)) == ("hello", ", ", "world") # assert all(p.logits is None for p in (part1, sep, part2)) - assert all(p._origin_trace == frozenset({"model8"}) for p in (part1, sep, part2)) + assert all(p.origin_trace == frozenset({"model8"}) for p in (part1, sep, part2)) def test_formatting(self): s = _lstr("Hello {}!") - filled = s.format(_lstr("world", _origin_trace="model9")) + filled = s.format(_lstr("world",origin_trace="model9")) assert str(filled) == "Hello world!" # assert filled.logits is None - assert filled._origin_trace == frozenset({"model9"}) + assert filled.origin_trace == frozenset({"model9"}) def test_repr(self): - s = _lstr("test", _origin_trace="model10") # Removed logits parameter + s = _lstr("test",origin_trace="model10") # Removed logits parameter assert "test" in repr(s) - assert "model10" in repr(s._origin_trace) + assert "model10" in repr(s.origin_trace) # Run the tests diff --git a/tests/test_openai_provider.py b/tests/test_openai_provider.py index e8ff2916..2938b445 100644 --- a/tests/test_openai_provider.py +++ b/tests/test_openai_provider.py @@ -94,7 +94,7 @@ def test_call_model(mock_openai_client): @ell.tool() def dummy_tool(param1: str, param2: int): pass - result = OpenAIProvider.call_model(mock_openai_client, "gpt-3.5-turbo", messages, api_params, tools=[dummy_tool]) + result = OpenAIProvider.call(mock_openai_client, "gpt-3.5-turbo", messages, api_params, tools=[dummy_tool]) assert isinstance(result, APICallResult) assert not "stream" in result.final_call_params From 32336a6cbd990d3adf25ecbf7c596a6612611e9b Mon Sep 17 00:00:00 2001 From: William Guss Date: Fri, 20 Sep 2024 13:33:11 -0700 Subject: [PATCH 08/17] more redable decorators. --- src/ell/lmp/tool.py | 230 ++++++++++++++++++++++---------------------- 1 file changed, 114 insertions(+), 116 deletions(-) diff --git a/src/ell/lmp/tool.py b/src/ell/lmp/tool.py index fbe5bae5..c40fae7e 100644 --- a/src/ell/lmp/tool.py +++ b/src/ell/lmp/tool.py @@ -17,119 +17,6 @@ def tool(*, exempt_from_tracking: bool = False, **tool_kwargs): - """ - Defines a tool for use in language model programs (LMPs) that support tool use. - - This decorator wraps a function, adding metadata and handling for tool invocations. - It automatically extracts the tool's description and parameters from the function's - docstring and type annotations, creating a structured representation for LMs to use. - - :param exempt_from_tracking: If True, the tool usage won't be tracked. Default is False. - :type exempt_from_tracking: bool - :param tool_kwargs: Additional keyword arguments for tool configuration. - :return: A wrapped version of the original function, usable as a tool by LMs. - :rtype: Callable - - Requirements: - - - Function must have fully typed arguments (Pydantic-serializable). - - Return value must be one of: str, JSON-serializable object, Pydantic model, or List[ContentBlock]. - - All parameters must have type annotations. - - Complex types should be Pydantic models. - - Function should have a descriptive docstring. - - Can only be used in LMPs with @ell.complex decorators - - Functionality: - - 1. Metadata Extraction: - - Uses function docstring as tool description. - - Extracts parameter info from type annotations and docstring. - - Creates a Pydantic model for parameter validation and schema generation. - - 2. Integration with LMs: - - Can be passed to @ell.complex decorators. - - Provides structured tool information to LMs. - - 3. Invocation Handling: - - Manages tracking, logging, and result processing. - - Wraps results in appropriate types (e.g., _lstr) for tracking. - - Usage Modes: - - 1. Normal Function Call: - - Behaves like a regular Python function. - - Example: result = my_tool(arg1="value", arg2=123) - - 2. LMP Tool Call: - - Used within LMPs or with explicit _tool_call_id. - - Returns a ToolResult object. - - Example: result = my_tool(arg1="value", arg2=123, _tool_call_id="unique_id") - - Result Coercion: - - - String → ContentBlock(text=result) - - Pydantic BaseModel → ContentBlock(parsed=result) - - List[ContentBlock] → Used as-is - - Other types → ContentBlock(text=json.dumps(result)) - - Example:: - - @ell.tool() - def create_claim_draft( - claim_details: str, - claim_type: str, - claim_amount: float, - claim_date: str = Field(description="Date format: YYYY-MM-DD") - ) -> str: - '''Create a claim draft. Returns the created claim ID.''' - return "12345" - - # For use in a complex LMP: - @ell.complex(model="gpt-4", tools=[create_claim_draft], temperature=0.1) - def insurance_chatbot(message_history: List[Message]) -> List[Message]: - # Chatbot implementation... - - x = insurance_chatbot([ - ell.user("I crashed my car into a tree."), - ell.assistant("I'm sorry to hear that. Can you provide more details?"), - ell.user("The car is totaled and I need to file a claim. Happened on 2024-08-01. total value is like $5000") - ]) - print(x) - '''ell.Message(content=[ - ContentBlock(tool_call( - tool_call_id="asdas4e", - tool_fn=create_claim_draft, - input=create_claim_draftParams({ - claim_details="The car is totaled and I need to file a claim. Happened on 2024-08-01. total value is like $5000", - claim_type="car", - claim_amount=5000, - claim_date="2024-08-01" - }) - )) - ], role='assistant')''' - - if x.tool_calls: - next_user_message = response_message.call_tools_and_collect_as_message() - # This actually calls create_claim_draft - print(next_user_message) - ''' - ell.Message(content=[ - ContentBlock(tool_result=ToolResult( - tool_call_id="asdas4e", - result=[ContentBlock(text="12345")] - )) - ], role='user') - ''' - y = insurance_chatbot(message_history + [x, next_user_message]) - print(y) - ''' - ell.Message("I've filed that for you!", role='assistant') - ''' - - Note: - - Tools are integrated into LMP calls via the 'tools' parameter in @ell.complex. - - LMs receive structured tool information, enabling understanding and usage within the conversation context. - """ def tool_decorator(fn: Callable[..., Any]) -> InvocableTool: # color = compute_color(fn) _under_fn = fn @@ -145,8 +32,6 @@ def wrapper( #XXX: Post release, we need to wrap all tool arguments in type primitives for tracking I guess or change that tool makes the tool function inoperable. #XXX: Most people are not going to manually try and call the tool without a type primitive and if they do it will most likely be wrapped with l strs. - # assert exempt_from_tracking or _invocation_origin is not None, "Invocation origin is required when using a tracked Tool" - # Do nice logging hooks here. if config.verbose and not exempt_from_tracking: pass @@ -197,7 +82,6 @@ def wrapper( sig = inspect.signature(fn) - # Create a Pydantic model from the function signature # 2. Create a dictionary of field definitions for the Pydantic model fields = {} for param_name, param in sig.parameters.items(): @@ -246,3 +130,117 @@ def get_params_model(): return ret return tool_decorator + + +tool.__doc__ = """Defines a tool for use in language model programs (LMPs) that support tool use. + +This decorator wraps a function, adding metadata and handling for tool invocations. +It automatically extracts the tool's description and parameters from the function's +docstring and type annotations, creating a structured representation for LMs to use. + +:param exempt_from_tracking: If True, the tool usage won't be tracked. Default is False. +:type exempt_from_tracking: bool +:param tool_kwargs: Additional keyword arguments for tool configuration. +:return: A wrapped version of the original function, usable as a tool by LMs. +:rtype: Callable + +Requirements: + +- Function must have fully typed arguments (Pydantic-serializable). +- Return value must be one of: str, JSON-serializable object, Pydantic model, or List[ContentBlock]. +- All parameters must have type annotations. +- Complex types should be Pydantic models. +- Function should have a descriptive docstring. +- Can only be used in LMPs with @ell.complex decorators + +Functionality: + +1. Metadata Extraction: + - Uses function docstring as tool description. + - Extracts parameter info from type annotations and docstring. + - Creates a Pydantic model for parameter validation and schema generation. + +2. Integration with LMs: + - Can be passed to @ell.complex decorators. + - Provides structured tool information to LMs. + +3. Invocation Handling: + - Manages tracking, logging, and result processing. + - Wraps results in appropriate types (e.g., _lstr) for tracking. + +Usage Modes: + +1. Normal Function Call: + - Behaves like a regular Python function. + - Example: result = my_tool(arg1="value", arg2=123) + +2. LMP Tool Call: + - Used within LMPs or with explicit _tool_call_id. + - Returns a ToolResult object. + - Example: result = my_tool(arg1="value", arg2=123, _tool_call_id="unique_id") + +Result Coercion: + +- String → ContentBlock(text=result) +- Pydantic BaseModel → ContentBlock(parsed=result) +- List[ContentBlock] → Used as-is +- Other types → ContentBlock(text=json.dumps(result)) + +Example:: + + @ell.tool() + def create_claim_draft( + claim_details: str, + claim_type: str, + claim_amount: float, + claim_date: str = Field(description="Date format: YYYY-MM-DD") + ) -> str: + '''Create a claim draft. Returns the created claim ID.''' + return "12345" + + # For use in a complex LMP: + @ell.complex(model="gpt-4", tools=[create_claim_draft], temperature=0.1) + def insurance_chatbot(message_history: List[Message]) -> List[Message]: + # Chatbot implementation... + + x = insurance_chatbot([ + ell.user("I crashed my car into a tree."), + ell.assistant("I'm sorry to hear that. Can you provide more details?"), + ell.user("The car is totaled and I need to file a claim. Happened on 2024-08-01. total value is like $5000") + ]) + print(x) + '''ell.Message(content=[ + ContentBlock(tool_call( + tool_call_id="asdas4e", + tool_fn=create_claim_draft, + input=create_claim_draftParams({ + claim_details="The car is totaled and I need to file a claim. Happened on 2024-08-01. total value is like $5000", + claim_type="car", + claim_amount=5000, + claim_date="2024-08-01" + }) + )) + ], role='assistant')''' + + if x.tool_calls: + next_user_message = response_message.call_tools_and_collect_as_message() + # This actually calls create_claim_draft + print(next_user_message) + ''' + ell.Message(content=[ + ContentBlock(tool_result=ToolResult( + tool_call_id="asdas4e", + result=[ContentBlock(text="12345")] + )) + ], role='user') + ''' + y = insurance_chatbot(message_history + [x, next_user_message]) + print(y) + ''' + ell.Message("I've filed that for you!", role='assistant') + ''' + +Note: +- Tools are integrated into LMP calls via the 'tools' parameter in @ell.complex. +- LMs receive structured tool information, enabling understanding and usage within the conversation context. + """ \ No newline at end of file From e3f24fcb6ffcb74e6863eb3725168e33b570c46b Mon Sep 17 00:00:00 2001 From: William Guss Date: Fri, 20 Sep 2024 16:56:57 -0700 Subject: [PATCH 09/17] openai provider --- src/ell/lmp/complex.py | 5 - src/ell/lmp/tool.py | 5 +- src/ell/provider.py | 134 ++++++++----- src/ell/providers/__init__.py | 2 +- src/ell/providers/openai.py | 368 +++++++++++++--------------------- src/ell/types/message.py | 7 + 6 files changed, 242 insertions(+), 279 deletions(-) diff --git a/src/ell/lmp/complex.py b/src/ell/lmp/complex.py index d5aa3880..49b7f574 100644 --- a/src/ell/lmp/complex.py +++ b/src/ell/lmp/complex.py @@ -116,11 +116,6 @@ def _client_for_model( if client is None: raise ValueError(f"No client found for model '{model}'. Ensure the model is registered using 'register_model' in 'config.py' or specify a client directly using the 'client' argument in the decorator or function call.") - - # compatibility with local models necessetates no api key. - # if not client.api_key: - # raise RuntimeError(_no_api_key_warning(model, _name, client, long=True, error=True)) - return client diff --git a/src/ell/lmp/tool.py b/src/ell/lmp/tool.py index c40fae7e..622daece 100644 --- a/src/ell/lmp/tool.py +++ b/src/ell/lmp/tool.py @@ -18,7 +18,6 @@ def tool(*, exempt_from_tracking: bool = False, **tool_kwargs): def tool_decorator(fn: Callable[..., Any]) -> InvocableTool: - # color = compute_color(fn) _under_fn = fn @wraps(fn) @@ -28,11 +27,9 @@ def wrapper( _tool_call_id: str = None, **fn_kwargs ): - #XXX: Post release, we need to wrap all tool arguments in type primitives for tracking I guess or change that tool makes the tool function inoperable. #XXX: Most people are not going to manually try and call the tool without a type primitive and if they do it will most likely be wrapped with l strs. - - + if config.verbose and not exempt_from_tracking: pass # tool_usage_logger_pre(fn, fn_args, fn_kwargs, name, color) diff --git a/src/ell/provider.py b/src/ell/provider.py index 2f0653ac..aabfc8d9 100644 --- a/src/ell/provider.py +++ b/src/ell/provider.py @@ -3,7 +3,19 @@ from functools import lru_cache import inspect from types import MappingProxyType -from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple, Type, TypedDict, Union +from typing import ( + Any, + Callable, + Dict, + FrozenSet, + List, + Optional, + Set, + Tuple, + Type, + TypedDict, + Union, +) from pydantic import BaseModel, ConfigDict, Field from ell.types import Message, ContentBlock, ToolCall @@ -13,23 +25,29 @@ from ell.types.message import LMP -# XXX: Might leave this internal to providers so that the complex code is simpler & +# XXX: Might leave this internal to providers so that the complex code is simpler & # we can literally jsut call provider.call like any openai fn. class EllCallParams(BaseModel): model: str = Field(..., description="Model identifier") messages: List[Message] = Field(..., description="Conversation context") client: Any = Field(..., description="API client") tools: Optional[List[LMP]] = Field(None, description="Available tools") - api_params: Dict[str, Any] = Field(default_factory=dict, description="API parameters") - + api_params: Dict[str, Any] = Field( + default_factory=dict, description="API parameters" + ) model_config = ConfigDict(arbitrary_types_allowed=True) + def get_tool_by_name(self, name: str) -> Optional[LMP]: + """Get a tool by name.""" + return next( + (tool for tool in (self.tools or []) if tool.__name__ == name), None + ) + -class Metadata(TypedDict): - """First class metadata so that ell studio can work, you can add more stuff here if you want""" - -#XXX: Needs a better name. +Metadata = Dict[str, Any] + +# XXX: Needs a better name. class Provider(ABC): """ Abstract base class for all providers. Providers are API interfaces to language models, not necessarily API providers. @@ -41,94 +59,120 @@ class Provider(ABC): ### API PARAMETERS ############# ################################ @abstractmethod - def provider_call_function(self, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]: + def provider_call_function( + self, api_call_params: Optional[Dict[str, Any]] = None + ) -> Callable[..., Any]: """ Implement this method to return the function that makes the API call to the language model. For example, if you're implementing the OpenAI provider, you would return the function that makes the API call to OpenAI's API. """ return NotImplemented - def disallowed_api_params(self) -> FrozenSet[str]: """ Returns a list of disallowed call params that ell will override. """ - return frozenset({"messages", "tools", "model"}) + return frozenset({"messages", "tools", "model", "stream", "stream_options"}) - def available_api_params(self, api_params : Optional[Dict[str, Any]] = None): + def available_api_params(self, api_params: Optional[Dict[str, Any]] = None): params = _call_params(self.provider_call_function(api_params)) return frozenset(params.keys()) - self.disallowed_api_params() - ################################ ### TRANSLATION ############### ################################ @abstractmethod - def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]: + def translate_to_provider(self, ell_call: EllCallParams) -> Dict[str, Any]: """Converts an ell call to provider call params!""" return NotImplemented - + @abstractmethod - def translate_from_provider(self, provider_response : Any, ell_call : EllCallParams, origin_id : Optional[str] = None, logger : Optional[Callable[[str], None]] = None) -> Tuple[List[Message], Metadata]: - """Converts provider responses to universal format.""" + def translate_from_provider( + self, + provider_response: Any, + ell_call: EllCallParams, + provider_call_params: Dict[str, Any], + origin_id: Optional[str] = None, + logger: Optional[Callable[..., None]] = None, + ) -> Tuple[List[Message], Metadata]: + """Converts provider responses to universal format. with metadata""" return NotImplemented ################################ ### CALL MODEL ################ ################################ # Be careful to override this method in your provider. - def call(self, ell_call : EllCallParams, origin_id : Optional[str] = None, logger : Optional[Any] = None) -> Tuple[List[Message], Dict[str, Any], Metadata]: + def call( + self, + ell_call: EllCallParams, + origin_id: Optional[str] = None, + logger: Optional[Any] = None, + ) -> Tuple[List[Message], Dict[str, Any], Metadata]: # Automatic validation of params - assert ell_call.api_params.keys() not in self.disallowed_api_params(), f"Disallowed parameters: {ell_call.api_params}" + assert ( + ell_call.api_params.keys() not in self.disallowed_api_params() + ), f"Disallowed api parameters: {ell_call.api_params}" - # Call - - final_api_call_params = self.translate_to_provider(ell_call) + call = self.provider_call_function(final_api_call_params) _validate_provider_call_params(final_api_call_params, call) provider_resp = call(final_api_call_params)(**final_api_call_params) - messages, metadata = self.translate_from_provider(provider_resp, ell_call, origin_id, logger) + messages, metadata = self.translate_from_provider( + provider_resp, ell_call, final_api_call_params, origin_id, logger + ) + assert "choices" not in metadata, "choices should be in the metadata." _validate_messages_are_tracked(messages, origin_id) - - # TODO: Validate messages are tracked. - return messages, final_api_call_params, metadata - - + return messages, final_api_call_params, metadata # handhold the the implementer, in production mode we can turn these off for speed. @lru_cache(maxsize=None) -def _call_params(call : Callable[..., Any]) -> MappingProxyType[str, inspect.Parameter]: +def _call_params(call: Callable[..., Any]) -> MappingProxyType[str, inspect.Parameter]: return inspect.signature(call).parameters -def _validate_provider_call_params(api_call_params: Dict[str, Any], call : Callable[..., Any]): + +def _validate_provider_call_params( + api_call_params: Dict[str, Any], call: Callable[..., Any] +): provider_call_params = _call_params(call) - + required_params = { - name: param for name, param in provider_call_params.items() + name: param + for name, param in provider_call_params.items() if param.default == param.empty and param.kind != param.VAR_KEYWORD } - + for param_name in required_params: - assert param_name in api_call_params, f"Provider implementation error: Required parameter '{param_name}' is missing in the converted call parameters converted from ell call." - + assert ( + param_name in api_call_params + ), f"Provider implementation error: Required parameter '{param_name}' is missing in the converted call parameters converted from ell call." + for param_name, param_value in api_call_params.items(): - assert param_name in provider_call_params, f"Provider implementation error: Unexpected parameter '{param_name}' in the converted call parameters." - + assert ( + param_name in provider_call_params + ), f"Provider implementation error: Unexpected parameter '{param_name}' in the converted call parameters." + param_type = provider_call_params[param_name].annotation if param_type != inspect.Parameter.empty: - assert isinstance(param_value, param_type), f"Provider implementation error: Parameter '{param_name}' should be of type {param_type}." - + assert isinstance( + param_value, param_type + ), f"Provider implementation error: Parameter '{param_name}' should be of type {param_type}." -def _validate_messages_are_tracked(messages: List[Message], origin_id: Optional[str] = None): - if origin_id is None: return - - for message in messages: - assert isinstance(message.text, _lstr), f"Provider implementation error: Message text should be an instance of _lstr, got {type(message.text)}" - assert message.text._or == origin_id, f"Provider implementation error: Message origin_id {message.text.origin_id} does not match the provided origin_id {origin_id}" +def _validate_messages_are_tracked( + messages: List[Message], origin_id: Optional[str] = None +): + if origin_id is None: + return + for message in messages: + assert isinstance( + message.text, _lstr + ), f"Provider implementation error: Message text should be an instance of _lstr, got {type(message.text)}" + assert ( + message.text.origin_id == origin_id + ), f"Provider implementation error: Message origin_id {message.text.origin_id} does not match the provided origin_id {origin_id}" diff --git a/src/ell/providers/__init__.py b/src/ell/providers/__init__.py index 763dfc07..90ec14c3 100644 --- a/src/ell/providers/__init__.py +++ b/src/ell/providers/__init__.py @@ -1,5 +1,5 @@ import ell.providers.openai -import ell.providers.anthropic +# import ell.providers.anthropic # import ell.providers.groq # import ell.providers.mistral # import ell.providers.cohere diff --git a/src/ell/providers/openai.py b/src/ell/providers/openai.py index 30797274..62e1aa43 100644 --- a/src/ell/providers/openai.py +++ b/src/ell/providers/openai.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple, Type, Union -from ell.provider import APICallResult, Provider +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast +from ell.provider import APICallResult, EllCallParams, Metadata, Provider from ell.types import Message, ContentBlock, ToolCall from ell.types._lstr import _lstr import json @@ -10,241 +10,161 @@ from ell.util.serialization import serialize_image try: + # XXX: Could genericize. import openai + from openai._streaming import Stream + from openai.types.chat import ChatCompletion, ParsedChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam class OpenAIProvider(Provider): - - # XXX: This content block conversion etc might need to happen on a per model basis for providers like groq etc. We will think about this at a future date. - @staticmethod - def content_block_to_openai_format(content_block: ContentBlock) -> Dict[str, Any]: - if content_block.image: - base64_image = serialize_image(content_block.image) - image_url = {"url": base64_image} - - # add detail only if supplied by user - # OpenAI's default is "auto", we omit the "detail" key entirely if not provided by user - if content_block.image_detail: - image_url["detail"] = content_block.image_detail - - return { - "type": "image_url", - "image_url": image_url - } - elif content_block.text: - return { - "type": "text", - "text": content_block.text - } - elif content_block.parsed: - return { - "type": "text", - "text": content_block.parsed.model_dump_json() - } - # Tool calls handled in message_to_openai_format. - #XXX: Feel free to refactor this. + def provider_call_function(self, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]: + if api_call_params and api_call_params.get("response_format"): + return openai.beta.chat.completions.parse else: - return None - - @staticmethod - def message_to_openai_format(message: Message) -> Dict[str, Any]: - openai_message = { - "role": "tool" if message.tool_results else message.role, - "content": list(filter(None, [ - OpenAIProvider.content_block_to_openai_format(c) for c in message.content - ])) - } - if message.tool_calls: - try: - openai_message["tool_calls"] = [ - { - "id": tool_call.tool_call_id, - "type": "function", - "function": { - "name": tool_call.tool.__name__, - "arguments": json.dumps(tool_call.params.model_dump()) - } - } for tool_call in message.tool_calls - ] - except TypeError as e: - print(f"Error serializing tool calls: {e}. Did you fully type your @ell.tool decorated functions?") - raise - openai_message["content"] = None # Set content to null when there are tool calls - - if message.tool_results: - openai_message["tool_call_id"] = message.tool_results[0].tool_call_id - openai_message["content"] = message.tool_results[0].result[0].text - assert len(message.tool_results[0].result) == 1, "Tool result should only have one content block" - assert message.tool_results[0].result[0].type == "text", "Tool result should only have one text content block" - return openai_message - - @classmethod - def call( - cls, - client: Any, - model: str, - messages: List[Message], - api_params: Dict[str, Any], - tools: Optional[list[LMP]] = None, - ) -> APICallResult: - final_call_params = api_params.copy() - openai_messages = [cls.message_to_openai_format(message) for message in messages] - - actual_n = api_params.get("n", 1) - final_call_params["model"] = model - final_call_params["messages"] = openai_messages - - if model == "o1-preview" or model == "o1-mini": - # Ensure no system messages are present - assert all(msg['role'] != 'system' for msg in final_call_params['messages']), "System messages are not allowed for o1-preview or o1-mini models" - - response = client.chat.completions.create(**final_call_params) - final_call_params.pop("stream", None) - final_call_params.pop("stream_options", None) - - - elif final_call_params.get("response_format"): + return openai.chat.completions.create + + def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]: + final_call_params = ell_call.api_params.copy() + final_call_params["model"] = ell_call.model + # Stream by default for verbose logging. + final_call_params["stream"] = True + final_call_params["stream_options"] = {"include_usage": True} + + # XXX: Deprecation of config.registry.supports_streaming when streaming is implemented. + if final_call_params.get("response_format") or config.registry[ell_call.model].supports_streaming is False or ell_call.tools: final_call_params.pop("stream", None) final_call_params.pop("stream_options", None) - response = client.beta.chat.completions.parse(**final_call_params) - else: - # Tools not workign with structured API - if tools: - final_call_params["tool_choice"] = "auto" - final_call_params["tools"] = [ - { - "type": "function", - "function": { - "name": tool.__name__, - "description": tool.__doc__, - "parameters": tool.__ell_params_model__.model_json_schema(), - }, - } - for tool in tools + if ell_call.tools: + final_call_params.update( + tool_choice="auto", + tools=[ + dict( + type="function", + function=dict( + name=tool.__name__, + description=tool.__doc__, + parameters=tool.__ell_params_model__.model_json_schema(), #type: ignore + ) + ) for tool in ell_call.tools ] - final_call_params.pop("stream", None) - final_call_params.pop("stream_options", None) + ) + # messages + openai_messages : List[ChatCompletionMessageParam] = [] + for message in ell_call.messages: + if (tool_calls := message.tool_calls): + assert message.role == "assistant", "Tool calls must be from the assistant." + assert all(t.tool_call_id for t in tool_calls), "Tool calls must have tool call ids." + openai_messages.append(dict( + tool_calls=[ + dict( + id=cast(str, tool_call.tool_call_id), + type="function", + function=dict( + name=tool_call.tool.__name__, + arguments=json.dumps(tool_call.params.model_dump()) + ) + ) for tool_call in tool_calls ], + role="assistant", + content=None, + )) + elif (tool_results := message.tool_results): + assert len(tool_results) == 1, "Message should only have one tool result" + assert (tr_content := tool_results[0].result[0]).type == "text", "Tool result should only have one text content block" + openai_messages.append(dict( + role="tool", + tool_call_id=tool_results[0].tool_call_id, + content=cast(str, tr_content.text), + )) else: - final_call_params["stream_options"] = {"include_usage": True} - final_call_params["stream"] = True - - response = client.chat.completions.create(**final_call_params) + openai_messages.append(cast(ChatCompletionMessageParam, dict( + role=message.role, + content=[content_block_to_openai_format(c) for c in message.content] + ))) + final_call_params["messages"] = openai_messages - - return APICallResult( - response=response, - actual_streaming=isinstance(response, openai.Stream), - actual_n=actual_n, - final_call_params=final_call_params, - ) - - @classmethod - def process_response( - cls, call_result: APICallResult, _invocation_origin: str, logger : Optional[Any] = None, tools: Optional[List[LMP]] = None, - ) -> Tuple[List[Message], Dict[str, Any]]: - choices_progress = defaultdict(list) - api_params = call_result.final_call_params - metadata = {} - #XXX: Remove logger and refactor this API - - if not call_result.actual_streaming: - response = [call_result.response] - else: - response = call_result.response + return final_call_params + + def translate_from_provider( + self, + provider_response: Union[ + ChatCompletion, + ParsedChatCompletion, + Stream[ChatCompletionChunk], Any], + ell_call: EllCallParams, + provider_call_params: Dict[str, Any], + origin_id: Optional[str] = None, + logger: Optional[Callable[..., None]] = None, + ) -> Tuple[List[Message], Metadata]: - - for chunk in response: - if hasattr(chunk, "usage") and chunk.usage: - metadata = chunk.to_dict() - - - for choice in chunk.choices: - choices_progress[choice.index].append(choice) - - if choice.index == 0 and logger: - # print(choice, streaming) - logger(choice.delta.content if call_result.actual_streaming else - choice.message.content or getattr(choice.message, "refusal", ""), is_refusal=getattr(choice.message, "refusal", False) if not call_result.actual_streaming else False) - - - - tracked_results = [] - for _, choice_deltas in sorted(choices_progress.items(), key=lambda x: x[0]): - content = [] - - if call_result.actual_streaming: - text_content = "".join( - (choice.delta.content or "" for choice in choice_deltas) - ) - if text_content: - content.append( - ContentBlock( - text=_lstr( - content=text_content,origin_trace=_invocation_origin - ) - ) - ) - else: - choice = choice_deltas[0].message - if choice.refusal: - raise ValueError(choice.refusal) - if api_params.get("response_format"): - content.append(ContentBlock(parsed=choice.parsed)) - elif choice.content: - content.append( - ContentBlock( - text=_lstr( - content=choice.content,origin_trace=_invocation_origin - ) - ) - ) - - if not call_result.actual_streaming and hasattr(choice, "tool_calls") and choice.tool_calls: - assert tools is not None, "Tools not provided, yet tool calls in response. Did you manually specify a tool spec without using ell.tool?" - for tool_call in choice.tool_calls: - matching_tool = next( - ( - tool - for tool in tools - if tool.__name__ == tool_call.function.name - ), - None, - ) - if matching_tool: - params = matching_tool.__ell_params_model__( - **json.loads(tool_call.function.arguments) - ) - content.append( + metadata : Metadata = {} + messages : List[Message] = [] + did_stream = provider_call_params.get("stream", False) + + if did_stream: + stream = cast(Stream[ChatCompletionChunk], provider_response) + message_streams = defaultdict(list) + role : Optional[str] = None + for chunk in stream: + if hasattr(chunk, "usage") and chunk.usage: metadata.update(chunk.model_dump(exclude={"choices"})) + for chat_compl_chunk in chunk.choices: + message_streams[chat_compl_chunk.index].append(chat_compl_chunk) + role = role or (delta := chat_compl_chunk.delta).role + if chat_compl_chunk.index == 0 and logger: + logger(delta.content, is_refusal=delta.refusal) + for _, message_stream in sorted(message_streams.items(), key=lambda x: x[0]): + text = "".join((choice.delta.content or "") for choice in message_stream) + messages.append( + Message(role=role, + content=_lstr(content=text,origin_trace=origin_id))) + #XXX: Support streaming other types. + else: + chat_completion = cast(Union[ChatCompletion, ParsedChatCompletion], provider_response) + metadata = chat_completion.model_dump(exclude={"choices"}) + for oai_choice in chat_completion.choices: + content_blocks = [] + if (refusal := (message := oai_choice.message).refusal): + raise ValueError(refusal) + if hasattr(message, "parsed"): + if (parsed := message.parsed): + content_blocks.append(ContentBlock(parsed=parsed)) #XXX: Origin tracing + else: + if (content := message.content): + content_blocks.append( ContentBlock( - tool_call=ToolCall( - tool=matching_tool, - tool_call_id=_lstr( - tool_call.id,origin_trace=_invocation_origin - ), - params=params, + text=_lstr(content=content,origin_trace=origin_id))) + if (tool_calls := message.tool_calls): + for tool_call in tool_calls: + matching_tool = ell_call.get_tool_by_name(tool_call.function.name) + assert matching_tool, "Model called tool not found in provided toolset." + content_blocks.append( + ContentBlock( + tool_call=ToolCall( + tool=matching_tool, + tool_call_id=_lstr( + tool_call.id, origin_trace= origin_id), + params=json.loads(tool_call.function.arguments), + ) ) ) - ) - - tracked_results.append( - Message( - role=( - choice.role - if not call_result.actual_streaming - else choice_deltas[0].delta.role - ), - content=content, - ) - ) - return tracked_results, metadata - - @classmethod - def supports_streaming(cls) -> bool: - return True + messages.append(Message(role=role, content=content_blocks)) + return messages, metadata - @classmethod - def get_client_type(cls) -> Type: - return openai.Client - register_provider(OpenAIProvider) + openai_provider = OpenAIProvider() + register_provider(openai_provider, openai.Client) except ImportError: - pass \ No newline at end of file + pass + + +def content_block_to_openai_format(content_block: ContentBlock) -> Dict[str, Any]: + if (image := content_block.image): + image_url = {"url": serialize_image(image)} + # XXX: Solve per content params better + if (image_detail := content_block.image_detail): image_url["detail"] = image_detail + return { + "type": "image_url", + "image_url": image_url + } + elif (text := content_block.text): return dict(type="text", text=text) + elif (parsed := content_block.parsed): return dict(type="text", text=parsed.model_dump_json()) + else: + raise ValueError(f"Unsupported content block type for openai: {content_block}") \ No newline at end of file diff --git a/src/ell/types/message.py b/src/ell/types/message.py index 123189ea..d33e4dae 100644 --- a/src/ell/types/message.py +++ b/src/ell/types/message.py @@ -29,6 +29,11 @@ class ToolCall(BaseModel): tool_call_id : Optional[_lstr_generic] = Field(default=None) params : BaseModel + def __init__(self, tool, tool_call_id, params : Union[BaseModel, Dict[str, Any]]): + if not isinstance(params, BaseModel): + params = tool.__ell_params_model__(**params) #convenience. + super().__init__(tool=tool, tool_call_id=tool_call_id, params=params) + def __call__(self, **kwargs): assert not kwargs, "Unexpected arguments provided. Calling a tool uses the params provided in the ToolCall." @@ -41,6 +46,8 @@ def call_and_collect_as_message_block(self): def call_and_collect_as_message(self): return Message(role="user", content=[self.call_and_collect_as_message_block()]) + + class ContentBlock(BaseModel): From b994128861faaebbc0dcd674af8fb42bd2d61756 Mon Sep 17 00:00:00 2001 From: William Guss Date: Fri, 20 Sep 2024 17:52:24 -0700 Subject: [PATCH 10/17] deprecate lm_params --- docs/ramblings/notes_on_adapters.py | 2 +- docs/src/core_concepts/ell_simple.rst | 10 +++---- .../core_concepts/versioning_and_storage.rst | 4 +-- docs/src/getting_started.rst | 4 +-- examples/git_issue.py | 2 +- examples/multilmp.py | 2 +- src/ell/configurator.py | 13 +++++---- src/ell/lmp/complex.py | 16 ++++++----- src/ell/lmp/simple.py | 2 +- src/ell/provider.py | 28 +++++++++---------- src/ell/providers/anthropic.py | 2 +- src/ell/providers/openai.py | 16 +++++++---- src/ell/types/_lstr.py | 7 +++-- src/ell/types/message.py | 2 +- src/ell/util/_warnings.py | 2 +- tests/test_lmp_to_prompt.py | 6 ++-- 16 files changed, 63 insertions(+), 55 deletions(-) diff --git a/docs/ramblings/notes_on_adapters.py b/docs/ramblings/notes_on_adapters.py index 2ba6bbd2..1cf25162 100644 --- a/docs/ramblings/notes_on_adapters.py +++ b/docs/ramblings/notes_on_adapters.py @@ -81,7 +81,7 @@ class OAILikeProvider(abc.ABC): # inherently you just don't want to fuck around with -""blah(lm_params=dict(client=my_openai_client)) +""blah(api_params=dict(client=my_openai_client)) "" # or even diff --git a/docs/src/core_concepts/ell_simple.rst b/docs/src/core_concepts/ell_simple.rst index 82f12eaa..c5ad62d1 100644 --- a/docs/src/core_concepts/ell_simple.rst +++ b/docs/src/core_concepts/ell_simple.rst @@ -100,11 +100,11 @@ One of the most convenient functions of the ``@ell.simple`` decorator is that yo """You are a helpful assistant.""" return f"Hey there {name}!" -Likewise, if you want to modify those parameters for a particular invocation of that prompt, you simply pass them in as ``lm_params`` keyword arguments to the function when calling it. For example: +Likewise, if you want to modify those parameters for a particular invocation of that prompt, you simply pass them in as ``api_params`` keyword arguments to the function when calling it. For example: .. code-block:: python - >>> hello("world", lm_params=dict(temperature=0.7)) + >>> hello("world", api_params=dict(temperature=0.7)) 'Hey there world!' @@ -139,15 +139,15 @@ In the spirit of simplicity, we've designed it to automatically coerce the retur >>> hello("world") ['Hey there world!', 'Hi, world.'] -Similarly, this behavior applies when using runtime ``lm_params`` to specify multiple outputs. +Similarly, this behavior applies when using runtime ``api_params`` to specify multiple outputs. .. code-block:: python - >>> hello("world", lm_params=dict(n=3)) + >>> hello("world", api_params=dict(n=3)) ['Hey there world!', 'Hi, world.', 'Hello, world!'] -.. note:: In the future, we may modify this interface as preserving the ``lm_params`` keyword in its current form could potentially lead to conflicts with user-defined functions. However, during the beta phase, we are closely monitoring for feedback and will make adjustments based on user experiences and needs. +.. note:: In the future, we may modify this interface as preserving the ``api_params`` keyword in its current form could potentially lead to conflicts with user-defined functions. However, during the beta phase, we are closely monitoring for feedback and will make adjustments based on user experiences and needs. Multimodal inputs diff --git a/docs/src/core_concepts/versioning_and_storage.rst b/docs/src/core_concepts/versioning_and_storage.rst index 0ef4ab9a..1b40470c 100644 --- a/docs/src/core_concepts/versioning_and_storage.rst +++ b/docs/src/core_concepts/versioning_and_storage.rst @@ -118,9 +118,9 @@ In addition, when a language model program depends on another prompt (i.e., when @ell.simple(model="gpt-4-turbo", temperature=0.2) def write_a_really_good_story(about : str): """You are an expert novelist that writes in the style of Hemmingway. You write in lowercase.""" - # Note: You can pass in lm_params to control the language model call + # Note: You can pass in api_params to control the language model call # in the case n = 4 tells OpenAI to generate a batch of 4 outputs. - ideas = generate_story_ideas(about, lm_params=(dict(n=4))) + ideas = generate_story_ideas(about, api_params=(dict(n=4))) drafts = [write_a_draft_of_a_story(idea) for idea in ideas] diff --git a/docs/src/getting_started.rst b/docs/src/getting_started.rst index e4ff36cf..31780046 100644 --- a/docs/src/getting_started.rst +++ b/docs/src/getting_started.rst @@ -162,9 +162,9 @@ Taking this concept further, LMPs can call other LMPs, allowing for more complex @ell.simple(model="gpt-4-turbo", temperature=0.2) def write_a_really_good_story(about : str): """You are an expert novelist that writes in the style of Hemmingway. You write in lowercase.""" - # Note: You can pass in lm_params to control the language model call + # Note: You can pass in api_params to control the language model call # in the case n = 4 tells OpenAI to generate a batch of 4 outputs. - ideas = generate_story_ideas(about, lm_params=(dict(n=4))) + ideas = generate_story_ideas(about, api_params=(dict(n=4))) drafts = [write_a_draft_of_a_story(idea) for idea in ideas] diff --git a/examples/git_issue.py b/examples/git_issue.py index fb796df2..3995c9e5 100644 --- a/examples/git_issue.py +++ b/examples/git_issue.py @@ -87,7 +87,7 @@ def generate_issue( res = fn(*fn_args, **fn_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "d:\\dev\\ell\\examples\\multilmp.py", line 32, in write_a_really_good_story - ideas = generate_story_ideas(about, lm_params=(dict(n=4))) + ideas = generate_story_ideas(about, api_params=(dict(n=4))) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\\dev\\ell\\ell\\src\\ell\\decorators.py", line 216, in wrapper fn_closure, _uses = ell.util.closure.lexically_closured_source(func_to_track) diff --git a/examples/multilmp.py b/examples/multilmp.py index 6a36c850..c2399662 100644 --- a/examples/multilmp.py +++ b/examples/multilmp.py @@ -22,7 +22,7 @@ def choose_the_best_draft(drafts : List[str]): @ell.simple(model="gpt-4-turbo", temperature=0.2) def write_a_really_good_story(about : str): - ideas = generate_story_ideas(about, lm_params=(dict(n=4))) + ideas = generate_story_ideas(about, api_params=(dict(n=4))) drafts = [write_a_draft_of_a_story(idea) for idea in ideas] diff --git a/src/ell/configurator.py b/src/ell/configurator.py index b6be398c..7f1ef62b 100644 --- a/src/ell/configurator.py +++ b/src/ell/configurator.py @@ -36,7 +36,7 @@ class Config(BaseModel): store: Optional[Store] = Field(default=None, description="An optional Store instance for persistence.") autocommit: bool = Field(default=False, description="If True, enables automatic committing of changes to the store.") lazy_versioning: bool = Field(default=True, description="If True, enables lazy versioning for improved performance.") - default_lm_params: Dict[str, Any] = Field(default_factory=dict, description="Default parameters for language models.") + default_api_params: Dict[str, Any] = Field(default_factory=dict, description="Default parameters for language models.") default_client: Optional[openai.Client] = Field(default=None, description="The default OpenAI client used when a specific model client is not found.") providers: Dict[Type, Provider] = Field(default_factory=dict, description="A dictionary mapping client types to provider classes.") def __init__(self, **data): @@ -131,6 +131,7 @@ def get_provider_for(self, client: Union[Type[Any], Any]) -> Optional[Provider]: :return: The provider instance for the specified client, or None if not found. :rtype: Optional[Provider] """ + client_type = type(client) if not isinstance(client, type) else client return self.providers.get(client_type) @@ -143,7 +144,7 @@ def init( verbose: bool = False, autocommit: bool = True, lazy_versioning: bool = True, - default_lm_params: Optional[Dict[str, Any]] = None, + default_api_params: Optional[Dict[str, Any]] = None, default_client: Optional[Any] = None ) -> None: """ @@ -157,8 +158,8 @@ def init( :type autocommit: bool :param lazy_versioning: Enable or disable lazy versioning. :type lazy_versioning: bool - :param default_lm_params: Set default parameters for language models. - :type default_lm_params: Dict[str, Any], optional + :param default_api_params: Set default parameters for language models. + :type default_api_params: Dict[str, Any], optional :param default_openai_client: Set the default OpenAI client. :type default_openai_client: openai.Client, optional """ @@ -173,8 +174,8 @@ def init( config.store = store config.autocommit = autocommit or config.autocommit - if default_lm_params is not None: - config.default_lm_params.update(default_lm_params) + if default_api_params is not None: + config.default_api_params.update(default_api_params) if default_client is not None: config.default_client = default_client diff --git a/src/ell/lmp/complex.py b/src/ell/lmp/complex.py index 49b7f574..8d5c4771 100644 --- a/src/ell/lmp/complex.py +++ b/src/ell/lmp/complex.py @@ -39,11 +39,11 @@ def model_call( # XXX: move should log to a logger. should_log = not exempt_from_tracking and config.verbose # Cute verbose logging. - if should_log: model_usage_logger_pre(prompt, prompt_args, prompt_kwargs, model_call.__ell_hash__, messages) #type: ignore + if should_log: model_usage_logger_pre(prompt, prompt_args, prompt_kwargs, "[]", messages) #type: ignore # Call the model. # Merge API params - merged_api_params = {**config.default_lm_params, **default_api_params_from_decorator, **(api_params or {})} + merged_api_params = {**config.default_api_params, **default_api_params_from_decorator, **(api_params or {})} n = merged_api_params.get("n", 1) # Merge client overrides & client registry merged_client = _client_for_model(model, client or default_client_from_decorator) @@ -51,18 +51,19 @@ def model_call( # XXX: Could change behaviour of overriding ell params for dyanmic tool calls. model=merged_api_params.pop("model", default_model_from_decorator), messages=messages, - client = client or default_client_from_decorator, + client = merged_client, api_params=merged_api_params, - tools=tools, + tools=tools or [], ) # Get the provider for the model - provider = config.get_provider_for(ell_call) - assert provider is not None, f"No provider found for model {ell_call.client}." + provider = config.get_provider_for(ell_call.client) + assert provider is not None, f"No provider found for client {ell_call.client}." if should_log: model_usage_logger_post_start(n) with model_usage_logger_post_intermediate(n) as _logger: (result, final_api_params, metadata) = provider.call(ell_call, origin_id=_invocation_origin, logger=_logger) - + if isinstance(result, list) and len(result) == 1: + result = result[0] result = post_callback(result) if post_callback else result if should_log: model_usage_logger_post_end() @@ -110,6 +111,7 @@ def _client_for_model( # XXX: Move to config to centralize api keys etc. if not client: client, was_fallback = config.get_client_for(model) + # XXX: Wrong. if not client and not was_fallback: raise RuntimeError(_no_api_key_warning(model, _name, '', long=True, error=True)) diff --git a/src/ell/lmp/simple.py b/src/ell/lmp/simple.py index 6a5ae5a0..086b3ade 100644 --- a/src/ell/lmp/simple.py +++ b/src/ell/lmp/simple.py @@ -86,7 +86,7 @@ def generate_story(prompt: str) -> str: story1 = generate_story("A day in the life of a time traveler") # Overriding parameters during function call - story2 = generate_story("An AI's first day of consciousness", lm_params={"temperature": 0.9, "max_tokens": 500}) + story2 = generate_story("An AI's first day of consciousness", api_params={"temperature": 0.9, "max_tokens": 500}) See Also: diff --git a/src/ell/provider.py b/src/ell/provider.py index aabfc8d9..d33bebc8 100644 --- a/src/ell/provider.py +++ b/src/ell/provider.py @@ -31,7 +31,7 @@ class EllCallParams(BaseModel): model: str = Field(..., description="Model identifier") messages: List[Message] = Field(..., description="Conversation context") client: Any = Field(..., description="API client") - tools: Optional[List[LMP]] = Field(None, description="Available tools") + tools: List[LMP] = Field(default_factory=list, description="Available tools") api_params: Dict[str, Any] = Field( default_factory=dict, description="API parameters" ) @@ -54,6 +54,7 @@ class Provider(ABC): For example, the OpenAI provider is an API interface to OpenAI's API but also to Ollama and Azure OpenAI. In Ell. We hate abstractions. The only reason this exists is to force implementers to implement their own provider correctly -_-. """ + dangerous_disable_validation = False ################################ ### API PARAMETERS ############# @@ -110,21 +111,22 @@ def call( ) -> Tuple[List[Message], Dict[str, Any], Metadata]: # Automatic validation of params assert ( - ell_call.api_params.keys() not in self.disallowed_api_params() + not set(ell_call.api_params.keys()).intersection(self.disallowed_api_params()) ), f"Disallowed api parameters: {ell_call.api_params}" final_api_call_params = self.translate_to_provider(ell_call) call = self.provider_call_function(final_api_call_params) - _validate_provider_call_params(final_api_call_params, call) - - provider_resp = call(final_api_call_params)(**final_api_call_params) + assert self.dangerous_disable_validation or _validate_provider_call_params(final_api_call_params, call) + + + provider_resp = call(**final_api_call_params) messages, metadata = self.translate_from_provider( provider_resp, ell_call, final_api_call_params, origin_id, logger ) assert "choices" not in metadata, "choices should be in the metadata." - _validate_messages_are_tracked(messages, origin_id) + assert self.dangerous_disable_validation or _validate_messages_are_tracked(messages, origin_id) return messages, final_api_call_params, metadata @@ -155,13 +157,8 @@ def _validate_provider_call_params( assert ( param_name in provider_call_params ), f"Provider implementation error: Unexpected parameter '{param_name}' in the converted call parameters." - - param_type = provider_call_params[param_name].annotation - if param_type != inspect.Parameter.empty: - assert isinstance( - param_value, param_type - ), f"Provider implementation error: Parameter '{param_name}' should be of type {param_type}." - + + return True def _validate_messages_are_tracked( messages: List[Message], origin_id: Optional[str] = None @@ -174,5 +171,6 @@ def _validate_messages_are_tracked( message.text, _lstr ), f"Provider implementation error: Message text should be an instance of _lstr, got {type(message.text)}" assert ( - message.text.origin_id == origin_id - ), f"Provider implementation error: Message origin_id {message.text.origin_id} does not match the provided origin_id {origin_id}" + origin_id in message.text.__origin_trace__ + ), f"Provider implementation error: Message origin_id {message.text.__origin_trace__} does not match the provided origin_id {origin_id}" + return True diff --git a/src/ell/providers/anthropic.py b/src/ell/providers/anthropic.py index 1505cc05..214b6d25 100644 --- a/src/ell/providers/anthropic.py +++ b/src/ell/providers/anthropic.py @@ -24,7 +24,7 @@ def call( tools: Optional[list[LMP]] = None, ) -> APICallResult: final_call_params = api_params.copy() - assert final_call_params.get("max_tokens") is not None, f"max_tokens is required for anthropic calls, pass it to the @ell.simple/complex decorator, e.g. @ell.simple(..., max_tokens=your_max_tokens) or pass it to the model directly as a parameter when calling your LMP: your_lmp(..., lm_params=({{'max_tokens': your_max_tokens}}))." + assert final_call_params.get("max_tokens") is not None, f"max_tokens is required for anthropic calls, pass it to the @ell.simple/complex decorator, e.g. @ell.simple(..., max_tokens=your_max_tokens) or pass it to the model directly as a parameter when calling your LMP: your_lmp(..., api_params=({{'max_tokens': your_max_tokens}}))." anthropic_messages = [message_to_anthropic_format(message) for message in messages] system_message = None diff --git a/src/ell/providers/openai.py b/src/ell/providers/openai.py index 62e1aa43..064e5e0c 100644 --- a/src/ell/providers/openai.py +++ b/src/ell/providers/openai.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast -from ell.provider import APICallResult, EllCallParams, Metadata, Provider +from ell.provider import EllCallParams, Metadata, Provider from ell.types import Message, ContentBlock, ToolCall from ell.types._lstr import _lstr import json @@ -14,7 +14,10 @@ import openai from openai._streaming import Stream from openai.types.chat import ChatCompletion, ParsedChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam + class OpenAIProvider(Provider): + dangerous_disable_validation = True + def provider_call_function(self, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]: if api_call_params and api_call_params.get("response_format"): return openai.beta.chat.completions.parse @@ -76,7 +79,7 @@ def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]: else: openai_messages.append(cast(ChatCompletionMessageParam, dict( role=message.role, - content=[content_block_to_openai_format(c) for c in message.content] + content=[_content_block_to_openai_format(c) for c in message.content] ))) final_call_params["messages"] = openai_messages @@ -98,15 +101,19 @@ def translate_from_provider( messages : List[Message] = [] did_stream = provider_call_params.get("stream", False) + if did_stream: stream = cast(Stream[ChatCompletionChunk], provider_response) message_streams = defaultdict(list) role : Optional[str] = None for chunk in stream: + if hasattr(chunk, "usage") and chunk.usage: metadata.update(chunk.model_dump(exclude={"choices"})) + for chat_compl_chunk in chunk.choices: message_streams[chat_compl_chunk.index].append(chat_compl_chunk) - role = role or (delta := chat_compl_chunk.delta).role + delta = chat_compl_chunk.delta + role = role or delta.role if chat_compl_chunk.index == 0 and logger: logger(delta.content, is_refusal=delta.refusal) for _, message_stream in sorted(message_streams.items(), key=lambda x: x[0]): @@ -148,14 +155,13 @@ def translate_from_provider( return messages, metadata - openai_provider = OpenAIProvider() register_provider(openai_provider, openai.Client) except ImportError: pass -def content_block_to_openai_format(content_block: ContentBlock) -> Dict[str, Any]: +def _content_block_to_openai_format(content_block: ContentBlock) -> Dict[str, Any]: if (image := content_block.image): image_url = {"url": serialize_image(image)} # XXX: Solve per content params better diff --git a/src/ell/types/_lstr.py b/src/ell/types/_lstr.py index 93d4d582..aa8fb104 100644 --- a/src/ell/types/_lstr.py +++ b/src/ell/types/_lstr.py @@ -298,7 +298,6 @@ def __getattribute__(self, name: str) -> Union[Callable, Any]: return type(self) if callable(attr) and name not in _lstr.__dict__: - def wrapped(*args: Any, **kwargs: Any) -> Any: result = attr(*args, **kwargs) # If the result is a string, return an lstr instance @@ -329,12 +328,14 @@ def join(self, iterable: Iterable[Union[str, "_lstr"]]) -> "_lstr": Returns: lstr: A new lstr instance containing the joined content, with theorigin_trace(s) updated accordingly. """ - parts = [str(item) for item in iterable] - new_content = super(_lstr, self).join(parts) new__origin_trace__ = self.__origin_trace__ + parts = [] for item in iterable: if isinstance(item, _lstr): new__origin_trace__ = new__origin_trace__.union(item.__origin_trace__) + parts.append(item) + new_content = super(_lstr, self).join(parts) + return _lstr(new_content, None, new__origin_trace__) @override diff --git a/src/ell/types/message.py b/src/ell/types/message.py index d33e4dae..0c69d41b 100644 --- a/src/ell/types/message.py +++ b/src/ell/types/message.py @@ -212,7 +212,7 @@ def text(self) -> str: >>> message.text 'Hello\\n\\nWorld' """ - return "\n".join(c.text or f"<{c.type}>" for c in self.content) + return _lstr("\n").join(c.text or f"<{c.type}>" for c in self.content) @cached_property def images(self) -> List[PILImage.Image]: diff --git a/src/ell/util/_warnings.py b/src/ell/util/_warnings.py index 9418fa66..6a183e5e 100644 --- a/src/ell/util/_warnings.py +++ b/src/ell/util/_warnings.py @@ -55,7 +55,7 @@ def {fn.__name__}(...): ell.simple(model, client=my_client)(...) ``` {Style.RESET_ALL}""") - elif (client_to_use := config.registry[model]) is None or not client_to_use.api_key: + elif (client_to_use := config.registry[model].default_client) is None or not client_to_use.api_key: logger.warning(_no_api_key_warning(model, fn.__name__, client_to_use, long=False)) diff --git a/tests/test_lmp_to_prompt.py b/tests/test_lmp_to_prompt.py index 4a4ec0a9..8a5650c0 100644 --- a/tests/test_lmp_to_prompt.py +++ b/tests/test_lmp_to_prompt.py @@ -36,7 +36,7 @@ # def test_lm_decorator_with_params(mock_run_lm): -# result = lmp_with_default_system_prompt("input", lm_params=dict(temperature=0.5)) +# result = lmp_with_default_system_prompt("input", api_params=dict(temperature=0.5)) # mock_run_lm.assert_called_once_with( # model="gpt-4-turbo", @@ -55,7 +55,7 @@ # @patch("ell.util.lm._run_lm") # def test_lm_decorator_with_docstring_system_prompt(mock_run_lm): # mock_run_lm.return_value = ("Mocked content", None) -# result = lmp_with_docstring_system_prompt("input", lm_params=dict(temperature=0.5)) +# result = lmp_with_docstring_system_prompt("input", api_params=dict(temperature=0.5)) # mock_run_lm.assert_called_once_with( # model="gpt-4-turbo", @@ -74,7 +74,7 @@ # @patch("ell.util.lm._run_lm") # def test_lm_decorator_with_msg_fmt_system_prompt(mock_run_lm): # mock_run_lm.return_value = ("Mocked content from msg fmt", None) -# result = lmp_with_message_fmt("input", lm_params=dict(temperature=0.5)) +# result = lmp_with_message_fmt("input", api_params=dict(temperature=0.5)) # mock_run_lm.assert_called_once_with( # model="gpt-4-turbo", From c9a293d91052bc6567f98e0b4835ed3fdeb371c6 Mon Sep 17 00:00:00 2001 From: William Guss Date: Fri, 20 Sep 2024 17:54:21 -0700 Subject: [PATCH 11/17] deprecate lm_params fuly --- src/ell/lmp/complex.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/ell/lmp/complex.py b/src/ell/lmp/complex.py index 8d5c4771..2d2c2bfd 100644 --- a/src/ell/lmp/complex.py +++ b/src/ell/lmp/complex.py @@ -29,12 +29,17 @@ def model_call( _invocation_origin : Optional[str] = None, client: Optional[Any] = None, api_params: Optional[Dict[str, Any]] = None, + lm_params: Optional[DeprecationWarning] = None, **prompt_kwargs, ) -> Tuple[Any, Any, Any]: + # XXX: Deprecation in 0.1.0 + if lm_params: + raise DeprecationWarning("lm_params is deprecated. Use api_params instead.") + # promt -> str res = prompt(*prompt_args, **prompt_kwargs) # Convert prompt into ell messages - messages = _get_messages(res, prompt) + messages = _get_messages(res, prompt) # XXX: move should log to a logger. should_log = not exempt_from_tracking and config.verbose @@ -64,6 +69,7 @@ def model_call( (result, final_api_params, metadata) = provider.call(ell_call, origin_id=_invocation_origin, logger=_logger) if isinstance(result, list) and len(result) == 1: result = result[0] + result = post_callback(result) if post_callback else result if should_log: model_usage_logger_post_end() From 63ce6e4581802c255136fe3ee11c76c7a3e2a5b6 Mon Sep 17 00:00:00 2001 From: William Guss Date: Fri, 20 Sep 2024 18:25:34 -0700 Subject: [PATCH 12/17] anthropic --- examples/claude.py | 2 +- src/ell/provider.py | 8 +- src/ell/providers/__init__.py | 2 +- src/ell/providers/anthropic.py | 135 ++++++++++----------------------- src/ell/providers/openai.py | 6 +- 5 files changed, 50 insertions(+), 103 deletions(-) diff --git a/examples/claude.py b/examples/claude.py index 28e961f2..a15a4267 100644 --- a/examples/claude.py +++ b/examples/claude.py @@ -8,5 +8,5 @@ def hello_from_claude(): if __name__ == "__main__": ell.init(verbose=True, store="./logdir", autocommit=True) - hello_from_claude() + print(hello_from_claude()) diff --git a/src/ell/provider.py b/src/ell/provider.py index d33bebc8..20dab7d1 100644 --- a/src/ell/provider.py +++ b/src/ell/provider.py @@ -61,7 +61,7 @@ class Provider(ABC): ################################ @abstractmethod def provider_call_function( - self, api_call_params: Optional[Dict[str, Any]] = None + self, client: Any, api_call_params: Optional[Dict[str, Any]] = None ) -> Callable[..., Any]: """ Implement this method to return the function that makes the API call to the language model. @@ -75,8 +75,8 @@ def disallowed_api_params(self) -> FrozenSet[str]: """ return frozenset({"messages", "tools", "model", "stream", "stream_options"}) - def available_api_params(self, api_params: Optional[Dict[str, Any]] = None): - params = _call_params(self.provider_call_function(api_params)) + def available_api_params(self, client: Any, api_params: Optional[Dict[str, Any]] = None): + params = _call_params(self.provider_call_function(client, api_params)) return frozenset(params.keys()) - self.disallowed_api_params() ################################ @@ -116,7 +116,7 @@ def call( final_api_call_params = self.translate_to_provider(ell_call) - call = self.provider_call_function(final_api_call_params) + call = self.provider_call_function(ell_call.client, final_api_call_params) assert self.dangerous_disable_validation or _validate_provider_call_params(final_api_call_params, call) diff --git a/src/ell/providers/__init__.py b/src/ell/providers/__init__.py index 90ec14c3..763dfc07 100644 --- a/src/ell/providers/__init__.py +++ b/src/ell/providers/__init__.py @@ -1,5 +1,5 @@ import ell.providers.openai -# import ell.providers.anthropic +import ell.providers.anthropic # import ell.providers.groq # import ell.providers.mistral # import ell.providers.cohere diff --git a/src/ell/providers/anthropic.py b/src/ell/providers/anthropic.py index 214b6d25..975b76c5 100644 --- a/src/ell/providers/anthropic.py +++ b/src/ell/providers/anthropic.py @@ -1,5 +1,5 @@ -from typing import Any, Dict, List, Optional, Tuple, Type -from ell.provider import APICallResult, Provider +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast +from ell.provider import EllCallParams, Metadata, Provider from ell.types import Message, ContentBlock, ToolCall from ell.types._lstr import _lstr from ell.types.message import LMP @@ -12,21 +12,19 @@ try: import anthropic from anthropic import Anthropic + from anthropic.types import Message as AnthropicMessage, MessageCreateParams, RawMessageStreamEvent + from anthropic._streaming import Stream class AnthropicProvider(Provider): - @classmethod - def call( - cls, - client: Anthropic, - model: str, - messages: List[Message], - api_params: Dict[str, Any], - tools: Optional[list[LMP]] = None, - ) -> APICallResult: - final_call_params = api_params.copy() + + def provider_call_function(self, client : Anthropic, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]: + return client.messages.create + + def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]: + final_call_params = ell_call.api_params.copy() assert final_call_params.get("max_tokens") is not None, f"max_tokens is required for anthropic calls, pass it to the @ell.simple/complex decorator, e.g. @ell.simple(..., max_tokens=your_max_tokens) or pass it to the model directly as a parameter when calling your LMP: your_lmp(..., api_params=({{'max_tokens': your_max_tokens}}))." - anthropic_messages = [message_to_anthropic_format(message) for message in messages] + anthropic_messages = [message_to_anthropic_format(message) for message in ell_call.messages] system_message = None if anthropic_messages and anthropic_messages[0]["role"] == "system": system_message = anthropic_messages.pop(0) @@ -34,49 +32,43 @@ def call( if system_message: final_call_params["system"] = system_message["content"][0]["text"] - actual_n = api_params.get("n", 1) - final_call_params["model"] = model + # XXX: untils streaming is implemented. + final_call_params['stream'] = True + + final_call_params["model"] = ell_call.model final_call_params["messages"] = anthropic_messages - if tools: + if ell_call.tools: final_call_params["tools"] = [ { "name": tool.__name__, "description": tool.__doc__, "input_schema": tool.__ell_params_model__.model_json_schema(), } - for tool in tools + for tool in ell_call.tools ] - # Streaming unsupported. - # XXX: Support soon. - stream = True - if stream: - response = client.messages.stream(**final_call_params) - else: - response = client.messages.create(**final_call_params) - - return APICallResult( - response=response, - actual_streaming=stream, - actual_n=actual_n, - final_call_params=final_call_params, - ) - - @classmethod - def process_response( - cls, call_result: APICallResult, _invocation_origin: str, logger: Optional[Any] = None, tools: Optional[List[LMP]] = None, - ) -> Tuple[List[Message], Dict[str, Any]]: + return final_call_params + + def translate_from_provider( + self, + provider_response : Union[Stream[RawMessageStreamEvent], AnthropicMessage], + ell_call: EllCallParams, + provider_call_params: Dict[str, Any], + origin_id: Optional[str] = None, + logger: Optional[Callable[..., None]] = None, + ) -> Tuple[List[Message], Metadata]: + usage = {} tracked_results = [] metadata = {} - if call_result.actual_streaming: + if provider_call_params.get("stream", False): content = [] current_block: Optional[Dict[str, Any]] = None message_metadata = {} - with call_result.response as stream: + with cast(Stream[RawMessageStreamEvent], provider_response) as stream: for chunk in stream: if chunk.type == "message_start": message_metadata = chunk.message.dict() @@ -90,35 +82,26 @@ def process_response( if current_block is not None: if current_block["type"] == "text": current_block["content"] += chunk.delta.text + logger(chunk.delta.text) elif chunk.type == "content_block_stop": if current_block is not None: if current_block["type"] == "text": - content.append(ContentBlock(text=_lstr(current_block["content"],origin_trace=_invocation_origin))) + content.append(ContentBlock(text=_lstr(current_block["content"],origin_trace=origin_id))) elif current_block["type"] == "tool_use": try: final_cb = chunk.content_block - matching_tool = next( - ( - tool - for tool in tools - if tool.__name__ == final_cb.name - ), - None, - ) + matching_tool = ell_call.get_tool_by_name(final_cb.name) if matching_tool: - params = matching_tool.__ell_params_model__( - **final_cb.input - ) content.append( ContentBlock( tool_call=ToolCall( tool=matching_tool, tool_call_id=_lstr( - final_cb.id,origin_trace=_invocation_origin + final_cb.id,origin_trace=origin_id ), - params=params, + params=final_cb.input, ) ) ) @@ -139,35 +122,8 @@ def process_response( elif chunk.type == "message_stop": tracked_results.append(Message(role="assistant", content=content)) - if logger and current_block: - if chunk.type == "text" and current_block["type"] == "text": - logger(chunk.text) # print(chunk) - - metadata = message_metadata - else: - # Non-streaming response processing (unchanged) - cbs = [] - for content_block in call_result.response.content: - if content_block.type == "text": - cbs.append(ContentBlock(text=_lstr(content_block.text,origin_trace=_invocation_origin))) - elif content_block.type == "tool_use": - assert tools is not None, "Tools were not provided to the model when calling it and yet anthropic returned a tool use." - tool_call = ToolCall( - tool=next((t for t in tools if t.__name__ == content_block.name), None) , - tool_call_id=content_block.id, - params=content_block.input - ) - cbs.append(ContentBlock(tool_call=tool_call)) - tracked_results.append(Message(role="assistant", content=cbs)) - if logger: - logger(tracked_results[0].text) - - - usage = call_result.response.usage.dict() if call_result.response.usage else {} - metadata = call_result.response.model_dump() - del metadata["content"] # process metadata for ell # XXX: Unify an ell metadata format for ell studio. @@ -178,28 +134,19 @@ def process_response( metadata["usage"] = usage return tracked_results, metadata - @classmethod - def supports_streaming(cls) -> bool: - return True - - @classmethod - def get_client_type(cls) -> Type: - return Anthropic - @staticmethod - def serialize_image_for_anthropic(img): - buffer = BytesIO() - img.save(buffer, format="PNG") - return base64.b64encode(buffer.getvalue()).decode() - - register_provider(AnthropicProvider) + register_provider(AnthropicProvider(), Anthropic) except ImportError: pass +def serialize_image_for_anthropic(img): + buffer = BytesIO() + img.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode() def content_block_to_anthropic_format(content_block: ContentBlock) -> Dict[str, Any]: if content_block.image: - base64_image = AnthropicProvider.serialize_image_for_anthropic(content_block.image) + base64_image = serialize_image_for_anthropic(content_block.image) return { "type": "image", "source": { diff --git a/src/ell/providers/openai.py b/src/ell/providers/openai.py index 064e5e0c..d7b3b6ef 100644 --- a/src/ell/providers/openai.py +++ b/src/ell/providers/openai.py @@ -18,11 +18,11 @@ class OpenAIProvider(Provider): dangerous_disable_validation = True - def provider_call_function(self, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]: + def provider_call_function(self, client : openai.Client, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]: if api_call_params and api_call_params.get("response_format"): - return openai.beta.chat.completions.parse + return client.beta.chat.completions.parse else: - return openai.chat.completions.create + return client.chat.completions.create def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]: final_call_params = ell_call.api_params.copy() From 88013e35678bd2bd5c1c52697ad4c7ae1c84d37f Mon Sep 17 00:00:00 2001 From: William Guss Date: Sun, 22 Sep 2024 12:11:47 -0700 Subject: [PATCH 13/17] workign anthropic provider --- src/ell/providers/anthropic.py | 176 ++++++++++++++++----------------- tests/test_lstr.py | 4 +- tests/test_openai_provider.py | 1 - 3 files changed, 88 insertions(+), 93 deletions(-) diff --git a/src/ell/providers/anthropic.py b/src/ell/providers/anthropic.py index 975b76c5..e57fc5bc 100644 --- a/src/ell/providers/anthropic.py +++ b/src/ell/providers/anthropic.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union, cast from ell.provider import EllCallParams, Metadata, Provider from ell.types import Message, ContentBlock, ToolCall from ell.types._lstr import _lstr @@ -12,7 +12,8 @@ try: import anthropic from anthropic import Anthropic - from anthropic.types import Message as AnthropicMessage, MessageCreateParams, RawMessageStreamEvent + from anthropic.types import Message as AnthropicMessage, MessageParam, RawMessageStreamEvent + from anthropic.types.message_create_params import MessageCreateParamsStreaming from anthropic._streaming import Stream class AnthropicProvider(Provider): @@ -20,31 +21,41 @@ class AnthropicProvider(Provider): def provider_call_function(self, client : Anthropic, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]: return client.messages.create - def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]: - final_call_params = ell_call.api_params.copy() + def translate_to_provider(self, ell_call : EllCallParams): + final_call_params = cast(MessageCreateParamsStreaming, ell_call.api_params.copy()) + # XXX: Helper, but should be depreicated due to ssot assert final_call_params.get("max_tokens") is not None, f"max_tokens is required for anthropic calls, pass it to the @ell.simple/complex decorator, e.g. @ell.simple(..., max_tokens=your_max_tokens) or pass it to the model directly as a parameter when calling your LMP: your_lmp(..., api_params=({{'max_tokens': your_max_tokens}}))." - anthropic_messages = [message_to_anthropic_format(message) for message in ell_call.messages] + dirty_msgs = [ + MessageParam( + role=cast(Literal["user", "assistant"], message.role), + content=[_content_block_to_anthropic_format(c) for c in message.content]) for message in ell_call.messages] + role_correct_msgs : List[MessageParam] = [] + for msg in dirty_msgs: + if (not len(role_correct_msgs) or role_correct_msgs[-1]['role'] != msg['role']): + role_correct_msgs.append(msg) + else: cast(List, role_correct_msgs[-1]['content']).extend(msg['content']) + system_message = None - if anthropic_messages and anthropic_messages[0]["role"] == "system": - system_message = anthropic_messages.pop(0) + if role_correct_msgs and role_correct_msgs[0]["role"] == "system": + system_message = role_correct_msgs.pop(0) if system_message: final_call_params["system"] = system_message["content"][0]["text"] + - # XXX: untils streaming is implemented. final_call_params['stream'] = True - final_call_params["model"] = ell_call.model - final_call_params["messages"] = anthropic_messages + final_call_params["messages"] = role_correct_msgs if ell_call.tools: final_call_params["tools"] = [ - { - "name": tool.__name__, - "description": tool.__doc__, - "input_schema": tool.__ell_params_model__.model_json_schema(), - } + #XXX: Cleaner with LMP's as a class. + dict( + name=tool.__name__, + description=tool.__doc__, + input_schema=tool.__ell_params_model__.model_json_schema(), + ) for tool in ell_call.tools ] @@ -63,6 +74,8 @@ def translate_from_provider( tracked_results = [] metadata = {} + #XXX: Support n > 0 + if provider_call_params.get("stream", False): content = [] current_block: Optional[Dict[str, Any]] = None @@ -71,53 +84,53 @@ def translate_from_provider( with cast(Stream[RawMessageStreamEvent], provider_response) as stream: for chunk in stream: if chunk.type == "message_start": - message_metadata = chunk.message.dict() + message_metadata = chunk.message.model_dump() message_metadata.pop("content", None) # Remove content as we'll build it separately elif chunk.type == "content_block_start": - current_block = chunk.content_block.dict() - current_block["content"] = "" - + current_block = chunk.content_block.model_dump() + if current_block["type"] == "tool_use": + if logger: logger(f" ") except json.JSONDecodeError: - # Handle partial JSON if necessary + if logger: logger(f" - FAILED TO PARSE JSON") pass + if logger: logger(f")>") + current_block = None elif chunk.type == "message_delta": - message_metadata.update(chunk.delta.dict()) + message_metadata.update(chunk.delta.model_dump()) if chunk.usage: - usage.update(chunk.usage.dict()) + usage.update(chunk.usage.model_dump()) elif chunk.type == "message_stop": tracked_results.append(Message(role="assistant", content=content)) @@ -134,62 +147,45 @@ def translate_from_provider( metadata["usage"] = usage return tracked_results, metadata - - register_provider(AnthropicProvider(), Anthropic) + # XXX: Could register a true base class. + register_provider(AnthropicProvider(), anthropic.Anthropic) + register_provider(AnthropicProvider(), anthropic.AnthropicBedrock) + register_provider(AnthropicProvider(), anthropic.AnthropicVertex) + except ImportError: + raise pass def serialize_image_for_anthropic(img): - buffer = BytesIO() - img.save(buffer, format="PNG") - return base64.b64encode(buffer.getvalue()).decode() - -def content_block_to_anthropic_format(content_block: ContentBlock) -> Dict[str, Any]: - if content_block.image: - base64_image = serialize_image_for_anthropic(content_block.image) - return { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": base64_image - } - } - elif content_block.text: - return { - "type": "text", - "text": content_block.text - } - elif content_block.parsed: - return { - "type": "text", - "text": json.dumps(content_block.parsed.model_dump()) - } - elif content_block.tool_call: - - return { - "type": "tool_use", - "id": content_block.tool_call.tool_call_id, - "name": content_block.tool_call.tool.__name__, - "input": content_block.tool_call.params.model_dump() - } - elif content_block.tool_result: - return { - "type": "tool_result", - "tool_use_id": content_block.tool_result.tool_call_id, - "content": [content_block_to_anthropic_format(c) for c in content_block.tool_result.result] - } - else: - raise ValueError("Content block is not supported by anthropic") - - - -def message_to_anthropic_format(message: Message) -> Dict[str, Any]: - - anthropic_message = { - "role": message.role, - "content": list(filter(None, [ - content_block_to_anthropic_format(c) for c in message.content - ])) - } - return anthropic_message \ No newline at end of file + buffer = BytesIO() + img.save(buffer, format="PNG") + base64_image = base64.b64encode(buffer.getvalue()).decode() + return dict( + type="image", + source=dict( + type="base64", + media_type="image/png", + data=base64_image + ) + ) + +def _content_block_to_anthropic_format(content_block: ContentBlock): + if (image := content_block.image): return serialize_image_for_anthropic(image) + elif (text := content_block.text): return dict(type="text", text=text) + elif (parsed := content_block.parsed): + return dict(type="text", text=json.dumps(parsed.model_dump())) + elif (tool_call := content_block.tool_call): + return dict( + type="tool_use", + id=tool_call.tool_call_id, + name=tool_call.tool.__name__, + input=tool_call.params.model_dump() + ) + elif (tool_result := content_block.tool_result): + return dict( + type="tool_result", + tool_use_id=tool_result.tool_call_id, + content=[_content_block_to_anthropic_format(c) for c in tool_result.result] + ) + else: + raise ValueError("Content block is not supported by anthropic") \ No newline at end of file diff --git a/tests/test_lstr.py b/tests/test_lstr.py index 9e8d6209..b3dfeece 100644 --- a/tests/test_lstr.py +++ b/tests/test_lstr.py @@ -13,11 +13,11 @@ def test_init(self): # Test initialization with logits andorigin_trace # logits = np.array([0.1, 0.2]) - origin_trace = "model1" + origin_trace = "model1" s = _lstr("world",origin_trace=origin_trace) # Removed logits parameter assert str(s) == "world" # assert np.array_equal(s.logits, logits) - assert s.origin_trace == frozenset({_origin_trace}) + assert s.origin_trace == frozenset({origin_trace}) def test_add(self): s1 = _lstr("hello") diff --git a/tests/test_openai_provider.py b/tests/test_openai_provider.py index 2938b445..5822987c 100644 --- a/tests/test_openai_provider.py +++ b/tests/test_openai_provider.py @@ -1,6 +1,5 @@ import pytest from unittest.mock import Mock, patch -from ell.provider import APICallResult from ell.providers.openai import OpenAIProvider from ell.types import Message, ContentBlock, ToolCall from ell.types.message import LMP, ToolResult From 6a30be1251d975ba35650e4994a03beca08555ec Mon Sep 17 00:00:00 2001 From: William Guss Date: Sun, 22 Sep 2024 14:27:35 -0700 Subject: [PATCH 14/17] better tests. --- src/ell/providers/anthropic.py | 1 + src/ell/providers/openai.py | 1 + src/ell/types/message.py | 3 +- tests/test_openai_provider.py | 532 ++++++++++++++++++++++++++------- 4 files changed, 424 insertions(+), 113 deletions(-) diff --git a/src/ell/providers/anthropic.py b/src/ell/providers/anthropic.py index e57fc5bc..4d99fa8c 100644 --- a/src/ell/providers/anthropic.py +++ b/src/ell/providers/anthropic.py @@ -17,6 +17,7 @@ from anthropic._streaming import Stream class AnthropicProvider(Provider): + dangerous_disable_validation = True def provider_call_function(self, client : Anthropic, api_call_params : Optional[Dict[str, Any]] = None) -> Callable[..., Any]: return client.messages.create diff --git a/src/ell/providers/openai.py b/src/ell/providers/openai.py index d7b3b6ef..c79b322d 100644 --- a/src/ell/providers/openai.py +++ b/src/ell/providers/openai.py @@ -126,6 +126,7 @@ def translate_from_provider( chat_completion = cast(Union[ChatCompletion, ParsedChatCompletion], provider_response) metadata = chat_completion.model_dump(exclude={"choices"}) for oai_choice in chat_completion.choices: + role = oai_choice.message.role content_blocks = [] if (refusal := (message := oai_choice.message).refusal): raise ValueError(refusal) diff --git a/src/ell/types/message.py b/src/ell/types/message.py index 0c69d41b..68e0a90d 100644 --- a/src/ell/types/message.py +++ b/src/ell/types/message.py @@ -21,7 +21,6 @@ class ToolResult(BaseModel): tool_call_id: _lstr_generic - #XXX: Add a validator to check that the result is a list of ContentBlocks. result: List["ContentBlock"] class ToolCall(BaseModel): @@ -29,7 +28,7 @@ class ToolCall(BaseModel): tool_call_id : Optional[_lstr_generic] = Field(default=None) params : BaseModel - def __init__(self, tool, tool_call_id, params : Union[BaseModel, Dict[str, Any]]): + def __init__(self, tool, params : Union[BaseModel, Dict[str, Any]], tool_call_id=None): if not isinstance(params, BaseModel): params = tool.__ell_params_model__(**params) #convenience. super().__init__(tool=tool, tool_call_id=tool_call_id, params=params) diff --git a/tests/test_openai_provider.py b/tests/test_openai_provider.py index 5822987c..2ecd01be 100644 --- a/tests/test_openai_provider.py +++ b/tests/test_openai_provider.py @@ -1,128 +1,438 @@ +import pydantic import pytest -from unittest.mock import Mock, patch -from ell.providers.openai import OpenAIProvider +from unittest.mock import MagicMock, patch +from ell.providers.openai import OpenAIProvider, _content_block_to_openai_format +from ell.provider import EllCallParams from ell.types import Message, ContentBlock, ToolCall -from ell.types.message import LMP, ToolResult -from pydantic import BaseModel -import json -import ell -class DummyParams(BaseModel): - param1: str - param2: int +from openai import Client +from openai.types.chat import ( + ChatCompletion, + ChatCompletionMessageParam, + ParsedChatCompletion, + ChatCompletionChunk, +) +from openai._streaming import Stream + + + +@pytest.fixture +def provider(): + return OpenAIProvider() + + +@pytest.fixture +def ell_call_params(openai_client): + return EllCallParams( + client=openai_client, # Added the required 'client' field + api_params={}, + model="gpt-4", + messages=[], + tools=[], + ) + + +@pytest.fixture +def openai_client(): + client = MagicMock(spec=Client) + + # Configure 'beta.chat.completions.parse' + client.beta = MagicMock() + client.beta.chat = MagicMock() + client.beta.chat.completions = MagicMock() + client.beta.chat.completions.parse = MagicMock() + + # Configure 'chat.completions.create' + client.chat = MagicMock() + client.chat.completions = MagicMock() + client.chat.completions.create = MagicMock() + + return client + @pytest.fixture -def mock_openai_client(): - return Mock() -import openai +def mock_tool(): + mock = MagicMock() + mock.__name__ = "mock_tool" + mock.__doc__ = "A mock tool" + # Define the __ell_params_model__ attribute with a mock + params_model = pydantic.create_model("MyModel", param1=(str, "...")) + mock.__ell_params_model__ = params_model + return mock + + +class TestOpenAIProvider: + def test_provider_call_function_with_response_format( + self, provider, openai_client, ell_call_params + ): + api_call_params = {"response_format": "parsed"} + func = provider.provider_call_function(openai_client, api_call_params) + assert func == openai_client.beta.chat.completions.parse + + def test_provider_call_function_without_response_format( + self, provider, openai_client, ell_call_params + ): + api_call_params = {} + func = provider.provider_call_function(openai_client, api_call_params) + assert func == openai_client.chat.completions.create + + def test_translate_to_provider_streaming_enabled(self, provider, ell_call_params): + ell_call_params.api_params = {"some_param": "value"} + ell_call_params.tools = [] + ell_call_params.messages = [ + Message(role="user", content=[ContentBlock(text="Hello")]) + ] + + translated = provider.translate_to_provider(ell_call_params) + assert translated["model"] == "gpt-4" + assert translated["stream"] is True + assert translated["stream_options"] == {"include_usage": True} + assert translated["messages"] == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + def test_translate_to_provider_streaming_disabled_due_to_response_format( + self, provider, ell_call_params + ): + ell_call_params.api_params = {"response_format": "parsed"} + ell_call_params.tools = [] + ell_call_params.messages = [ + Message(role="user", content=[ContentBlock(text="Hello")]) + ] + + translated = provider.translate_to_provider(ell_call_params) + assert "stream" not in translated + assert "stream_options" not in translated + + def test_translate_to_provider_with_tools( + self, provider, ell_call_params, mock_tool + ): + ell_call_params.tools = [mock_tool] + ell_call_params.api_params = {} + ell_call_params.messages = [] + + translated = provider.translate_to_provider(ell_call_params) + assert translated["tool_choice"] == "auto" + assert translated["tools"] == [ + { + "type": "function", + "function": { + "name": "mock_tool", + "description": "A mock tool", + "parameters": { + "properties": { + "param1": { + "default": "...", + "title": "Param1", + "type": "string", + }, + }, + "title": "MyModel", + "type": "object", + }, + }, + } + ] + + def test_translate_to_provider_with_tool_calls( + self, provider, ell_call_params, mock_tool + ): + tool_call = ToolCall( + tool=mock_tool, tool_call_id="123", params={"param1": "value1"} + ) + message = Message(role="assistant", content=[tool_call]) + ell_call_params.tools = [mock_tool] + ell_call_params.messages = [message] + + translated = provider.translate_to_provider(ell_call_params) + assert translated["messages"] == [ + { + "tool_calls": [ + { + "id": "123", + "type": "function", + "function": { + "name": "mock_tool", + "arguments": '{"param1": "value1"}', + }, + } + ], + "role": "assistant", + "content": None, + } + ] + + def test_translate_from_provider_streaming( + self, provider, ell_call_params, openai_client + ): + provider_call_params = {"stream": True} + stream_chunk = ChatCompletionChunk( + id="chatcmpl-123", + model="gpt-4", + choices=[ + dict( + index=0, + delta=dict(role="assistant", content="Hello"), + ) + ], + created=1234567890, + object="chat.completion.chunk", + usage=None, + ) + mock_stream = MagicMock(spec=Stream) + mock_stream.__iter__.return_value = [stream_chunk] + + with patch("ell.providers.openai.Stream", return_value=mock_stream): + messages, metadata = provider.translate_from_provider( + provider_response=mock_stream, + ell_call=ell_call_params, + provider_call_params=provider_call_params, + ) + assert messages == [ + Message(role="assistant", content=[ContentBlock(text="Hello")]) + ] + assert metadata == {} + + def test_translate_from_provider_non_streaming(self, provider, ell_call_params): + provider_call_params = {"stream": False} + chat_completion = ChatCompletion( + id="chatcmpl-123", + model="gpt-4", + choices=[ + dict( + index=0, + message=dict(role="assistant", content="Hello"), + finish_reason="stop", + ) + ], + created=1234567890, + object="chat.completion", + ) + + messages, metadata = provider.translate_from_provider( + provider_response=chat_completion, + ell_call=ell_call_params, + provider_call_params=provider_call_params, + ) + assert messages == [ + Message(role="assistant", content=[ContentBlock(text="Hello")]) + ] + + def test_translate_from_provider_with_refusal(self, provider, ell_call_params): + chat_completion = ChatCompletion( + id="chatcmpl-123", + model="gpt-4", + choices=[ + dict( + index=0, + message=dict( + role="assistant", content=None, refusal="Refusal message" + ), + finish_reason="stop", + ) + ], + created=1234567890, + object="chat.completion", + ) + + with pytest.raises(ValueError) as excinfo: + provider.translate_from_provider( + provider_response=chat_completion, + ell_call=ell_call_params, + provider_call_params={"stream": False}, + ) + assert "Refusal message" in str(excinfo.value) + + # Additional assertions to improve test coverage + assert excinfo.value.args[0] == "Refusal message" + assert ell_call_params.client.mock_calls == [] # Ensure no client calls were made + + def test_translate_to_provider_with_multiple_messages( + self, provider, ell_call_params + ): + ell_call_params.messages = [ + Message(role="user", content=[ContentBlock(text="Hello")]), + Message(role="assistant", content=[ContentBlock(text="Hi there!")]), + ] + + translated = provider.translate_to_provider(ell_call_params) + assert translated["messages"] == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "Hi there!"}]}, + ] + + def test_translate_to_provider(self, provider, ell_call_params): + import numpy as np + + image_block = ContentBlock( + image=np.random.rand(100, 100, 3), image_detail="detail" + ) # Truncated valid base64 + openai_format = _content_block_to_openai_format(image_block) + assert isinstance(openai_format, dict) + assert "type" in openai_format + assert openai_format["type"] == "image_url" + assert "image_url" in openai_format + assert isinstance(openai_format["image_url"], dict) + assert "url" in openai_format["image_url"] + assert isinstance(openai_format["image_url"]["url"], str) + assert openai_format["image_url"]["url"].startswith("data:image/png;base64,") + assert "detail" in openai_format["image_url"] + assert openai_format["image_url"]["detail"] == "detail" + + def test_translate_to_provider_with_parsed_message(self, provider, ell_call_params): + model = pydantic.create_model("MyModel", field=(str, "...")) + parsed_block = ContentBlock(parsed=model(field="value")) + ell_call_params.messages = [Message(role="user", content=[parsed_block])] + + translated = provider.translate_to_provider(ell_call_params) + assert translated["messages"] == [ + {"role": "user", "content": [{"type": "text", "text": '{"field":"value"}'}]} + ] + + def test_translate_from_provider_with_usage_metadata( + self, provider, ell_call_params + ): + chunk_with_usage = ChatCompletionChunk( + id="chunk_123", + created=1612288000, + object="chat.completion.chunk", + model="gpt-4", + choices=[ + dict( + index=0, delta=dict(role="assistant", content="Hello") + ) # Added index=0 + ], + usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + ) + mock_stream = MagicMock(spec=Stream) + mock_stream.__iter__.return_value = [chunk_with_usage] + + messages, metadata = provider.translate_from_provider( + provider_response=mock_stream, + ell_call=ell_call_params, + provider_call_params={"stream": True}, + ) + assert messages == [ + Message(role="assistant", content=[ContentBlock(text="Hello")]) + ] + assert ( + "prompt_tokens" in metadata["usage"] + and metadata["usage"]["prompt_tokens"] == 10 + ) + assert ( + "completion_tokens" in metadata["usage"] + and metadata["usage"]["completion_tokens"] == 5 + ) + assert ( + "total_tokens" in metadata["usage"] + and metadata["usage"]["total_tokens"] == 15 + ) + + def test_translate_from_provider_with_multiple_chunks( + self, provider, ell_call_params + ): + chunk1 = ChatCompletionChunk( + id="chunk_1", + object="chat.completion.chunk", + created=1234567890, + model="gpt-4", + choices=[dict(index=0, delta=dict(role="assistant", content="Hello"))], + ) + chunk2 = ChatCompletionChunk( + id="chunk_2", + object="chat.completion.chunk", + created=1234567891, + model="gpt-4", + choices=[dict(index=0, delta=dict(content=" World"))], + ) + mock_stream = MagicMock(spec=Stream) + mock_stream.__iter__.return_value = [chunk1, chunk2] + + messages, metadata = provider.translate_from_provider( + provider_response=mock_stream, + ell_call=ell_call_params, + provider_call_params={"stream": True}, + ) + assert messages == [ + Message(role="assistant", content=[ContentBlock(text="Hello World")]) + ] + assert metadata == {} + + # Suggested Test for _content_block_to_openai_format def test_content_block_to_openai_format(): + from ell.providers.openai import _content_block_to_openai_format + from ell.types import ContentBlock + from ell.util.serialization import serialize_image + from PIL import Image + import numpy as np + # Test text content - text_block = ContentBlock(text="Hello, world!") - assert OpenAIProvider.content_block_to_openai_format(text_block) == { - "type": "text", - "text": "Hello, world!" - } + text_block = ContentBlock(text="Hello World") + expected_text = {"type": "text", "text": "Hello World"} + assert _content_block_to_openai_format(text_block) == expected_text # Test parsed content - class DummyParsed(BaseModel): + class ParsedModel(pydantic.BaseModel): field: str - parsed_block = ContentBlock(parsed=DummyParsed(field="value")) - - res = OpenAIProvider.content_block_to_openai_format(parsed_block) - assert res["type"] == "text" - assert (res["text"]) == '{"field":"value"}' - + parsed_block = ContentBlock(parsed=ParsedModel(field="value")) + expected_parsed = {"type": "text", "text": '{"field":"value"}'} + assert _content_block_to_openai_format(parsed_block) == expected_parsed - # Test image content (mocked) - with patch('ell.providers.openai.serialize_image', return_value="base64_image_data"): - # Test random image content - import numpy as np - from PIL import Image - - # Generate a random image - random_image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8) - pil_image = Image.fromarray(random_image) - - with patch('ell.providers.openai.serialize_image', return_value="random_base64_image_data"): - random_image_block = ContentBlock(image=pil_image) - assert OpenAIProvider.content_block_to_openai_format(random_image_block) == { - "type": "image_url", - "image_url": { - "url": "random_base64_image_data" - } - } - - -def test_message_to_openai_format(): - # Test simple message - simple_message = Message(role="user", content=[ContentBlock(text="Hello")]) - assert OpenAIProvider.message_to_openai_format(simple_message) == { - "role": "user", - "content": [{"type": "text", "text": "Hello"}] + # Test image content with image_detail + img = Image.new('RGB', (100, 100)) + serialized_img = serialize_image(img) + image_block = ContentBlock(image=img, image_detail="Sample Image") + expected_image = { + "type": "image_url", + "image_url": { + "url": serialized_img, + "detail": "Sample Image" + } } + assert _content_block_to_openai_format(image_block) == expected_image - # Test message with tool calls - def dummy_tool(param1: str, param2: int): pass - tool_call = ToolCall(tool=dummy_tool, tool_call_id="123", params=DummyParams(param1="test", param2=42)) - tool_message = Message(role="assistant", content=[tool_call]) - formatted = OpenAIProvider.message_to_openai_format(tool_message) - assert formatted["role"] == "assistant" - assert formatted["content"] is None - assert len(formatted["tool_calls"]) == 1 - assert formatted["tool_calls"][0]["id"] == "123" - assert formatted["tool_calls"][0]["function"]["name"] == "dummy_tool" - assert json.loads(formatted["tool_calls"][0]["function"]["arguments"]) == {"param1": "test", "param2": 42} - - # Test message with tool results - tool_result_message = Message( - role="user", - content=[ToolResult(tool_call_id="123", result=[ContentBlock(text="Tool output")])], - ) - formatted = OpenAIProvider.message_to_openai_format(tool_result_message) - assert formatted["role"] == "tool" - assert formatted["tool_call_id"] == "123" - assert formatted["content"] == "Tool output" - -def test_call_model(mock_openai_client): - messages = [Message(role="user", content=[ContentBlock(text="Hello")], refusal=None)] - api_params = {"temperature": 0.7} - - # Mock the client's chat.completions.create method - mock_openai_client.chat.completions.create.return_value = Mock(choices=[Mock(message=Mock(content="Response", refusal=None))]) - - @ell.tool() - def dummy_tool(param1: str, param2: int): pass - - result = OpenAIProvider.call(mock_openai_client, "gpt-3.5-turbo", messages, api_params, tools=[dummy_tool]) - - assert isinstance(result, APICallResult) - assert not "stream" in result.final_call_params - assert not result.actual_streaming - assert result.actual_n == 1 - assert "messages" in result.final_call_params - assert result.final_call_params["model"] == "gpt-3.5-turbo" - - -def test_process_response(): - # Mock APICallResult - mock_response = Mock( - choices=[Mock(message=Mock(role="assistant", content="Hello, world!", refusal=None, tool_calls=None))] - ) - call_result = APICallResult( - response=mock_response, - actual_streaming=False, - actual_n=1, - final_call_params={} - ) + # Test image content without image_detail + image_block_no_detail = ContentBlock(image=img) + expected_image_no_detail = { + "type": "image_url", + "image_url": { + "url": serialized_img + } + } + assert _content_block_to_openai_format(image_block_no_detail) == expected_image_no_detail - processed_messages, metadata = OpenAIProvider.process_response(call_result, "test_origin") + # Test unsupported content type + with pytest.raises(ValueError): + _content_block_to_openai_format(ContentBlock(audio=[0.1, 0.2])) - assert len(processed_messages) == 1 - assert processed_messages[0].role == "assistant" - assert len(processed_messages[0].content) == 1 - assert processed_messages[0].content[0].text == "Hello, world!" +def test_translate_to_provider_no_tools_no_streaming(): -def test_supports_streaming(): - assert OpenAIProvider.supports_streaming() == True + provider = OpenAIProvider() + ell_call_params = EllCallParams( + client=MagicMock(), + api_params={"response_format": "parsed"}, + model="gpt-4", + messages=[Message(role="user", content=[ContentBlock(text="Hello")])], + tools=[] + ) + + translated = provider.translate_to_provider(ell_call_params) + assert "stream" not in translated + assert "stream_options" not in translated + assert translated["response_format"] == "parsed" + assert translated["model"] == "gpt-4" + +def test_translate_to_provider_with_custom_stream_options(): + provider = OpenAIProvider() + ell_call_params = EllCallParams( + client=MagicMock(), + api_params={"custom_option": True}, + model="gpt-4", + messages=[Message(role="user", content=[ContentBlock(text="Hello")])], + tools=[] + ) -# Add more tests as needed for other methods and edge cases + translated = provider.translate_to_provider(ell_call_params) + assert translated["custom_option"] is True + assert translated["stream"] is True + assert translated["stream_options"] == {"include_usage": True} \ No newline at end of file From baaaf7153e5744349138d289bbd19cfb2971648a Mon Sep 17 00:00:00 2001 From: William Guss Date: Sun, 22 Sep 2024 14:40:03 -0700 Subject: [PATCH 15/17] update openai --- poetry.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index 32f7ff3a..3cf4128c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -907,13 +907,13 @@ files = [ [[package]] name = "openai" -version = "1.42.0" +version = "1.47.0" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.42.0-py3-none-any.whl", hash = "sha256:dc91e0307033a4f94931e5d03cc3b29b9717014ad5e73f9f2051b6cb5eda4d80"}, - {file = "openai-1.42.0.tar.gz", hash = "sha256:c9d31853b4e0bc2dc8bd08003b462a006035655a701471695d0bfdc08529cde3"}, + {file = "openai-1.47.0-py3-none-any.whl", hash = "sha256:9ccc8737dfa791f7bd903db4758c176b8544a8cd89d3a3d2add3cea02a34c3a0"}, + {file = "openai-1.47.0.tar.gz", hash = "sha256:6e14d6f77c8cf546646afcd87a2ef752505b3710d2564a2e433e17307dfa86a0"}, ] [package.dependencies] From 5a6f997c9cd5b5ca01f0b67509c4a18b179817ef Mon Sep 17 00:00:00 2001 From: William Guss Date: Sun, 22 Sep 2024 14:46:39 -0700 Subject: [PATCH 16/17] test dependencies --- .github/workflows/pytest.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 45d0af77..2704a5b6 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -29,6 +29,7 @@ jobs: - name: Install dependencies run: | poetry install + poetry run pip install anthropic - name: Run pytest run: | From 7463386d0f46bff0f1e9fb8357314fa747d17c26 Mon Sep 17 00:00:00 2001 From: William Guss Date: Sun, 22 Sep 2024 14:48:16 -0700 Subject: [PATCH 17/17] comments: --- src/ell/provider.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ell/provider.py b/src/ell/provider.py index 20dab7d1..f1fc34b4 100644 --- a/src/ell/provider.py +++ b/src/ell/provider.py @@ -105,6 +105,8 @@ def translate_from_provider( # Be careful to override this method in your provider. def call( self, + #XXX: In future refactors, we can fully enumerate the args and make ell_call's internal to the _provider implementer interface. + # This gives us a litellm style interface for free. ell_call: EllCallParams, origin_id: Optional[str] = None, logger: Optional[Any] = None,