Skip to content

Commit

Permalink
Add step / task / workflow run / observer metrics as logs (#1698)
Browse files Browse the repository at this point in the history
Co-authored-by: Suchintan <[email protected]>
  • Loading branch information
wintonzheng and suchintan authored Feb 1, 2025
1 parent 41e8d8b commit 204972e
Show file tree
Hide file tree
Showing 11 changed files with 284 additions and 219 deletions.
48 changes: 40 additions & 8 deletions skyvern/forge/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
import string
from asyncio.exceptions import CancelledError
from datetime import datetime
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Tuple

Expand Down Expand Up @@ -755,6 +755,7 @@ async def agent_step(
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
json_response = await app.LLM_API_HANDLER(
prompt=extract_action_prompt,
prompt_name="extract-actions",
step=step,
screenshots=scraped_page.screenshots,
)
Expand Down Expand Up @@ -1126,7 +1127,10 @@ async def complete_verify(page: Page, scraped_page: ScrapedPage, task: Task, ste

# this prompt is critical to our agent so let's use the primary LLM API handler
verification_result = await app.LLM_API_HANDLER(
prompt=verification_prompt, step=step, screenshots=scraped_page_refreshed.screenshots
prompt=verification_prompt,
step=step,
screenshots=scraped_page_refreshed.screenshots,
prompt_name="check-user-goal",
)
return CompleteVerifyResult.model_validate(verification_result)

Expand Down Expand Up @@ -1411,8 +1415,10 @@ async def _build_extract_action_prompt(
elif task_type == TaskType.validation:
template = "decisive-criterion-validate"
elif task_type == TaskType.action:
prompt = prompt_engine.load_prompt("infer-action-type", navigation_goal=navigation_goal)
json_response = await app.LLM_API_HANDLER(prompt=prompt, step=step)
prompt = prompt_engine.load_prompt(
"infer-action-type", navigation_goal=navigation_goal, prompt_name="infer-action-type"
)
json_response = await app.LLM_API_HANDLER(prompt=prompt, step=step, prompt_name="infer-action-type")
if json_response.get("error"):
raise FailedToParseActionInstruction(
reason=json_response.get("thought"), error_type=json_response.get("error")
Expand Down Expand Up @@ -1914,6 +1920,18 @@ async def update_step(
diff=update_comparison,
)

# Track step duration when step is completed or failed
if status in [StepStatus.completed, StepStatus.failed]:
duration_seconds = (datetime.now(UTC) - step.created_at.replace(tzinfo=UTC)).total_seconds()
LOG.info(
"Step duration metrics",
task_id=step.task_id,
step_id=step.step_id,
duration_seconds=duration_seconds,
status=status,
organization_id=step.organization_id,
)

await save_step_logs(step.step_id)

return await app.DATABASE.update_step(
Expand Down Expand Up @@ -1948,6 +1966,19 @@ async def update_task(
for key, value in updates.items()
if getattr(task, key) != value
}

# Track task duration when task is completed, failed, or terminated
if status in [TaskStatus.completed, TaskStatus.failed, TaskStatus.terminated]:
duration_seconds = (datetime.now(UTC) - task.created_at.replace(tzinfo=UTC)).total_seconds()
LOG.info(
"Task duration metrics",
task_id=task.task_id,
workflow_run_id=task.workflow_run_id,
duration_seconds=duration_seconds,
status=status,
organization_id=task.organization_id,
)

await save_task_logs(task.task_id)
LOG.info("Updating task in db", task_id=task.task_id, diff=update_comparison)
return await app.DATABASE.update_task(
Expand Down Expand Up @@ -2040,7 +2071,9 @@ async def summary_failure_reason_for_max_steps(
navigation_payload=task.navigation_payload,
steps=steps_results,
)
json_response = await app.LLM_API_HANDLER(prompt=prompt, screenshots=screenshots, step=step)
json_response = await app.LLM_API_HANDLER(
prompt=prompt, screenshots=screenshots, step=step, prompt_name="summarize-max-steps-reason"
)
return json_response.get("reasoning", "")
except Exception:
LOG.warning("Failed to summary the failure reason", task_id=task.task_id, step_id=step.step_id)
Expand Down Expand Up @@ -2198,6 +2231,7 @@ async def handle_potential_verification_code(
prompt=extract_action_prompt,
step=step,
screenshots=scraped_page.screenshots,
prompt_name="extract-actions",
)
return json_response

Expand Down Expand Up @@ -2238,9 +2272,7 @@ async def create_extract_action(task: Task, step: Step, scraped_page: ScrapedPag
)

data_extraction_summary_resp = await app.SECONDARY_LLM_API_HANDLER(
prompt=prompt,
step=step,
screenshots=scraped_page.screenshots,
prompt=prompt, step=step, screenshots=scraped_page.screenshots, prompt_name="data-extraction-summary"
)
return ExtractAction(
reasoning=data_extraction_summary_resp.get("summary", "Extracting information from the page"),
Expand Down
6 changes: 4 additions & 2 deletions skyvern/forge/agent_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ async def _convert_svg_to_string(

for retry in range(SVG_SHAPE_CONVERTION_ATTEMPTS):
try:
json_response = await app.SECONDARY_LLM_API_HANDLER(prompt=svg_convert_prompt, step=step)
json_response = await app.SECONDARY_LLM_API_HANDLER(
prompt=svg_convert_prompt, step=step, prompt_name="svg-convert"
)
svg_shape = json_response.get("shape", "")
recognized = json_response.get("recognized", False)
if not svg_shape or not recognized:
Expand Down Expand Up @@ -316,7 +318,7 @@ async def _convert_css_shape_to_string(
for retry in range(CSS_SHAPE_CONVERTION_ATTEMPTS):
try:
json_response = await app.SECONDARY_LLM_API_HANDLER(
prompt=prompt, screenshots=[screenshot], step=step
prompt=prompt, screenshots=[screenshot], step=step, prompt_name="css-shape-convert"
)
css_shape = json_response.get("shape", "")
recognized = json_response.get("recognized", False)
Expand Down
38 changes: 35 additions & 3 deletions skyvern/forge/sdk/api/llm/api_handler_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def get_llm_api_handler_with_router(llm_key: str) -> LLMAPIHandler:

async def llm_api_handler_with_router_and_fallback(
prompt: str,
prompt_name: str,
step: Step | None = None,
observer_cruise: ObserverTask | None = None,
observer_thought: ObserverThought | None = None,
Expand All @@ -80,6 +81,8 @@ async def llm_api_handler_with_router_and_fallback(
Returns:
The response from the LLM router.
"""
start_time = time.time()

if parameters is None:
parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config)

Expand Down Expand Up @@ -120,7 +123,6 @@ async def llm_api_handler_with_router_and_fallback(
)
try:
response = await router.acompletion(model=main_model_group, messages=messages, **parameters)
LOG.info("LLM API call successful", llm_key=llm_key, model=llm_config.model_name)
except litellm.exceptions.APIError as e:
raise LLMProviderErrorRetryableTask(llm_key) from e
except ValueError as e:
Expand Down Expand Up @@ -195,6 +197,21 @@ async def llm_api_handler_with_router_and_fallback(
ai_suggestion=ai_suggestion,
)

# Track LLM API handler duration
duration_seconds = time.time() - start_time
LOG.info(
"LLM API handler duration metrics",
llm_key=llm_key,
model=main_model_group,
prompt_name=prompt_name,
duration_seconds=duration_seconds,
step_id=step.step_id if step else None,
observer_thought_id=observer_thought.observer_thought_id if observer_thought else None,
organization_id=step.organization_id
if step
else (observer_thought.organization_id if observer_thought else None),
)

return parsed_response

return llm_api_handler_with_router_and_fallback
Expand All @@ -210,13 +227,15 @@ def get_llm_api_handler(llm_key: str, base_parameters: dict[str, Any] | None = N

async def llm_api_handler(
prompt: str,
prompt_name: str,
step: Step | None = None,
observer_cruise: ObserverTask | None = None,
observer_thought: ObserverThought | None = None,
ai_suggestion: AISuggestion | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> dict[str, Any]:
start_time = time.time()
active_parameters = base_parameters or {}
if parameters is None:
parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config)
Expand Down Expand Up @@ -270,14 +289,12 @@ async def llm_api_handler(
# TODO (kerem): add a timeout to this call
# TODO (kerem): add a retry mechanism to this call (acompletion_with_retries)
# TODO (kerem): use litellm fallbacks? https://litellm.vercel.app/docs/tutorials/fallbacks#how-does-completion_with_fallbacks-work
LOG.info("Calling LLM API", llm_key=llm_key, model=llm_config.model_name)
response = await litellm.acompletion(
model=llm_config.model_name,
messages=messages,
timeout=settings.LLM_CONFIG_TIMEOUT,
**active_parameters,
)
LOG.info("LLM API call successful", llm_key=llm_key, model=llm_config.model_name)
except litellm.exceptions.APIError as e:
raise LLMProviderErrorRetryableTask(llm_key) from e
except CancelledError:
Expand Down Expand Up @@ -350,6 +367,21 @@ async def llm_api_handler(
ai_suggestion=ai_suggestion,
)

# Track LLM API handler duration
duration_seconds = time.time() - start_time
LOG.info(
"LLM API handler duration metrics",
llm_key=llm_key,
prompt_name=prompt_name,
model=llm_config.model_name,
duration_seconds=duration_seconds,
step_id=step.step_id if step else None,
observer_thought_id=observer_thought.observer_thought_id if observer_thought else None,
organization_id=step.organization_id
if step
else (observer_thought.organization_id if observer_thought else None),
)

return parsed_response

return llm_api_handler
Expand Down
1 change: 1 addition & 0 deletions skyvern/forge/sdk/api/llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class LLMAPIHandler(Protocol):
def __call__(
self,
prompt: str,
prompt_name: str,
step: Step | None = None,
observer_cruise: ObserverTask | None = None,
observer_thought: ObserverThought | None = None,
Expand Down
6 changes: 4 additions & 2 deletions skyvern/forge/sdk/routes/agent_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,9 @@ async def make_ai_suggestion(
ai_suggestion_type=ai_suggestion_type,
)

llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, ai_suggestion=new_ai_suggestion)
llm_response = await app.LLM_API_HANDLER(
prompt=llm_prompt, ai_suggestion=new_ai_suggestion, prompt_name="suggest-data-schema"
)
parsed_ai_suggestion = AISuggestionBase.model_validate(llm_response)

return parsed_ai_suggestion
Expand Down Expand Up @@ -1045,7 +1047,7 @@ async def generate_task(

llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt)
try:
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt)
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt, prompt_name="generate-task")
parsed_task_generation_obj = TaskGenerationBase.model_validate(llm_response)

# generate a TaskGenerationModel
Expand Down
2 changes: 1 addition & 1 deletion skyvern/forge/sdk/routes/totp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ async def save_totp_code(

async def parse_totp_code(content: str) -> str | None:
prompt = prompt_engine.load_prompt("parse-verification-code", content=content)
code_resp = await app.SECONDARY_LLM_API_HANDLER(prompt=prompt)
code_resp = await app.SECONDARY_LLM_API_HANDLER(prompt=prompt, prompt_name="parse-verification-code")
return code_resp.get("code", None)
Loading

0 comments on commit 204972e

Please sign in to comment.