Skip to content

Commit

Permalink
O3 Mini support (#1709)
Browse files Browse the repository at this point in the history
  • Loading branch information
wintonzheng authored Feb 3, 2025
1 parent e0e8684 commit 59756cb
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 13 deletions.
20 changes: 15 additions & 5 deletions skyvern/forge/sdk/api/llm/api_handler_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,12 @@ async def llm_api_handler_with_router_and_fallback(
LOG.exception("Failed to calculate LLM cost", error=str(e))
llm_cost = 0
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)

# TODO (suchintan): Properly support reasoning tokens
reasoning_tokens = response.get("usage", {}).get("reasoning_tokens", 0)
LOG.info("Reasoning tokens", reasoning_tokens=reasoning_tokens)

completion_tokens = response.get("usage", {}).get("completion_tokens", 0) + reasoning_tokens

if step:
await app.DATABASE.update_step(
Expand Down Expand Up @@ -388,10 +393,15 @@ async def llm_api_handler(

@staticmethod
def get_api_parameters(llm_config: LLMConfig | LLMRouterConfig) -> dict[str, Any]:
return {
"max_tokens": llm_config.max_output_tokens,
"temperature": settings.LLM_CONFIG_TEMPERATURE,
}
params: dict[str, Any] = {"max_completion_tokens": llm_config.max_completion_tokens}

if llm_config.temperature is not None:
params["temperature"] = llm_config.temperature

if llm_config.reasoning_effort is not None:
params["reasoning_effort"] = llm_config.reasoning_effort

return params

@classmethod
def register_custom_handler(cls, llm_key: str, handler: LLMAPIHandler) -> None:
Expand Down
24 changes: 18 additions & 6 deletions skyvern/forge/sdk/api/llm/config_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,19 @@ def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig:
LLMConfigRegistry.register_config(
"OPENAI_GPT4O",
LLMConfig(
"gpt-4o", ["OPENAI_API_KEY"], supports_vision=True, add_assistant_prefix=False, max_output_tokens=16384
"gpt-4o", ["OPENAI_API_KEY"], supports_vision=True, add_assistant_prefix=False, max_completion_tokens=16384
),
)
LLMConfigRegistry.register_config(
"OPENAI_O3_MINI",
LLMConfig(
"o3-mini",
["OPENAI_API_KEY"],
supports_vision=False,
add_assistant_prefix=False,
max_completion_tokens=16384,
temperature=None, # Temperature isn't supported in the O-model series
reasoning_effort="high",
),
)
LLMConfigRegistry.register_config(
Expand All @@ -90,7 +102,7 @@ def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig:
["OPENAI_API_KEY"],
supports_vision=True,
add_assistant_prefix=False,
max_output_tokens=16384,
max_completion_tokens=16384,
),
)
LLMConfigRegistry.register_config(
Expand All @@ -100,7 +112,7 @@ def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig:
["OPENAI_API_KEY"],
supports_vision=True,
add_assistant_prefix=False,
max_output_tokens=16384,
max_completion_tokens=16384,
),
)

Expand Down Expand Up @@ -149,7 +161,7 @@ def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig:
["ANTHROPIC_API_KEY"],
supports_vision=True,
add_assistant_prefix=True,
max_output_tokens=8192,
max_completion_tokens=8192,
),
)

Expand Down Expand Up @@ -275,7 +287,7 @@ def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig:
["GEMINI_API_KEY"],
supports_vision=True,
add_assistant_prefix=False,
max_output_tokens=8192,
max_completion_tokens=8192,
),
)
LLMConfigRegistry.register_config(
Expand All @@ -285,7 +297,7 @@ def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig:
["GEMINI_API_KEY"],
supports_vision=True,
add_assistant_prefix=False,
max_output_tokens=8192,
max_completion_tokens=8192,
),
)

Expand Down
8 changes: 6 additions & 2 deletions skyvern/forge/sdk/api/llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def get_missing_env_vars(self) -> list[str]:
@dataclass(frozen=True)
class LLMConfig(LLMConfigBase):
litellm_params: Optional[LiteLLMParams] = field(default=None)
max_output_tokens: int = SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS
max_completion_tokens: int = SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS
temperature: float | None = SettingsManager.get_settings().LLM_CONFIG_TEMPERATURE
reasoning_effort: str | None = None


@dataclass(frozen=True)
Expand Down Expand Up @@ -72,7 +74,9 @@ class LLMRouterConfig(LLMConfigBase):
allowed_fails: int | None = None
allowed_fails_policy: AllowedFailsPolicy | None = None
cooldown_time: float | None = None
max_output_tokens: int = SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS
max_completion_tokens: int = SettingsManager.get_settings().LLM_CONFIG_MAX_TOKENS
reasoning_effort: str | None = None
temperature: float | None = SettingsManager.get_settings().LLM_CONFIG_TEMPERATURE


class LLMAPIHandler(Protocol):
Expand Down

0 comments on commit 59756cb

Please sign in to comment.