From d83ff6ff8cdc9540856ed7213244a3488d5df361 Mon Sep 17 00:00:00 2001 From: hakan458 Date: Tue, 7 Jan 2025 17:23:34 -0800 Subject: [PATCH 1/3] feat: DIA-1715: VertexAI Gemini model support --- adala/runtimes/_litellm.py | 60 +++++++++++++---- server/app.py | 132 +++++++++++++++++++++++++++++++++++-- 2 files changed, 172 insertions(+), 20 deletions(-) diff --git a/adala/runtimes/_litellm.py b/adala/runtimes/_litellm.py index c682ed3a..d2a02b8b 100644 --- a/adala/runtimes/_litellm.py +++ b/adala/runtimes/_litellm.py @@ -50,6 +50,8 @@ logger = logging.getLogger(__name__) +# TODO remove +litellm.drop_params = True # basically only retrying on timeout, incomplete output, or rate limit # https://docs.litellm.ai/docs/exception_mapping#custom-mapping-list @@ -153,6 +155,22 @@ def _get_usage_dict(usage: Usage, model: str) -> Dict: return data +def resolve_litellm_model_and_provider(model_name: str, provider: str): + """ + When using litellm.get_model_info() some models are accessed with their provider prefix + while others are not. + + This helper function contains logic which resolves this for supported providers + """ + if "/" in model_name: # TODO handle models like vertex_ai/meta/llama ... + model_name = model_name.split("/")[1] + provider = provider.lower() + if provider == "vertexai": + provider = "vertex_ai" + + return model_name, provider + + class InstructorClientMixin: def _from_litellm(self, **kwargs): return instructor.from_litellm(litellm.completion, **kwargs) @@ -160,7 +178,7 @@ def _from_litellm(self, **kwargs): @cached_property def client(self): kwargs = {} - if self.is_custom_openai_endpoint: + if self.is_custom_openai_endpoint or self.model.startswith("vertex"): kwargs["mode"] = instructor.Mode.JSON return self._from_litellm(**kwargs) @@ -402,7 +420,7 @@ def init_runtime(self) -> "Runtime": model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, - seed=self.seed, + # seed=self.seed, # extra inference params passed to this runtime **self.model_extra, ) @@ -553,9 +571,12 @@ def _get_prompt_tokens(string: str, model: str, output_fields: List[str]) -> int return user_tokens + system_tokens @staticmethod - def _get_completion_tokens(model: str, output_fields: Optional[List[str]]) -> int: + def _get_completion_tokens( + model: str, output_fields: Optional[List[str]], provider: str + ) -> int: + model, provider = resolve_litellm_model_and_provider(model, provider) max_tokens = litellm.get_model_info( - model=model, custom_llm_provider="openai" + model=model, custom_llm_provider=provider ).get("max_tokens", None) if not max_tokens: raise ValueError @@ -565,10 +586,14 @@ def _get_completion_tokens(model: str, output_fields: Optional[List[str]]) -> in @classmethod def _estimate_cost( - cls, user_prompt: str, model: str, output_fields: Optional[List[str]] + cls, + user_prompt: str, + model: str, + output_fields: Optional[List[str]], + provider: str, ): prompt_tokens = cls._get_prompt_tokens(user_prompt, model, output_fields) - completion_tokens = cls._get_completion_tokens(model, output_fields) + completion_tokens = cls._get_completion_tokens(model, output_fields, provider) prompt_cost, completion_cost = litellm.cost_per_token( model=model, prompt_tokens=prompt_tokens, @@ -579,7 +604,11 @@ def _estimate_cost( return prompt_cost, completion_cost, total_cost def get_cost_estimate( - self, prompt: str, substitutions: List[Dict], output_fields: Optional[List[str]] + self, + prompt: str, + substitutions: List[Dict], + output_fields: Optional[List[str]], + provider: str, ) -> CostEstimate: try: user_prompts = [ @@ -594,6 +623,7 @@ def get_cost_estimate( user_prompt=user_prompt, model=self.model, output_fields=output_fields, + provider=provider, ) cumulative_prompt_cost += prompt_cost cumulative_completion_cost += completion_cost @@ -729,8 +759,9 @@ class AsyncLiteLLMVisionRuntime(AsyncLiteLLMChatRuntime): def init_runtime(self) -> "Runtime": super().init_runtime() - if not litellm.supports_vision(self.model): - raise ValueError(f"Model {self.model} does not support vision") + # model_name = self.model + # if not litellm.supports_vision(model_name): + # raise ValueError(f"Model {self.model} does not support vision") return self async def batch_to_batch( @@ -816,7 +847,10 @@ async def batch_to_batch( # TODO: cost estimate -def get_model_info(provider: str, model_name: str, auth_info: Optional[dict]=None) -> dict: + +def get_model_info( + provider: str, model_name: str, auth_info: Optional[dict] = None +) -> dict: if auth_info is None: auth_info = {} try: @@ -826,11 +860,11 @@ def get_model_info(provider: str, model_name: str, auth_info: Optional[dict]=Non model=f"azure/{model_name}", messages=[{"role": "user", "content": ""}], max_tokens=1, - **auth_info + **auth_info, ) model_name = dummy_completion.model - full_name = f"{provider}/{model_name}" - return litellm.get_model_info(full_name) + model_name, provider = resolve_litellm_model_and_provider(model_name, provider) + return litellm.get_model_info(model=model_name, custom_llm_provider=provider) except Exception as err: logger.error("Hit error when trying to get model metadata: %s", err) return {} diff --git a/server/app.py b/server/app.py index ad21e3d0..e539ed98 100644 --- a/server/app.py +++ b/server/app.py @@ -16,6 +16,8 @@ from fastapi import HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware import litellm +from litellm.exceptions import AuthenticationError +from litellm.utils import check_valid_key, get_valid_models from pydantic import BaseModel, SerializeAsAny, field_validator, Field, model_validator from redis import Redis import time @@ -37,7 +39,6 @@ logger = init_logger(__name__) - settings = Settings() app = fastapi.FastAPI() @@ -83,10 +84,29 @@ class BatchSubmitted(BaseModel): job_id: str +class ModelsListRequest(BaseModel): + provider: str + + +class ModelsListResponse(BaseModel): + models_list: List[str] + + class CostEstimateRequest(BaseModel): agent: Agent prompt: str substitutions: List[Dict] + provider: str + + +class ValidateConnectionRequest(BaseModel): + provider: str + api_key: Optional[str] = None + vertex_credentials: Optional[str] = None + api_version: Optional[str] = None + deployment_name: Optional[str] = None + endpoint: Optional[str] = None + auth_token: Optional[str] = None class Status(Enum): @@ -216,6 +236,94 @@ async def submit_batch(batch: BatchData): return Response[BatchSubmitted](data=BatchSubmitted(job_id=batch.job_id)) +@app.post("/validate-connection", response_model=Response) +async def validate_connection(request: ValidateConnectionRequest): + multi_model_provider_models = { + "openai": "gpt-4o-mini", + "vertexai": "vertex_ai/gemini-1.5-flash", + } + provider = request.provider.lower() + messages = [{"role": "user", "content": "Hey, how's it going?"}] + + # For multi-model providers use a model that every account should have access to + if provider in multi_model_provider_models.keys(): + model = multi_model_provider_models[provider] + if provider == "openai": + model_extra = {"api_key": request.api_key} + elif provider == "vertexai": + model_extra = {"vertex_credentials": request.vertex_credentials} + try: + litellm.completion( + messages=messages, + model=model, + max_tokens=10, + temperature=0.0, + **model_extra, + ) + except AuthenticationError: + raise HTTPException( + status_code=400, + detail=f"Requested model '{model}' is not available with your api_key / credentials", + ) + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Error validating credentials for provider {provider}: {e}", + ) + + # For single-model connections use the provided model + else: + if provider.lower() == "azureopenai": + model = "azure/" + request.deployment_name + model_extra = {"base_url": request.endpoint} + elif provider.lower() == "custom": + model = "openai/" + request.deployment_name + model_extra = ( + {"extra_headers": {"Authorization": request.auth_token}} + if request.auth_token + else {} + ) + model_extra["api_key"] = request.api_key + try: + litellm.completion( + messages=messages, + model=model, + max_tokens=1000, + temperature=0.0, + seed=47, + **model_extra, + ) + except AuthenticationError: + raise HTTPException( + status_code=400, + detail=f"Requested model '{model}' is not available with your api_key and settings.", + ) + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Failed to check availability of requested model '{model}': {e}", + ) + + return Response(success=True, data=None) + + +@app.post("/models-list", response_model=Response[ModelsListResponse]) +async def models_list(request: ModelsListRequest): + # get_valid_models uses api key set in env, however the list is not dynamically retrieved + # https://docs.litellm.ai/docs/set_keys#get_valid_models + # https://github.com/BerriAI/litellm/blob/b9280528d368aced49cb4d287c57cd0b46168cb6/litellm/utils.py#L5705 + # Ultimately just uses litellm.models_by_provider - setting API key is not needed + lse_provider_to_litellm_provider = {"openai": "openai", "vertexai": "vertex_ai"} + provider = request.provider.lower() + valid_models = litellm.models_by_provider[ + lse_provider_to_litellm_provider[provider] + ] + + return Response[ModelsListResponse]( + data=ModelsListResponse(models_list=valid_models) + ) + + @app.post("/estimate-cost", response_model=Response[CostEstimate]) async def estimate_cost( request: CostEstimateRequest, @@ -238,6 +346,7 @@ async def estimate_cost( prompt = request.prompt substitutions = request.substitutions agent = request.agent + provider = request.provider runtime = agent.get_runtime() try: @@ -247,7 +356,10 @@ async def estimate_cost( list(skill.field_schema.keys()) if skill.field_schema else None ) cost_estimate = runtime.get_cost_estimate( - prompt=prompt, substitutions=substitutions, output_fields=output_fields + prompt=prompt, + substitutions=substitutions, + output_fields=output_fields, + provider=provider, ) cost_estimates.append(cost_estimate) total_cost_estimate = sum( @@ -429,21 +541,27 @@ class ModelMetadataRequestItem(BaseModel): model_name: str auth_info: Optional[Dict[str, str]] = None + class ModelMetadataRequest(BaseModel): models: List[ModelMetadataRequestItem] + class ModelMetadataResponse(BaseModel): model_metadata: Dict[str, Dict] + @app.post("/model-metadata", response_model=Response[ModelMetadataResponse]) async def model_metadata(request: ModelMetadataRequest): from adala.runtimes._litellm import get_model_info - resp = {'model_metadata': {item.model_name: get_model_info(**item.model_dump()) for item in request.models}} - return Response[ModelMetadataResponse]( - success=True, - data=resp - ) + resp = { + "model_metadata": { + item.model_name: get_model_info(**item.model_dump()) + for item in request.models + } + } + return Response[ModelMetadataResponse](success=True, data=resp) + if __name__ == "__main__": # for debugging From f4c5d79af67d713648dc9b8ef263c12c0be41da3 Mon Sep 17 00:00:00 2001 From: hakan458 Date: Wed, 8 Jan 2025 11:43:20 -0800 Subject: [PATCH 2/3] add vertex location+project to validation --- server/app.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/server/app.py b/server/app.py index e539ed98..d05da238 100644 --- a/server/app.py +++ b/server/app.py @@ -103,6 +103,8 @@ class ValidateConnectionRequest(BaseModel): provider: str api_key: Optional[str] = None vertex_credentials: Optional[str] = None + vertex_location: Optional[str] = None + vertex_project: Optional[str] = None api_version: Optional[str] = None deployment_name: Optional[str] = None endpoint: Optional[str] = None @@ -252,6 +254,10 @@ async def validate_connection(request: ValidateConnectionRequest): model_extra = {"api_key": request.api_key} elif provider == "vertexai": model_extra = {"vertex_credentials": request.vertex_credentials} + if request.vertex_location: + model_extra["vertex_location"] = request.vertex_location + if request.vertex_project: + model_extra["vertex_project"] = request.vertex_project try: litellm.completion( messages=messages, @@ -288,9 +294,8 @@ async def validate_connection(request: ValidateConnectionRequest): litellm.completion( messages=messages, model=model, - max_tokens=1000, + max_tokens=10, temperature=0.0, - seed=47, **model_extra, ) except AuthenticationError: From 9a8e51a49b123df071d210cfc452dfeae5e7bf31 Mon Sep 17 00:00:00 2001 From: hakan458 Date: Wed, 8 Jan 2025 13:58:13 -0800 Subject: [PATCH 3/3] return model name from validate connection --- server/app.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/server/app.py b/server/app.py index d05da238..6e498d6b 100644 --- a/server/app.py +++ b/server/app.py @@ -111,6 +111,11 @@ class ValidateConnectionRequest(BaseModel): auth_token: Optional[str] = None +class ValidateConnectionResponse(BaseModel): + model: str + success: bool + + class Status(Enum): PENDING = "Pending" INPROGRESS = "InProgress" @@ -238,7 +243,7 @@ async def submit_batch(batch: BatchData): return Response[BatchSubmitted](data=BatchSubmitted(job_id=batch.job_id)) -@app.post("/validate-connection", response_model=Response) +@app.post("/validate-connection", response_model=Response[ValidateConnectionResponse]) async def validate_connection(request: ValidateConnectionRequest): multi_model_provider_models = { "openai": "gpt-4o-mini", @@ -259,7 +264,7 @@ async def validate_connection(request: ValidateConnectionRequest): if request.vertex_project: model_extra["vertex_project"] = request.vertex_project try: - litellm.completion( + response = litellm.completion( messages=messages, model=model, max_tokens=10, @@ -291,7 +296,7 @@ async def validate_connection(request: ValidateConnectionRequest): ) model_extra["api_key"] = request.api_key try: - litellm.completion( + response = litellm.completion( messages=messages, model=model, max_tokens=10, @@ -309,7 +314,9 @@ async def validate_connection(request: ValidateConnectionRequest): detail=f"Failed to check availability of requested model '{model}': {e}", ) - return Response(success=True, data=None) + return Response[ValidateConnectionResponse]( + data=ValidateConnectionResponse(success=True, model=response.model) + ) @app.post("/models-list", response_model=Response[ModelsListResponse])