Skip to content

Commit cbc539c

Browse files
authored
feat: DIA-1384: add cost estimate endpoint (#225)
1 parent db34d51 commit cbc539c

File tree

5 files changed

+256
-10
lines changed

5 files changed

+256
-10
lines changed

adala/agents/base.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from adala.utils.internal_data import InternalDataFrame
3232
from adala.utils.types import BatchData
33+
3334
logger = logging.getLogger(__name__)
3435

3536

@@ -40,7 +41,7 @@ class Agent(BaseModel, ABC):
4041
4142
Attributes:
4243
environment (Environment): The environment with which the agent interacts.
43-
skills (Union[SkillSet, List[Skill]]): The skills possessed by the agent.
44+
skills (SkillSet): The skills possessed by the agent.
4445
memory (LongTermMemory, optional): The agent's long-term memory. Defaults to None.
4546
runtimes (Dict[str, Runtime], optional): The runtimes available to the agent. Defaults to predefined runtimes.
4647
default_runtime (str): The default runtime used by the agent. Defaults to 'openai'.
@@ -57,7 +58,7 @@ class Agent(BaseModel, ABC):
5758
"""
5859

5960
environment: Optional[SerializeAsAny[Union[Environment, AsyncEnvironment]]] = None
60-
skills: SerializeAsAny[Union[Skill, SkillSet]]
61+
skills: SerializeAsAny[SkillSet]
6162

6263
memory: Memory = Field(default=None)
6364
runtimes: Dict[str, SerializeAsAny[Union[Runtime, AsyncRuntime]]] = Field(
@@ -418,7 +419,7 @@ async def arefine_skill(
418419
skill = self.skills[skill_name]
419420
if not isinstance(skill, TransformSkill):
420421
raise ValueError(f"Skill {skill_name} is not a TransformSkill")
421-
422+
422423
# get default runtimes
423424
runtime = self.get_runtime()
424425
teacher_runtime = self.get_teacher_runtime()

adala/runtimes/_litellm.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import instructor
1414
from instructor.exceptions import InstructorRetryException, IncompleteOutputException
1515
import traceback
16+
from adala.runtimes.base import CostEstimate
1617
from adala.utils.exceptions import ConstrainedGenerationError
1718
from adala.utils.internal_data import InternalDataFrame
1819
from adala.utils.parse import (
@@ -122,7 +123,6 @@ def _get_usage_dict(usage: Usage, model: str) -> Dict:
122123

123124

124125
class InstructorClientMixin:
125-
126126
def _from_litellm(self, **kwargs):
127127
return instructor.from_litellm(litellm.completion, **kwargs)
128128

@@ -139,7 +139,6 @@ def is_custom_openai_endpoint(self) -> bool:
139139

140140

141141
class InstructorAsyncClientMixin(InstructorClientMixin):
142-
143142
def _from_litellm(self, **kwargs):
144143
return instructor.from_litellm(litellm.acompletion, **kwargs)
145144

@@ -527,6 +526,73 @@ async def record_to_record(
527526
# Extract the single row from the output DataFrame and convert it to a dictionary
528527
return output_df.iloc[0].to_dict()
529528

529+
@staticmethod
530+
def _get_prompt_tokens(string: str, model: str, output_fields: List[str]) -> int:
531+
user_tokens = litellm.token_counter(model=model, text=string)
532+
# FIXME surprisingly difficult to get function call tokens, and doesn't add a ton of value, so hard-coding until something like litellm supports doing this for us.
533+
# currently seems like we'd need to scrape the instructor logs to get the function call info, then use (at best) an openai-specific 3rd party lib to get a token estimate from that.
534+
system_tokens = 56 + (6 * len(output_fields))
535+
return user_tokens + system_tokens
536+
537+
@staticmethod
538+
def _get_completion_tokens(model: str, output_fields: Optional[List[str]]) -> int:
539+
max_tokens = litellm.get_model_info(
540+
model=model, custom_llm_provider="openai"
541+
).get("max_tokens", None)
542+
if not max_tokens:
543+
raise ValueError
544+
# extremely rough heuristic, from testing on some anecdotal examples
545+
n_outputs = len(output_fields) if output_fields else 1
546+
return min(max_tokens, 4 * n_outputs)
547+
548+
@classmethod
549+
def _estimate_cost(
550+
cls, user_prompt: str, model: str, output_fields: Optional[List[str]]
551+
):
552+
prompt_tokens = cls._get_prompt_tokens(user_prompt, model, output_fields)
553+
completion_tokens = cls._get_completion_tokens(model, output_fields)
554+
prompt_cost, completion_cost = litellm.cost_per_token(
555+
model=model,
556+
prompt_tokens=prompt_tokens,
557+
completion_tokens=completion_tokens,
558+
)
559+
total_cost = prompt_cost + completion_cost
560+
561+
return prompt_cost, completion_cost, total_cost
562+
563+
def get_cost_estimate(
564+
self, prompt: str, substitutions: List[Dict], output_fields: Optional[List[str]]
565+
) -> CostEstimate:
566+
try:
567+
user_prompts = [
568+
prompt.format(**substitution) for substitution in substitutions
569+
]
570+
cumulative_prompt_cost = 0
571+
cumulative_completion_cost = 0
572+
cumulative_total_cost = 0
573+
for user_prompt in user_prompts:
574+
prompt_cost, completion_cost, total_cost = self._estimate_cost(
575+
user_prompt=user_prompt,
576+
model=self.model,
577+
output_fields=output_fields,
578+
)
579+
cumulative_prompt_cost += prompt_cost
580+
cumulative_completion_cost += completion_cost
581+
cumulative_total_cost += total_cost
582+
return CostEstimate(
583+
prompt_cost_usd=cumulative_prompt_cost,
584+
completion_cost_usd=cumulative_completion_cost,
585+
total_cost_usd=cumulative_total_cost,
586+
)
587+
588+
except Exception as e:
589+
logger.error("Failed to estimate cost: %s", e)
590+
return CostEstimate(
591+
is_error=True,
592+
error_type=type(e).__name__,
593+
error_message=str(e),
594+
)
595+
530596

531597
class LiteLLMVisionRuntime(LiteLLMChatRuntime):
532598
"""

adala/runtimes/base.py

+44-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,51 @@
11
import logging
2+
from abc import abstractmethod
3+
from typing import Any, Dict, List, Optional, Type
24

3-
from tqdm import tqdm
4-
from abc import ABC, abstractmethod
5-
from pydantic import BaseModel, model_validator, Field
6-
from typing import List, Dict, Optional, Tuple, Any, Callable, ClassVar, Type
7-
from adala.utils.internal_data import InternalDataFrame, InternalSeries
5+
from adala.utils.internal_data import InternalDataFrame
86
from adala.utils.registry import BaseModelInRegistry
97
from pandarallel import pandarallel
8+
from pydantic import BaseModel, Field, model_validator
9+
from tqdm import tqdm
1010

1111
logger = logging.getLogger(__name__)
1212
tqdm.pandas()
1313

1414

15+
class CostEstimate(BaseModel):
16+
prompt_cost_usd: Optional[float] = None
17+
completion_cost_usd: Optional[float] = None
18+
total_cost_usd: Optional[float] = None
19+
is_error: bool = False
20+
error_type: Optional[str] = None
21+
error_message: Optional[str] = None
22+
23+
def __add__(self, other: "CostEstimate") -> "CostEstimate":
24+
# if either has an error, it takes precedence
25+
if self.is_error:
26+
return self
27+
if other.is_error:
28+
return other
29+
30+
def _safe_add(lhs: Optional[float], rhs: Optional[float]) -> Optional[float]:
31+
if lhs is None and rhs is None:
32+
return None
33+
_lhs = lhs or 0.0
34+
_rhs = rhs or 0.0
35+
return _lhs + _rhs
36+
37+
prompt_cost_usd = _safe_add(self.prompt_cost_usd, other.prompt_cost_usd)
38+
completion_cost_usd = _safe_add(
39+
self.completion_cost_usd, other.completion_cost_usd
40+
)
41+
total_cost_usd = _safe_add(self.total_cost_usd, other.total_cost_usd)
42+
return CostEstimate(
43+
prompt_cost_usd=prompt_cost_usd,
44+
completion_cost_usd=completion_cost_usd,
45+
total_cost_usd=total_cost_usd,
46+
)
47+
48+
1549
class Runtime(BaseModelInRegistry):
1650
"""
1751
Base class representing a generic runtime environment.
@@ -191,6 +225,11 @@ def record_to_batch(
191225
response_model=response_model,
192226
)
193227

228+
def get_cost_estimate(
229+
self, prompt: str, substitutions: List[Dict], output_fields: Optional[List[str]]
230+
) -> CostEstimate:
231+
raise NotImplementedError("This runtime does not support cost estimates")
232+
194233

195234
class AsyncRuntime(Runtime):
196235
"""Async version of runtime that uses asyncio to process batch of records."""

server/app.py

+60
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from aiokafka.errors import UnknownTopicOrPartitionError
1616
from fastapi import HTTPException, Depends
1717
from fastapi.middleware.cors import CORSMiddleware
18+
import litellm
1819
from pydantic import BaseModel, SerializeAsAny, field_validator, Field, model_validator
1920
from redis import Redis
2021
import time
@@ -24,6 +25,7 @@
2425
from server.handlers.result_handlers import ResultHandler
2526
from server.log_middleware import LogMiddleware
2627
from adala.skills.collection.prompt_improvement import ImprovedPromptResponse
28+
from adala.runtimes.base import CostEstimate
2729
from server.tasks.stream_inference import streaming_parent_task
2830
from server.utils import (
2931
Settings,
@@ -81,6 +83,12 @@ class BatchSubmitted(BaseModel):
8183
job_id: str
8284

8385

86+
class CostEstimateRequest(BaseModel):
87+
agent: Agent
88+
prompt: str
89+
substitutions: List[Dict]
90+
91+
8492
class Status(Enum):
8593
PENDING = "Pending"
8694
INPROGRESS = "InProgress"
@@ -210,6 +218,58 @@ async def submit_batch(batch: BatchData):
210218
return Response[BatchSubmitted](data=BatchSubmitted(job_id=batch.job_id))
211219

212220

221+
@app.post("/estimate-cost", response_model=Response[CostEstimate])
222+
async def estimate_cost(
223+
request: CostEstimateRequest,
224+
):
225+
"""
226+
Estimates what it would cost to run inference on the batch of data in
227+
`request` (using the run params from `request`)
228+
229+
Args:
230+
request (CostEstimateRequest): Specification for the inference run to
231+
make an estimate for, includes:
232+
agent (adala.agent.Agent): The agent definition, used to get the model
233+
and any other params necessary to estimate cost
234+
prompt (str): The prompt template that will be used for each task
235+
substitutions (List[Dict]): Mappings to substitute (simply using str.format)
236+
237+
Returns:
238+
Response[CostEstimate]: The cost estimate, including the prompt/completion/total costs (in USD)
239+
"""
240+
prompt = request.prompt
241+
substitutions = request.substitutions
242+
agent = request.agent
243+
runtime = agent.get_runtime()
244+
245+
try:
246+
cost_estimates = []
247+
for skill in agent.skills.skills.values():
248+
output_fields = (
249+
list(skill.field_schema.keys()) if skill.field_schema else None
250+
)
251+
cost_estimate = runtime.get_cost_estimate(
252+
prompt=prompt, substitutions=substitutions, output_fields=output_fields
253+
)
254+
cost_estimates.append(cost_estimate)
255+
total_cost_estimate = sum(
256+
cost_estimates,
257+
CostEstimate(
258+
prompt_cost_usd=None, completion_cost_usd=None, total_cost_usd=None
259+
),
260+
)
261+
262+
except NotImplementedError as e:
263+
return Response[CostEstimate](
264+
data=CostEstimate(
265+
is_error=True,
266+
error_type=type(e).__name__,
267+
error_message=str(e),
268+
)
269+
)
270+
return Response[CostEstimate](data=total_cost_estimate)
271+
272+
213273
@app.get("/jobs/{job_id}", response_model=Response[JobStatusResponse])
214274
def get_status(job_id):
215275
"""

tests/test_cost_estimation.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/usr/bin/env python3
2+
import pytest
3+
from adala.runtimes._litellm import AsyncLiteLLMChatRuntime
4+
from adala.runtimes.base import CostEstimate
5+
from adala.agents import Agent
6+
from adala.skills import ClassificationSkill
7+
import numpy as np
8+
import os
9+
from fastapi.testclient import TestClient
10+
from server.app import app, CostEstimateRequest
11+
12+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
13+
14+
15+
@pytest.mark.use_openai
16+
def test_simple_estimate_cost():
17+
runtime = AsyncLiteLLMChatRuntime(model="gpt-4o-mini", api_key=OPENAI_API_KEY)
18+
19+
cost_estimate = runtime.get_cost_estimate(
20+
prompt="testing, {text}",
21+
substitutions=[{"text": "knock knock, who's there"}],
22+
output_fields=["text"],
23+
)
24+
25+
assert isinstance(cost_estimate, CostEstimate)
26+
assert isinstance(cost_estimate.prompt_cost_usd, float)
27+
assert isinstance(cost_estimate.completion_cost_usd, float)
28+
assert isinstance(cost_estimate.total_cost_usd, float)
29+
assert np.isclose(
30+
cost_estimate.total_cost_usd,
31+
cost_estimate.prompt_cost_usd + cost_estimate.completion_cost_usd,
32+
)
33+
34+
35+
@pytest.mark.use_openai
36+
def test_estimate_cost_endpoint(client):
37+
req = {
38+
"agent": {
39+
"skills": [
40+
{
41+
"type": "ClassificationSkill",
42+
"name": "text_classifier",
43+
"instructions": "Always return the answer 'Feature Lack'.",
44+
"input_template": "{text}",
45+
"output_template": "{output}",
46+
"labels": [
47+
"Feature Lack",
48+
"Price",
49+
"Integration Issues",
50+
"Usability Concerns",
51+
"Competitor Advantage",
52+
],
53+
}
54+
],
55+
"runtimes": {
56+
"default": {
57+
"type": "AsyncLiteLLMChatRuntime",
58+
"model": "gpt-4o-mini",
59+
"api_key": OPENAI_API_KEY,
60+
}
61+
},
62+
},
63+
"prompt": "test {text}",
64+
"substitutions": [{"text": "test"}],
65+
}
66+
resp = client.post(
67+
"/estimate-cost",
68+
json=req,
69+
)
70+
resp_data = resp.json()["data"]
71+
cost_estimate = CostEstimate(**resp_data)
72+
73+
assert isinstance(cost_estimate, CostEstimate)
74+
assert isinstance(cost_estimate.prompt_cost_usd, float)
75+
assert isinstance(cost_estimate.completion_cost_usd, float)
76+
assert isinstance(cost_estimate.total_cost_usd, float)
77+
assert np.isclose(
78+
cost_estimate.total_cost_usd,
79+
cost_estimate.prompt_cost_usd + cost_estimate.completion_cost_usd,
80+
)

0 commit comments

Comments
 (0)