From aa44632860dad9bd4f1f5b5ce81f74cd05d72c90 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 7 Jun 2024 22:22:12 -0500 Subject: [PATCH] limit discord channels, ping discord players on their turn, improve action JSON errors --- prompts/llama-base.yml | 8 ++++++- taleweave/bot/discord.py | 14 +++++++++++ taleweave/simulate.py | 42 +++++++++++++++++++++++---------- taleweave/utils/conversation.py | 12 ++++++---- 4 files changed, 58 insertions(+), 18 deletions(-) diff --git a/prompts/llama-base.yml b/prompts/llama-base.yml index 7e7e09f..ebe7b6d 100644 --- a/prompts/llama-base.yml +++ b/prompts/llama-base.yml @@ -353,6 +353,9 @@ prompts: {{notes_prompt}} {{events_prompt}} What will you do next? Reply with a JSON function call, calling one of the actions. You can only perform one action per turn. What is your next action? + world_simulate_character_action_error_json: | + Your last reply was not a valid action or the action you tried to use does not exist. Please try again, being + careful to reply with a valid function call in JSON format. The available actions are: {{actions}}. world_simulate_character_planning: | You are about to start your turn. Plan your next action carefully. Take notes and schedule events to help keep track of your goals. @@ -375,4 +378,7 @@ prompts: world_simulate_character_planning_events_none: | You have no upcoming events. world_simulate_character_planning_events_item: | - {{event.name}} in {{turns}} turns \ No newline at end of file + {{event.name}} in {{turns}} turns + world_simulate_character_planning_error_json: | + Your last reply was not a valid action or the action you tried to use does not exist. Please try again, being + careful to reply with a valid function call in JSON format. The available actions are: {{actions}}. \ No newline at end of file diff --git a/taleweave/bot/discord.py b/taleweave/bot/discord.py index d5047ed..0b31260 100644 --- a/taleweave/bot/discord.py +++ b/taleweave/bot/discord.py @@ -45,6 +45,7 @@ active_tasks = set() event_messages: Dict[int, str | GameEvent] = {} event_queue: Queue[GameEvent] = Queue() +player_mentions: Dict[str, str] = {} def remove_tags(text: str) -> str: @@ -80,6 +81,12 @@ async def on_message(self, message): if message.author == self.user: return + # make sure the message was in a valid channel + active_channels = get_active_channels() + if message.channel not in active_channels: + return + + # get message contents config = get_game_config() author = message.author channel = message.channel @@ -143,6 +150,7 @@ def prompt_player(event: PromptEvent): ) set_character_agent(character_name, character, player) set_player(user_name, player) + player_mentions[user_name] = author.mention logger.info(f"{user_name} has joined the game as {character.name}!") join_event = PlayerEvent("join", character_name, user_name) @@ -179,6 +187,9 @@ def prompt_player(event: PromptEvent): if content.startswith(config.bot.discord.command_prefix + "leave"): remove_player(user_name) + if user_name in player_mentions: + del player_mentions[user_name] + # revert to LLM agent character, _ = get_character_agent_for_name(player.name) if character and player.fallback_agent: @@ -443,6 +454,9 @@ def embed_from_prompt(event: PromptEvent): if user: # TODO: use Discord user.mention to ping the user + if user in player_mentions: + user = player_mentions[user] + prompt_embed.add_field( name="Player", value=user, diff --git a/taleweave/simulate.py b/taleweave/simulate.py index 3e6f8a9..83cd758 100644 --- a/taleweave/simulate.py +++ b/taleweave/simulate.py @@ -7,10 +7,10 @@ from packit.agent import Agent from packit.conditions import condition_or, condition_threshold +from packit.errors import ToolError from packit.loops import loop_retry from packit.results import function_result from packit.toolbox import Toolbox -from packit.utils import could_be_json from taleweave.actions.base import ( action_ask, @@ -80,8 +80,10 @@ def world_result_parser(value, agent, **kwargs): def prompt_character_action( - room, character, agent, action_names, action_toolbox, current_turn + room, character, agent, action_toolbox, current_turn ) -> str: + action_names = action_toolbox.list_tools() + # collect data for the prompt notes_prompt, events_prompt = get_notes_events(character, current_turn) @@ -114,18 +116,21 @@ def result_parser(value, **kwargs): except Exception: pass - if could_be_json(value): - # TODO: only emit valid actions that parse and run correctly, and try to avoid parsing the JSON twice + try: + result = world_result_parser(value, **kwargs) + + # TODO: try to avoid parsing the JSON twice event = ActionEvent.from_json(value, room, character) - else: + broadcast(event) + + return result + except ToolError: raise ActionError( - "Your last reply was not valid JSON. Please try again and reply with a valid function call in JSON format." + format_prompt( + "world_simulate_character_action_error_json", actions=action_names + ) ) - broadcast(event) - - return world_result_parser(value, **kwargs) - # prompt and act logger.info("starting turn for character: %s", character.name) result = loop_retry( @@ -201,9 +206,21 @@ def prompt_character_planning( event_count = len(character.planner.calendar.events) note_count = len(character.planner.notes) + def result_parser(value, **kwargs): + try: + return function_result(value, **kwargs) + except ToolError: + raise ActionError( + format_prompt( + "world_simulate_character_planning_error_json", + actions=planner_toolbox.list_tools(), + ) + ) + logger.info("starting planning for character: %s", character.name) _, condition_end, result_parser = make_keyword_condition( - get_prompt("world_simulate_character_planning_done") + get_prompt("world_simulate_character_planning_done"), + result_parser=result_parser, ) stop_condition = condition_or( condition_end, partial(condition_threshold, max=max_steps) @@ -256,7 +273,6 @@ def simulate_world( *actions, ] ) - action_names = action_tools.list_tools() # build a toolbox for the planners planner_toolbox = Toolbox( @@ -309,7 +325,7 @@ def simulate_world( try: result = prompt_character_action( - room, character, agent, action_names, action_tools, current_turn + room, character, agent, action_tools, current_turn ) result_event = ResultEvent( result=result, room=room, character=character diff --git a/taleweave/utils/conversation.py b/taleweave/utils/conversation.py index c072471..0628661 100644 --- a/taleweave/utils/conversation.py +++ b/taleweave/utils/conversation.py @@ -18,10 +18,14 @@ logger = getLogger(__name__) -def make_keyword_condition(end_message: str, keywords=["end", "stop"]): +def make_keyword_condition( + end_message: str, + keywords=["end", "stop"], + result_parser=multi_function_or_str_result, +): set_end, condition_end = make_flag_condition() - def result_parser(value, **kwargs): + def inner_parser(value, **kwargs): normalized_value = normalize_name(value) if normalized_value in keywords: logger.debug(f"found keyword, setting stop condition: {normalized_value}") @@ -51,9 +55,9 @@ def result_parser(value, **kwargs): set_end() return end_message - return multi_function_or_str_result(value, **kwargs) + return result_parser(value, **kwargs) - return set_end, condition_end, result_parser + return set_end, condition_end, inner_parser def summarize_room(room: Room, player: Character) -> str: