Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-1715: VertexAI Gemini model support #298

Merged
merged 17 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 56 additions & 22 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,35 @@ def _get_usage_dict(usage: Usage, model: str) -> Dict:
return data


class InstructorClientMixin:
def normalize_litellm_model_and_provider(model_name: str, provider: str):
"""
When using litellm.get_model_info() some models are accessed with their provider prefix
while others are not.

This helper function contains logic which normalizes this for supported providers
"""
if "/" in model_name:
model_name = model_name.split('/', 1)[1]
provider = provider.lower()
if provider == "vertexai":
provider = "vertex_ai"

return model_name, provider


class InstructorClientMixin(BaseModel):

# Note: most models work better with json mode; this is set only for backwards compatibility
# instructor_mode: str = "json_mode"
instructor_mode: str = "tool_call"

# Note: doesn't seem like this separate function should be necessary, but errors when combined with @cached_property
def _from_litellm(self, **kwargs):
return instructor.from_litellm(litellm.completion, **kwargs)

@cached_property
def client(self):
kwargs = {}
if self.is_custom_openai_endpoint:
kwargs["mode"] = instructor.Mode.JSON
return self._from_litellm(**kwargs)

@property
def is_custom_openai_endpoint(self) -> bool:
return self.model.startswith("openai/") and self.model_extra.get("base_url")
return self._from_litellm(mode=instructor.Mode(self.instructor_mode))


class InstructorAsyncClientMixin(InstructorClientMixin):
Expand Down Expand Up @@ -241,7 +256,6 @@ class LiteLLMChatRuntime(InstructorClientMixin, Runtime):
with the provider of your specified model.
base_url (Optional[str]): Base URL, optional. If provided, will be used to talk to an OpenAI-compatible API provider besides OpenAI.
api_version (Optional[str]): API version, optional except for Azure.
timeout: Timeout in seconds.
"""

model: str = "gpt-4o-mini"
Expand Down Expand Up @@ -382,7 +396,6 @@ class AsyncLiteLLMChatRuntime(InstructorAsyncClientMixin, AsyncRuntime):
with the provider of your specified model.
base_url (Optional[str]): Base URL, optional. If provided, will be used to talk to an OpenAI-compatible API provider besides OpenAI.
api_version (Optional[str]): API version, optional except for Azure.
timeout: Timeout in seconds.
"""

model: str = "gpt-4o-mini"
Expand Down Expand Up @@ -553,9 +566,12 @@ def _get_prompt_tokens(string: str, model: str, output_fields: List[str]) -> int
return user_tokens + system_tokens

@staticmethod
def _get_completion_tokens(model: str, output_fields: Optional[List[str]]) -> int:
def _get_completion_tokens(
model: str, output_fields: Optional[List[str]], provider: str
) -> int:
model, provider = normalize_litellm_model_and_provider(model, provider)
max_tokens = litellm.get_model_info(
model=model, custom_llm_provider="openai"
model=model, custom_llm_provider=provider
).get("max_tokens", None)
if not max_tokens:
raise ValueError
Expand All @@ -565,10 +581,14 @@ def _get_completion_tokens(model: str, output_fields: Optional[List[str]]) -> in

