diff --git a/docs/docs/examples/agent/react_agent.ipynb b/docs/docs/examples/agent/react_agent.ipynb index 337db56d8aff2..0eba247bc912e 100644 --- a/docs/docs/examples/agent/react_agent.ipynb +++ b/docs/docs/examples/agent/react_agent.ipynb @@ -53,22 +53,7 @@ "execution_count": null, "id": "e8ac1778-0585-43c9-9dad-014d13d7460d", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[nltk_data] Downloading package stopwords to /Users/jerryliu/Programmi\n", - "[nltk_data] ng/gpt_index/.venv/lib/python3.10/site-\n", - "[nltk_data] packages/llama_index/legacy/_static/nltk_cache...\n", - "[nltk_data] Unzipping corpora/stopwords.zip.\n", - "[nltk_data] Downloading package punkt to /Users/jerryliu/Programming/g\n", - "[nltk_data] pt_index/.venv/lib/python3.10/site-\n", - "[nltk_data] packages/llama_index/legacy/_static/nltk_cache...\n", - "[nltk_data] Unzipping tokenizers/punkt.zip.\n" - ] - } - ], + "outputs": [], "source": [ "from llama_index.core.agent import ReActAgent\n", "from llama_index.llms.openai import OpenAI\n", @@ -468,13 +453,44 @@ "response = agent.chat(\"What is 5+3+2\")\n", "print(response)" ] + }, + { + "cell_type": "markdown", + "id": "76190511-692c-4642-9b86-adac88c98550", + "metadata": {}, + "source": [ + "### Customizing the Message Role of Observation\n", + "\n", + "If the LLM you use supports function/tool calling, you may set the message role of observations to `MessageRole.TOOL`. \n", + "Doing this will prevent the tool outputs from being misinterpreted as new user messages for some models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6d5e8c1-c40e-4a96-8d2e-84127f066265", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.agent import ReActChatFormatter\n", + "from llama_index.core.llms import MessageRole\n", + "\n", + "agent = ReActAgent.from_tools(\n", + " [multiply_tool, add_tool],\n", + " llm=llm,\n", + " react_chat_formatter=ReActChatFormatter.from_defaults(\n", + " observation_role=MessageRole.TOOL\n", + " ),\n", + " verbose=True,\n", + ")" + ] } ], "metadata": { "kernelspec": { - "display_name": "llama_index_v2", + "display_name": "LlamaIndex Development", "language": "python", - "name": "llama_index_v2" + "name": "llama-index-dev" }, "language_info": { "codemirror_mode": { diff --git a/llama-index-core/llama_index/core/agent/react/formatter.py b/llama-index-core/llama_index/core/agent/react/formatter.py index 1de4fa744e05b..e1985932a6a79 100644 --- a/llama-index-core/llama_index/core/agent/react/formatter.py +++ b/llama-index-core/llama_index/core/agent/react/formatter.py @@ -13,7 +13,7 @@ ObservationReasoningStep, ) from llama_index.core.base.llms.types import ChatMessage, MessageRole -from llama_index.core.bridge.pydantic import BaseModel, ConfigDict +from llama_index.core.bridge.pydantic import BaseModel, ConfigDict, Field from llama_index.core.tools import BaseTool logger = logging.getLogger(__name__) @@ -53,6 +53,14 @@ class ReActChatFormatter(BaseAgentChatFormatter): system_header: str = REACT_CHAT_SYSTEM_HEADER # default context: str = "" # not needed w/ default + observation_role: MessageRole = Field( + default=MessageRole.USER, + description=( + "Message role of tool outputs. If the LLM you use supports function/tool " + "calling, you may set it to `MessageRole.TOOL` to avoid the tool outputs " + "being misinterpreted as new user messages." + ), + ) def format( self, @@ -73,13 +81,13 @@ def format( fmt_sys_header = self.system_header.format(**format_args) # format reasoning history as alternating user and assistant messages - # where the assistant messages are thoughts and actions and the user + # where the assistant messages are thoughts and actions and the tool # messages are observations reasoning_history = [] for reasoning_step in current_reasoning: if isinstance(reasoning_step, ObservationReasoningStep): message = ChatMessage( - role=MessageRole.USER, + role=self.observation_role, content=reasoning_step.get_content(), ) else: @@ -100,6 +108,7 @@ def from_defaults( cls, system_header: Optional[str] = None, context: Optional[str] = None, + observation_role: MessageRole = MessageRole.USER, ) -> "ReActChatFormatter": """Create ReActChatFormatter from defaults.""" if not system_header: @@ -112,6 +121,7 @@ def from_defaults( return ReActChatFormatter( system_header=system_header, context=context or "", + observation_role=observation_role, ) @classmethod