Skip to content

Commit

Permalink
enable prompt caching, move everything to be param shapes (#63)
Browse files Browse the repository at this point in the history
* enable prompt caching, move everything to be param shapes
  • Loading branch information
nsmccandlish authored Oct 24, 2024
1 parent 0367e43 commit be847c4
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 42 deletions.
98 changes: 76 additions & 22 deletions computer-use-demo/computer_use_demo/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,22 @@
APIResponseValidationError,
APIStatusError,
)
from anthropic.types import (
ToolResultBlockParam,
)
from anthropic.types.beta import (
BetaContentBlock,
BetaCacheControlEphemeralParam,
BetaContentBlockParam,
BetaImageBlockParam,
BetaMessage,
BetaMessageParam,
BetaTextBlock,
BetaTextBlockParam,
BetaToolResultBlockParam,
BetaToolUseBlockParam,
)

from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult

BETA_FLAG = "computer-use-2024-10-22"
COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"


class APIProvider(StrEnum):
Expand Down Expand Up @@ -75,7 +76,7 @@ async def sampling_loop(
provider: APIProvider,
system_prompt_suffix: str,
messages: list[BetaMessageParam],
output_callback: Callable[[BetaContentBlock], None],
output_callback: Callable[[BetaContentBlockParam], None],
tool_output_callback: Callable[[ToolResult, str], None],
api_response_callback: Callable[
[httpx.Request, httpx.Response | object | None, Exception | None], None
Expand All @@ -92,21 +93,37 @@ async def sampling_loop(
BashTool(),
EditTool(),
)
system = (
f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}"
system = BetaTextBlockParam(
type="text",
text=f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}",
)

while True:
if only_n_most_recent_images:
_maybe_filter_to_n_most_recent_images(messages, only_n_most_recent_images)

enable_prompt_caching = False
betas = [COMPUTER_USE_BETA_FLAG]
image_truncation_threshold = 10
if provider == APIProvider.ANTHROPIC:
client = Anthropic(api_key=api_key)
enable_prompt_caching = True
elif provider == APIProvider.VERTEX:
client = AnthropicVertex()
elif provider == APIProvider.BEDROCK:
client = AnthropicBedrock()

if enable_prompt_caching:
betas.append(PROMPT_CACHING_BETA_FLAG)
_inject_prompt_caching(messages)
# Is it ever worth it to bust the cache with prompt caching?
image_truncation_threshold = 50
system["cache_control"] = {"type": "ephemeral"}

if only_n_most_recent_images:
_maybe_filter_to_n_most_recent_images(
messages,
only_n_most_recent_images,
min_removal_threshold=image_truncation_threshold,
)

# Call the API
# we use raw_response to provide debug information to streamlit. Your
# implementation may be able call the SDK directly with:
Expand All @@ -116,9 +133,9 @@ async def sampling_loop(
max_tokens=max_tokens,
messages=messages,
model=model,
system=system,
system=[system],
tools=tool_collection.to_params(),
betas=[BETA_FLAG],
betas=betas,
)
except (APIStatusError, APIResponseValidationError) as e:
api_response_callback(e.request, e.response, e)
Expand All @@ -133,25 +150,26 @@ async def sampling_loop(

response = raw_response.parse()

response_params = _response_to_params(response)
messages.append(
{
"role": "assistant",
"content": cast(list[BetaContentBlockParam], response.content),
"content": response_params,
}
)

tool_result_content: list[BetaToolResultBlockParam] = []
for content_block in cast(list[BetaContentBlock], response.content):
for content_block in response_params:
output_callback(content_block)
if content_block.type == "tool_use":
if content_block["type"] == "tool_use":
result = await tool_collection.run(
name=content_block.name,
tool_input=cast(dict[str, Any], content_block.input),
name=content_block["name"],
tool_input=cast(dict[str, Any], content_block["input"]),
)
tool_result_content.append(
_make_api_tool_result(result, content_block.id)
_make_api_tool_result(result, content_block["id"])
)
tool_output_callback(result, content_block.id)
tool_output_callback(result, content_block["id"])

if not tool_result_content:
return messages
Expand All @@ -162,7 +180,7 @@ async def sampling_loop(
def _maybe_filter_to_n_most_recent_images(
messages: list[BetaMessageParam],
images_to_keep: int,
min_removal_threshold: int = 10,
min_removal_threshold: int,
):
"""
With the assumption that images are screenshots that are of diminishing value as
Expand All @@ -174,7 +192,7 @@ def _maybe_filter_to_n_most_recent_images(
return messages

tool_result_blocks = cast(
list[ToolResultBlockParam],
list[BetaToolResultBlockParam],
[
item
for message in messages
Expand Down Expand Up @@ -208,6 +226,42 @@ def _maybe_filter_to_n_most_recent_images(
tool_result["content"] = new_content


def _response_to_params(
response: BetaMessage,
) -> list[BetaTextBlockParam | BetaToolUseBlockParam]:
res: list[BetaTextBlockParam | BetaToolUseBlockParam] = []
for block in response.content:
if isinstance(block, BetaTextBlock):
res.append({"type": "text", "text": block.text})
else:
res.append(cast(BetaToolUseBlockParam, block.model_dump()))
return res


def _inject_prompt_caching(
messages: list[BetaMessageParam],
):
"""
Set cache breakpoints for the 3 most recent turns
one cache breakpoint is left for tools/system prompt, to be shared across sessions
"""

breakpoints_remaining = 3
for message in reversed(messages):
if message["role"] == "user" and isinstance(
content := message["content"], list
):
if breakpoints_remaining:
breakpoints_remaining -= 1
content[-1]["cache_control"] = BetaCacheControlEphemeralParam(
{"type": "ephemeral"}
)
else:
content[-1].pop("cache_control", None)
# we'll only every have one extra turn per loop
break


def _make_api_tool_result(
result: ToolResult, tool_use_id: str
) -> BetaToolResultBlockParam:
Expand Down
31 changes: 15 additions & 16 deletions computer-use-demo/computer_use_demo/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
import httpx
import streamlit as st
from anthropic import RateLimitError
from anthropic.types import (
TextBlock,
from anthropic.types.beta import (
BetaContentBlockParam,
BetaTextBlockParam,
)
from anthropic.types.beta import BetaTextBlock, BetaToolUseBlock
from anthropic.types.tool_use_block import ToolUseBlock
from streamlit.delta_generator import DeltaGenerator

from computer_use_demo.loop import (
Expand Down Expand Up @@ -184,7 +183,7 @@ def _reset_api_provider():
else:
_render_message(
message["role"],
cast(BetaTextBlock | BetaToolUseBlock, block),
cast(BetaContentBlockParam | ToolResult, block),
)

# render past http exchanges
Expand All @@ -196,7 +195,7 @@ def _reset_api_provider():
st.session_state.messages.append(
{
"role": Sender.USER,
"content": [TextBlock(type="text", text=new_message)],
"content": [BetaTextBlockParam(type="text", text=new_message)],
}
)
_render_message(Sender.USER, new_message)
Expand Down Expand Up @@ -345,15 +344,11 @@ def _render_error(error: Exception):

def _render_message(
sender: Sender,
message: str | BetaTextBlock | BetaToolUseBlock | ToolResult,
message: str | BetaContentBlockParam | ToolResult,
):
"""Convert input from the user or output from the agent to a streamlit message."""
# streamlit's hotreloading breaks isinstance checks, so we need to check for class names
is_tool_result = not isinstance(message, str) and (
isinstance(message, ToolResult)
or message.__class__.__name__ == "ToolResult"
or message.__class__.__name__ == "CLIResult"
)
is_tool_result = not isinstance(message, str | dict)
if not message or (
is_tool_result
and st.session_state.hide_images
Expand All @@ -373,10 +368,14 @@ def _render_message(
st.error(message.error)
if message.base64_image and not st.session_state.hide_images:
st.image(base64.b64decode(message.base64_image))
elif isinstance(message, BetaTextBlock) or isinstance(message, TextBlock):
st.write(message.text)
elif isinstance(message, BetaToolUseBlock) or isinstance(message, ToolUseBlock):
st.code(f"Tool Use: {message.name}\nInput: {message.input}")
elif isinstance(message, dict):
if message["type"] == "text":
st.write(message["text"])
elif message["type"] == "tool_use":
st.code(f'Tool Use: {message["name"]}\nInput: {message["input"]}')
else:
# only expected return types are text and tool_use
raise Exception(f'Unexpected response type {message["type"]}')
else:
st.markdown(message)

Expand Down
6 changes: 4 additions & 2 deletions computer-use-demo/tests/loop_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest import mock

from anthropic.types import TextBlock, ToolUseBlock
from anthropic.types.beta import BetaMessage, BetaMessageParam
from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaTextBlockParam

from computer_use_demo.loop import APIProvider, sampling_loop

Expand Down Expand Up @@ -58,7 +58,9 @@ async def test_loop():
tool_collection.run.assert_called_once_with(
name="computer", tool_input={"action": "test"}
)
output_callback.assert_called_with(TextBlock(text="Done!", type="text"))
output_callback.assert_called_with(
BetaTextBlockParam(text="Done!", type="text")
)
assert output_callback.call_count == 3
assert tool_output_callback.call_count == 1
assert api_response_callback.call_count == 2
8 changes: 6 additions & 2 deletions computer-use-demo/tests/streamlit_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from unittest import mock

import pytest
from anthropic.types import TextBlockParam
from streamlit.testing.v1 import AppTest

from computer_use_demo.streamlit import Sender, TextBlock
from computer_use_demo.streamlit import Sender


@pytest.fixture
Expand All @@ -18,6 +19,9 @@ def test_streamlit(streamlit_app: AppTest):
streamlit_app.chat_input[0].set_value("Hello").run()
assert patch.called
assert patch.call_args.kwargs["messages"] == [
{"role": Sender.USER, "content": [TextBlock(text="Hello", type="text")]}
{
"role": Sender.USER,
"content": [TextBlockParam(text="Hello", type="text")],
}
]
assert not streamlit_app.exception

0 comments on commit be847c4

Please sign in to comment.