Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made the message role of ReAct observation configurable #17521

Merged
merged 7 commits into from
Jan 23, 2025
52 changes: 34 additions & 18 deletions docs/docs/examples/agent/react_agent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,7 @@
"execution_count": null,
"id": "e8ac1778-0585-43c9-9dad-014d13d7460d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package stopwords to /Users/jerryliu/Programmi\n",
"[nltk_data] ng/gpt_index/.venv/lib/python3.10/site-\n",
"[nltk_data] packages/llama_index/legacy/_static/nltk_cache...\n",
"[nltk_data] Unzipping corpora/stopwords.zip.\n",
"[nltk_data] Downloading package punkt to /Users/jerryliu/Programming/g\n",
"[nltk_data] pt_index/.venv/lib/python3.10/site-\n",
"[nltk_data] packages/llama_index/legacy/_static/nltk_cache...\n",
"[nltk_data] Unzipping tokenizers/punkt.zip.\n"
]
}
],
"outputs": [],
"source": [
"from llama_index.core.agent import ReActAgent\n",
"from llama_index.llms.openai import OpenAI\n",
Expand Down Expand Up @@ -468,13 +453,44 @@
"response = agent.chat(\"What is 5+3+2\")\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"id": "76190511-692c-4642-9b86-adac88c98550",
"metadata": {},
"source": [
"### Customizing the Message Role of Observation\n",
"\n",
"If the LLM you use supports function/tool calling, you may set the message role of observations to `MessageRole.TOOL`. \n",
"Doing this will prevent the tool outputs from being misinterpreted as new user messages for some models."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6d5e8c1-c40e-4a96-8d2e-84127f066265",
"metadata": {},
"outputs": [],
"source": [
"from llama_index.core.agent import ReActChatFormatter\n",
"from llama_index.core.llms import MessageRole\n",
"\n",
"agent = ReActAgent.from_tools(\n",
" [multiply_tool, add_tool],\n",
" llm=llm,\n",
" react_chat_formatter=ReActChatFormatter.from_defaults(\n",
" observation_role=MessageRole.TOOL\n",
" ),\n",
" verbose=True,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "llama_index_v2",
"display_name": "LlamaIndex Development",
"language": "python",
"name": "llama_index_v2"
"name": "llama-index-dev"
},
"language_info": {
"codemirror_mode": {
Expand Down
16 changes: 13 additions & 3 deletions llama-index-core/llama_index/core/agent/react/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ObservationReasoningStep,
)
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.bridge.pydantic import BaseModel, ConfigDict
from llama_index.core.bridge.pydantic import BaseModel, ConfigDict, Field
from llama_index.core.tools import BaseTool

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,6 +53,14 @@ class ReActChatFormatter(BaseAgentChatFormatter):

system_header: str = REACT_CHAT_SYSTEM_HEADER # default
context: str = "" # not needed w/ default
observation_role: MessageRole = Field(
default=MessageRole.USER,
description=(
"Message role of tool outputs. If the LLM you use supports function/tool "
"calling, you may set it to `MessageRole.TOOL` to avoid the tool outputs "
"being misinterpreted as new user messages."
),
)

def format(
self,
Expand All @@ -73,13 +81,13 @@ def format(
fmt_sys_header = self.system_header.format(**format_args)

# format reasoning history as alternating user and assistant messages
# where the assistant messages are thoughts and actions and the user
# where the assistant messages are thoughts and actions and the tool
# messages are observations
reasoning_history = []
for reasoning_step in current_reasoning:
if isinstance(reasoning_step, ObservationReasoningStep):
message = ChatMessage(
role=MessageRole.USER,
role=self.observation_role,
content=reasoning_step.get_content(),
)
else:
Expand All @@ -100,6 +108,7 @@ def from_defaults(
cls,
system_header: Optional[str] = None,
context: Optional[str] = None,
observation_role: MessageRole = MessageRole.USER,
) -> "ReActChatFormatter":
"""Create ReActChatFormatter from defaults."""
if not system_header:
Expand All @@ -112,6 +121,7 @@ def from_defaults(
return ReActChatFormatter(
system_header=system_header,
context=context or "",
observation_role=observation_role,
)

@classmethod
Expand Down
Loading