Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new SDK; fix state warnings #51

Merged
merged 3 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 15 additions & 32 deletions computer-use-demo/computer_use_demo/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
from anthropic.types import (
MessageParam,
ToolParam,
ToolResultBlockParam,
)
from anthropic.types.beta import (
Expand All @@ -21,7 +19,6 @@
BetaMessage,
BetaMessageParam,
BetaTextBlockParam,
BetaToolParam,
BetaToolResultBlockParam,
)

Expand Down Expand Up @@ -95,39 +92,25 @@ async def sampling_loop(
if only_n_most_recent_images:
_maybe_filter_to_n_most_recent_images(messages, only_n_most_recent_images)

if provider == APIProvider.ANTHROPIC:
client = Anthropic(api_key=api_key)
elif provider == APIProvider.VERTEX:
client = AnthropicVertex()
elif provider == APIProvider.BEDROCK:
client = AnthropicBedrock()

# Call the API
# we use raw_response to provide debug information to streamlit. Your
# implementation may be able call the SDK directly with:
# `response = client.messages.create(...)` instead.
if provider == APIProvider.ANTHROPIC:
raw_response = Anthropic(
api_key=api_key
).beta.messages.with_raw_response.create(
max_tokens=max_tokens,
messages=messages,
model=model,
system=system,
tools=cast(list[BetaToolParam], tool_collection.to_params()),
extra_headers={"anthropic-beta": BETA_FLAG},
)
elif provider == APIProvider.VERTEX:
raw_response = AnthropicVertex().messages.with_raw_response.create(
max_tokens=max_tokens,
messages=cast(list[MessageParam], messages),
model=model,
system=system,
tools=cast(list[ToolParam], tool_collection.to_params()),
extra_headers={"anthropic-beta": BETA_FLAG},
)
elif provider == APIProvider.BEDROCK:
raw_response = AnthropicBedrock().messages.with_raw_response.create(
max_tokens=max_tokens,
messages=cast(list[MessageParam], messages),
model=model,
system=system,
tools=cast(list[ToolParam], tool_collection.to_params()),
extra_body={"anthropic_beta": [BETA_FLAG]},
)
raw_response = client.beta.messages.with_raw_response.create(
max_tokens=max_tokens,
messages=messages,
model=model,
system=system,
tools=tool_collection.to_params(),
betas=["computer-use-2024-10-22"],
)

api_response_callback(cast(APIResponse[BetaMessage], raw_response))

Expand Down
2 changes: 1 addition & 1 deletion computer-use-demo/computer_use_demo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
streamlit>=1.38.0
anthropic[bedrock,vertex]>=0.36.2
anthropic[bedrock,vertex]>=0.37.1
jsonschema==4.22.0
boto3>=1.28.57
google-auth<3,>=2
11 changes: 2 additions & 9 deletions computer-use-demo/computer_use_demo/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def setup_state():
st.session_state.api_key = load_from_storage("api_key") or os.getenv(
"ANTHROPIC_API_KEY", ""
)
if "api_key_input" not in st.session_state:
st.session_state.api_key_input = st.session_state.api_key
if "provider" not in st.session_state:
st.session_state.provider = (
os.getenv("API_PROVIDER", "anthropic") or APIProvider.ANTHROPIC
Expand Down Expand Up @@ -114,7 +112,6 @@ def _reset_api_provider():
st.radio(
"API Provider",
options=provider_options,
index=provider_options.index(st.session_state.provider),
key="provider_radio",
format_func=lambda x: x.title(),
on_change=_reset_api_provider,
Expand All @@ -125,14 +122,10 @@ def _reset_api_provider():
if st.session_state.provider == APIProvider.ANTHROPIC:
st.text_input(
"Anthropic API Key",
value=st.session_state.api_key,
type="password",
key="api_key_input",
on_change=lambda: save_to_storage(
"api_key", st.session_state.api_key_input
),
key="api_key",
on_change=lambda: save_to_storage("api_key", st.session_state.api_key),
)
st.session_state.api_key = st.session_state.api_key_input

st.number_input(
"Only send N most recent images",
Expand Down
35 changes: 5 additions & 30 deletions computer-use-demo/computer_use_demo/tools/base.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,23 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, fields, replace
from typing import Any, ClassVar, Literal, Optional, Required, TypedDict
from typing import Any

APIToolType = Literal["computer_20241022", "text_editor_20241022", "bash_20241022"]
APIToolName = Literal["computer", "str_replace_editor", "bash"]


class AnthropicAPIToolParam(TypedDict):
"""API shape for Anthropic-defined tools."""

name: Required[APIToolName]
type: Required[APIToolType]


class ComputerToolOptions(TypedDict):
display_height_px: Required[int]
display_width_px: Required[int]
display_number: Optional[int]
from anthropic.types.beta import BetaToolUnionParam


class BaseAnthropicTool(metaclass=ABCMeta):
"""Abstract base class for Anthropic-defined tools."""

name: ClassVar[APIToolName]
api_type: ClassVar[APIToolType]

@property
def options(self) -> ComputerToolOptions | None:
return None

@abstractmethod
def __call__(self, **kwargs) -> Any:
"""Executes the tool with the given arguments."""
...

@abstractmethod
def to_params(
self,
) -> dict: # -> AnthropicToolParam & Optional[ComputerToolOptions]
"""Creates the shape necessary to this tool to the Anthropic API."""
return {
"name": self.name,
"type": self.api_type,
**(self.options or {}),
}
) -> BetaToolUnionParam:
raise NotImplementedError


@dataclass(kw_only=True, frozen=True)
Expand Down
13 changes: 11 additions & 2 deletions computer-use-demo/computer_use_demo/tools/bash.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
import os
from typing import ClassVar, Literal

from anthropic.types.beta import BetaToolBash20241022Param

from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult

Expand Down Expand Up @@ -107,8 +110,8 @@ class BashTool(BaseAnthropicTool):
"""

_session: _BashSession | None
name = "bash"
api_type = "bash_20241022"
name: ClassVar[Literal["bash"]] = "bash"
api_type: ClassVar[Literal["bash_20241022"]] = "bash_20241022"

def __init__(self):
self._session = None
Expand All @@ -133,3 +136,9 @@ async def __call__(
return await self._session.run(command)

raise ToolError("no command provided.")

def to_params(self) -> BetaToolBash20241022Param:
return {
"type": self.api_type,
"name": self.name,
}
6 changes: 4 additions & 2 deletions computer-use-demo/computer_use_demo/tools/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Any

from anthropic.types.beta import BetaToolUnionParam

from .base import (
BaseAnthropicTool,
ToolError,
Expand All @@ -15,11 +17,11 @@ class ToolCollection:

def __init__(self, *tools: BaseAnthropicTool):
self.tools = tools
self.tool_map = {tool.name: tool for tool in tools}
self.tool_map = {tool.to_params()["name"]: tool for tool in tools}

def to_params(
self,
) -> list[dict]: # -> List[AnthropicToolParam & Optional[ComputerToolOptions]]
) -> list[BetaToolUnionParam]:
return [tool.to_params() for tool in self.tools]

async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult:
Expand Down
17 changes: 14 additions & 3 deletions computer-use-demo/computer_use_demo/tools/computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from typing import Literal, TypedDict
from uuid import uuid4

from .base import BaseAnthropicTool, ComputerToolOptions, ToolError, ToolResult
from anthropic.types.beta import BetaToolComputerUse20241022Param

from .base import BaseAnthropicTool, ToolError, ToolResult
from .run import run

OUTPUT_DIR = "/tmp/outputs"
Expand Down Expand Up @@ -49,6 +51,12 @@ class ScalingSource(StrEnum):
API = "api"


class ComputerToolOptions(TypedDict):
display_height_px: int
display_width_px: int
display_number: int | None


def chunks(s: str, chunk_size: int) -> list[str]:
return [s[i : i + chunk_size] for i in range(0, len(s), chunk_size)]

Expand All @@ -59,8 +67,8 @@ class ComputerTool(BaseAnthropicTool):
The tool parameters are defined by Anthropic and are not editable.
"""

name = "computer"
api_type = "computer_20241022"
name: Literal["computer"] = "computer"
api_type: Literal["computer_20241022"] = "computer_20241022"
width: int
height: int
display_num: int | None
Expand All @@ -79,6 +87,9 @@ def options(self) -> ComputerToolOptions:
"display_number": self.display_num,
}

def to_params(self) -> BetaToolComputerUse20241022Param:
return {"name": self.name, "type": self.api_type, **self.options}

def __init__(self):
super().__init__()

Expand Down
12 changes: 10 additions & 2 deletions computer-use-demo/computer_use_demo/tools/edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pathlib import Path
from typing import Literal, get_args

from anthropic.types.beta import BetaToolTextEditor20241022Param

from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
from .run import maybe_truncate, run

Expand All @@ -21,15 +23,21 @@ class EditTool(BaseAnthropicTool):
The tool parameters are defined by Anthropic and are not editable.
"""

api_type = "text_editor_20241022"
name = "str_replace_editor"
api_type: Literal["text_editor_20241022"] = "text_editor_20241022"
name: Literal["str_replace_editor"] = "str_replace_editor"

_file_history: dict[Path, list[str]]

def __init__(self):
self._file_history = defaultdict(list)
super().__init__()

def to_params(self) -> BetaToolTextEditor20241022Param:
return {
"name": self.name,
"type": self.api_type,
}

async def __call__(
self,
*,
Expand Down
Loading