Skip to content

Commit

Permalink
add function calling to aws model
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobweiss2305 committed Sep 19, 2024
1 parent 814903f commit dd064cd
Show file tree
Hide file tree
Showing 8 changed files with 549 additions and 0 deletions.
Empty file.
9 changes: 9 additions & 0 deletions cookbook/providers/bedrock/anthropic/basic_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from phi.agent import Agent
from phi.model.aws.anthropic import Claude

agent = Agent(
model=Claude(model="anthropic.claude-3-5-sonnet-20240620-v1:0"),
description="You help people with their health and fitness goals.",
)

agent.print_response("Share a healthy breakfast recipe", stream=True)
11 changes: 11 additions & 0 deletions cookbook/providers/bedrock/anthropic/basic_stream_off.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from phi.agent import Agent, RunResponse
from phi.model.aws.anthropic import Claude

agent = Agent(
model=Claude(model="anthropic.claude-3-5-sonnet-20240620-v1:0"),
description="You help people with their health and fitness goals.",
)

run: RunResponse = agent.run("Share a quick healthy breakfast recipe.")

print(run.content)
21 changes: 21 additions & 0 deletions cookbook/providers/bedrock/anthropic/duckduckgo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from phi.agent import Agent, RunResponse
from phi.model.aws.anthropic import Claude
from phi.tools.duckduckgo import DuckDuckGo

agent = Agent(
model=Claude(model="anthropic.claude-3-5-sonnet-20240620-v1:0"),
tools=[DuckDuckGo()],
instructions=["use your tools to search internet"],
show_tool_calls=True,
debug_mode=True,
)

run: RunResponse = agent.run(
"you need to preform multiple searches. first list top 5 college football teams. then search for the mascot of the team with the most wins",
)

print(f"""
Agent Response: {run.content}
""")
5 changes: 5 additions & 0 deletions cookbook/providers/openai/basic_stream_off.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from phi.agent import Agent, RunResponse
from phi.model.openai import OpenAIChat

import time

agent = Agent(
model=OpenAIChat(model="gpt-4o"),
)

start_time = time.time()

run: RunResponse = agent.run("Share a healthy breakfast recipe") # type: ignore

print(run.content)

Empty file added phi/model/aws/__init__.py
Empty file.
127 changes: 127 additions & 0 deletions phi/model/aws/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Optional, List, Dict, Any

from phi.model.message import Message
from phi.model.aws.bedrock import AwsBedrock


class Claude(AwsBedrock):
name: str = "AwsBedrockAnthropicClaude"
model: str = "anthropic.claude-3-sonnet-20240229-v1:0"
# -*- Request parameters
max_tokens: int = 8192
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
anthropic_version: str = "bedrock-2023-05-31"
request_params: Optional[Dict[str, Any]] = None
# -*- Client parameters
client_params: Optional[Dict[str, Any]] = None

@property
def api_kwargs(self) -> Dict[str, Any]:
_request_params: Dict[str, Any] = {}
if self.anthropic_version:
_request_params["anthropic_version"] = self.anthropic_version
if self.max_tokens:
_request_params["max_tokens"] = self.max_tokens
if self.temperature:
_request_params["temperature"] = self.temperature
if self.stop_sequences:
_request_params["stop_sequences"] = self.stop_sequences
if self.tools is not None:
if _request_params.get("stop_sequences") is None:
_request_params["stop_sequences"] = ["</function_calls>"]
elif "</function_calls>" not in _request_params["stop_sequences"]:
_request_params["stop_sequences"].append("</function_calls>")
if self.top_p:
_request_params["top_p"] = self.top_p
if self.top_k:
_request_params["top_k"] = self.top_k
if self.request_params:
_request_params.update(self.request_params)
return _request_params

def get_tools(self):
"""
Refactors the tools in a format accepted by the Anthropic API.
"""
if not self.functions:
return None

tools: List = []
for f_name, function in self.functions.items():
required_params = [
param_name
for param_name, param_info in function.parameters.get("properties", {}).items()
if "null"
not in (
param_info.get("type") if isinstance(param_info.get("type"), list) else [param_info.get("type")]
)
]
tools.append(
{
"toolSpec": {
"name": f_name,
"description": function.description or "",
"inputSchema": {
"json": {
"type": function.parameters.get("type") or "object",
"properties": {
param_name: {
"type": param_info.get("type") or "",
"description": param_info.get("description") or "",
}
for param_name, param_info in function.parameters.get("properties", {}).items()
},
"required": required_params,
}
},
}
}
)
return tools

def get_request_body(self, messages: List[Message]) -> Dict[str, Any]:
system_prompt = None
messages_for_api = []
for m in messages:
if m.role == "system":
system_prompt = m.content
else:
messages_for_api.append({"role": m.role, "content": [{"text": m.content}]})

# -*- Build request body
request_body = {
"messages": messages_for_api,
**self.api_kwargs,
}
if self.tools:
request_body["tools"] = self.get_tools()

if system_prompt:
request_body["system"] = system_prompt
return request_body

def parse_response_message(self, response: Dict[str, Any]) -> Message:
output = response.get("output", {})
message = output.get("message", {})

role = message.get("role", "assistant")
content = message.get("content", [])

if isinstance(content, list):
text_content = "\n".join([item.get("text", "") for item in content if isinstance(item, dict)])
elif isinstance(content, dict):
text_content = content.get("text", "")
elif isinstance(content, str):
text_content = content
else:
text_content = ""

return Message(role=role, content=text_content)

def parse_response_delta(self, response: Dict[str, Any]) -> Optional[str]:
if "delta" in response:
return response.get("delta", {}).get("text")
return response.get("completion")
Loading

0 comments on commit dd064cd

Please sign in to comment.