Skip to content

Commit

Permalink
Made the message role of ReAct observation configurable (#17521)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesljlster authored Jan 23, 2025
1 parent f29cd17 commit 9cea48b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 21 deletions.
52 changes: 34 additions & 18 deletions docs/docs/examples/agent/react_agent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,7 @@
"execution_count": null,
"id": "e8ac1778-0585-43c9-9dad-014d13d7460d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package stopwords to /Users/jerryliu/Programmi\n",
"[nltk_data] ng/gpt_index/.venv/lib/python3.10/site-\n",
"[nltk_data] packages/llama_index/legacy/_static/nltk_cache...\n",
"[nltk_data] Unzipping corpora/stopwords.zip.\n",
"[nltk_data] Downloading package punkt to /Users/jerryliu/Programming/g\n",
"[nltk_data] pt_index/.venv/lib/python3.10/site-\n",
"[nltk_data] packages/llama_index/legacy/_static/nltk_cache...\n",
"[nltk_data] Unzipping tokenizers/punkt.zip.\n"
]
}
],
"outputs": [],
"source": [
"from llama_index.core.agent import ReActAgent\n",
"from llama_index.llms.openai import OpenAI\n",
Expand Down Expand Up @@ -468,13 +453,44 @@
"response = agent.chat(\"What is 5+3+2\")\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"id": "76190511-692c-4642-9b86-adac88c98550",
"metadata": {},
"source": [
"### Customizing the Message Role of Observation\n",
"\n",
"If the LLM you use supports function/tool calling, you may set the message role of observations to `MessageRole.TOOL`. \n",
"Doing this will prevent the tool outputs from being misinterpreted as new user messages for some models."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6d5e8c1-c40e-4a96-8d2e-84127f066265",
"metadata": {},
"outputs": [],
"source": [
"from llama_index.core.agent import ReActChatFormatter\n",
"from llama_index.core.llms import MessageRole\n",
"\n",
"agent = ReActAgent.from_tools(\n",
" [multiply_tool, add_tool],\n",
" llm=llm,\n",
" react_chat_formatter=ReActChatFormatter.from_defaults(\n",
" observation_role=MessageRole.TOOL\n",
" ),\n",
" verbose=True,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "llama_index_v2",
"display_name": "LlamaIndex Development",
"language": "python",
"name": "llama_index_v2"
"name": "llama-index-dev"
},
"language_info": {
"codemirror_mode": {
Expand Down
16 changes: 13 additions & 3 deletions llama-index-core/llama_index/core/agent/react/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ObservationReasoningStep,
)
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.bridge.pydantic import BaseModel, ConfigDict
from llama_index.core.bridge.pydantic import BaseModel, ConfigDict, Field
from llama_index.core.tools import BaseTool

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,6 +53,14 @@ class ReActChatFormatter(BaseAgentChatFormatter):

system_header: str = REACT_CHAT_SYSTEM_HEADER # default
context: str = "" # not needed w/ default
observation_role: MessageRole = Field(
default=MessageRole.USER,
description=(
"Message role of tool outputs. If the LLM you use supports function/tool "
"calling, you may set it to `MessageRole.TOOL` to avoid the tool outputs "
"being misinterpreted as new user messages."
),
)

def format(
self,
Expand All @@ -73,13 +81,13 @@ def format(
fmt_sys_header = self.system_header.format(**format_args)

# format reasoning history as alternating user and assistant messages
# where the assistant messages are thoughts and actions and the user
# where the assistant messages are thoughts and actions and the tool
# messages are observations
reasoning_history = []
for reasoning_step in current_reasoning:
if isinstance(reasoning_step, ObservationReasoningStep):
message = ChatMessage(
role=MessageRole.USER,
role=self.observation_role,
content=reasoning_step.get_content(),
)
else:
Expand All @@ -100,6 +108,7 @@ def from_defaults(
cls,
system_header: Optional[str] = None,
context: Optional[str] = None,
observation_role: MessageRole = MessageRole.USER,
) -> "ReActChatFormatter":
"""Create ReActChatFormatter from defaults."""
if not system_header:
Expand All @@ -112,6 +121,7 @@ def from_defaults(
return ReActChatFormatter(
system_header=system_header,
context=context or "",
observation_role=observation_role,
)

@classmethod
Expand Down

0 comments on commit 9cea48b

Please sign in to comment.