@classmethod
def _estimate_cost(
cls, user_prompt: str, model: str, output_fields: Optional[List[str]]
cls,
user_prompt: str,
model: str,
output_fields: Optional[List[str]],
provider: str,
):
prompt_tokens = cls._get_prompt_tokens(user_prompt, model, output_fields)
completion_tokens = cls._get_completion_tokens(model, output_fields)
completion_tokens = cls._get_completion_tokens(model, output_fields, provider)
prompt_cost, completion_cost = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
Expand All @@ -579,7 +599,11 @@ def _estimate_cost(
return prompt_cost, completion_cost, total_cost

def get_cost_estimate(
self, prompt: str, substitutions: List[Dict], output_fields: Optional[List[str]]
self,
prompt: str,
substitutions: List[Dict],
output_fields: Optional[List[str]],
provider: str,
) -> CostEstimate:
try:
user_prompts = [
Expand All @@ -594,6 +618,7 @@ def get_cost_estimate(
user_prompt=user_prompt,
model=self.model,
output_fields=output_fields,
provider=provider,
)
cumulative_prompt_cost += prompt_cost
cumulative_completion_cost += completion_cost
Expand Down Expand Up @@ -729,8 +754,12 @@ class AsyncLiteLLMVisionRuntime(AsyncLiteLLMChatRuntime):

def init_runtime(self) -> "Runtime":
super().init_runtime()
if not litellm.supports_vision(self.model):
raise ValueError(f"Model {self.model} does not support vision")
# Only running this supports_vision check for non-vertex models, since its based on a static JSON file in
# litellm which was not up to date. Will be soon in next release - should update this
if not self.model.startswith("vertex_ai"):
model_name = self.model
if not litellm.supports_vision(model_name):
raise ValueError(f"Model {self.model} does not support vision")
return self

async def batch_to_batch(
Expand Down Expand Up @@ -816,7 +845,10 @@ async def batch_to_batch(

# TODO: cost estimate

def get_model_info(provider: str, model_name: str, auth_info: Optional[dict]=None) -> dict:

def get_model_info(
provider: str, model_name: str, auth_info: Optional[dict] = None
) -> dict:
if auth_info is None:
auth_info = {}
try:
Expand All @@ -826,11 +858,13 @@ def get_model_info(provider: str, model_name: str, auth_info: Optional[dict]=Non
model=f"azure/{model_name}",
messages=[{"role": "user", "content": ""}],
max_tokens=1,
**auth_info
**auth_info,
)
model_name = dummy_completion.model
full_name = f"{provider}/{model_name}"
return litellm.get_model_info(full_name)
model_name, provider = normalize_litellm_model_and_provider(
model_name, provider
)
return litellm.get_model_info(model=model_name, custom_llm_provider=provider)
except Exception as err:
logger.error("Hit error when trying to get model metadata: %s", err)
return {}
144 changes: 137 additions & 7 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from fastapi import HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
import litellm
from litellm.exceptions import AuthenticationError
from litellm.utils import check_valid_key, get_valid_models
from pydantic import BaseModel, SerializeAsAny, field_validator, Field, model_validator
from redis import Redis
import time
Expand All @@ -37,7 +39,6 @@

logger = init_logger(__name__)


settings = Settings()

app = fastapi.FastAPI()
Expand Down Expand Up @@ -83,10 +84,36 @@ class BatchSubmitted(BaseModel):
job_id: str


class ModelsListRequest(BaseModel):
provider: str


class ModelsListResponse(BaseModel):
models_list: List[str]


class CostEstimateRequest(BaseModel):
agent: Agent
prompt: str
substitutions: List[Dict]
provider: str


class ValidateConnectionRequest(BaseModel):
provider: str
api_key: Optional[str] = None
vertex_credentials: Optional[str] = None
vertex_location: Optional[str] = None
vertex_project: Optional[str] = None
api_version: Optional[str] = None
deployment_name: Optional[str] = None
endpoint: Optional[str] = None
auth_token: Optional[str] = None


class ValidateConnectionResponse(BaseModel):
model: str
success: bool


class Status(Enum):
Expand Down Expand Up @@ -216,6 +243,99 @@ async def submit_batch(batch: BatchData):
return Response[BatchSubmitted](data=BatchSubmitted(job_id=batch.job_id))


@app.post("/validate-connection", response_model=Response[ValidateConnectionResponse])
async def validate_connection(request: ValidateConnectionRequest):
multi_model_provider_test_models = {
"openai": "gpt-4o-mini",
"vertexai": "vertex_ai/gemini-1.5-flash",
}
provider = request.provider.lower()
messages = [{"role": "user", "content": "Hey, how's it going?"}]

# For multi-model providers use a model that every account should have access to
if provider in multi_model_provider_test_models.keys():
model = multi_model_provider_test_models[provider]
if provider == "openai":
model_extra = {"api_key": request.api_key}
elif provider == "vertexai":
model_extra = {"vertex_credentials": request.vertex_credentials}
if request.vertex_location:
model_extra["vertex_location"] = request.vertex_location
if request.vertex_project:
model_extra["vertex_project"] = request.vertex_project
try:
response = litellm.completion(
messages=messages,
model=model,
max_tokens=10,
temperature=0.0,
**model_extra,
)
except AuthenticationError:
raise HTTPException(
status_code=400,
detail=f"Requested model '{model}' is not available with your api_key / credentials",
)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Error validating credentials for provider {provider}: {e}",
)

# For single-model connections use the provided model
else:
if provider.lower() == "azureopenai":
model = "azure/" + request.deployment_name
model_extra = {"base_url": request.endpoint}
elif provider.lower() == "custom":
model = "openai/" + request.deployment_name
model_extra = (
{"extra_headers": {"Authorization": request.auth_token}}
if request.auth_token
else {}
)
model_extra["api_key"] = request.api_key
try:
response = litellm.completion(
messages=messages,
model=model,
max_tokens=10,
temperature=0.0,
**model_extra,
)
except AuthenticationError:
raise HTTPException(
status_code=400,
detail=f"Requested model '{model}' is not available with your api_key and settings.",
)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to check availability of requested model '{model}': {e}",
)

return Response[ValidateConnectionResponse](
data=ValidateConnectionResponse(success=True, model=response.model)
)


@app.post("/models-list", response_model=Response[ModelsListResponse])
async def models_list(request: ModelsListRequest):
# get_valid_models uses api key set in env, however the list is not dynamically retrieved
# https://docs.litellm.ai/docs/set_keys#get_valid_models
# https://github.com/BerriAI/litellm/blob/b9280528d368aced49cb4d287c57cd0b46168cb6/litellm/utils.py#L5705
# Ultimately just uses litellm.models_by_provider - setting API key is not needed
lse_provider_to_litellm_provider = {"openai": "openai", "vertexai": "vertex_ai"}
provider = request.provider.lower()
valid_models = litellm.models_by_provider[
lse_provider_to_litellm_provider[provider]
]

return Response[ModelsListResponse](
data=ModelsListResponse(models_list=valid_models)
)


@app.post("/estimate-cost", response_model=Response[CostEstimate])
async def estimate_cost(
request: CostEstimateRequest,
Expand All @@ -238,6 +358,7 @@ async def estimate_cost(
prompt = request.prompt
substitutions = request.substitutions
agent = request.agent
provider = request.provider
runtime = agent.get_runtime()

try:
Expand All @@ -247,7 +368,10 @@ async def estimate_cost(
list(skill.field_schema.keys()) if skill.field_schema else None
)
cost_estimate = runtime.get_cost_estimate(
prompt=prompt, substitutions=substitutions, output_fields=output_fields
prompt=prompt,
substitutions=substitutions,
output_fields=output_fields,
provider=provider,
)
cost_estimates.append(cost_estimate)
total_cost_estimate = sum(
Expand Down Expand Up @@ -429,21 +553,27 @@ class ModelMetadataRequestItem(BaseModel):
model_name: str
auth_info: Optional[Dict[str, str]] = None


class ModelMetadataRequest(BaseModel):
models: List[ModelMetadataRequestItem]


class ModelMetadataResponse(BaseModel):
model_metadata: Dict[str, Dict]


@app.post("/model-metadata", response_model=Response[ModelMetadataResponse])
async def model_metadata(request: ModelMetadataRequest):
from adala.runtimes._litellm import get_model_info

resp = {'model_metadata': {item.model_name: get_model_info(**item.model_dump()) for item in request.models}}
return Response[ModelMetadataResponse](
success=True,
data=resp
)
resp = {
"model_metadata": {
item.model_name: get_model_info(**item.model_dump())
for item in request.models
}
}
return Response[ModelMetadataResponse](success=True, data=resp)


if __name__ == "__main__":
# for debugging
Expand Down
1 change: 1 addition & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_agent_is_serializable():
"verbose": False,
"batch_size": 100,
"concurrency": 1,
"instructor_mode": "tool_call",
"model": "gpt-4o-mini",
"max_tokens": 200,
"temperature": 0.0,
Expand Down
Loading