Skip to content

Commit

Permalink
limit discord channels, ping discord players on their turn, improve a…
Browse files Browse the repository at this point in the history
…ction JSON errors
  • Loading branch information
ssube committed Jun 8, 2024
1 parent e1c72c3 commit aa44632
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 18 deletions.
8 changes: 7 additions & 1 deletion prompts/llama-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,9 @@ prompts:
{{notes_prompt}} {{events_prompt}}
What will you do next? Reply with a JSON function call, calling one of the actions.
You can only perform one action per turn. What is your next action?
world_simulate_character_action_error_json: |
Your last reply was not a valid action or the action you tried to use does not exist. Please try again, being
careful to reply with a valid function call in JSON format. The available actions are: {{actions}}.
world_simulate_character_planning: |
You are about to start your turn. Plan your next action carefully. Take notes and schedule events to help keep track of your goals.
Expand All @@ -375,4 +378,7 @@ prompts:
world_simulate_character_planning_events_none: |
You have no upcoming events.
world_simulate_character_planning_events_item: |
{{event.name}} in {{turns}} turns
{{event.name}} in {{turns}} turns
world_simulate_character_planning_error_json: |
Your last reply was not a valid action or the action you tried to use does not exist. Please try again, being
careful to reply with a valid function call in JSON format. The available actions are: {{actions}}.
14 changes: 14 additions & 0 deletions taleweave/bot/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
active_tasks = set()
event_messages: Dict[int, str | GameEvent] = {}
event_queue: Queue[GameEvent] = Queue()
player_mentions: Dict[str, str] = {}


def remove_tags(text: str) -> str:
Expand Down Expand Up @@ -80,6 +81,12 @@ async def on_message(self, message):
if message.author == self.user:
return

# make sure the message was in a valid channel
active_channels = get_active_channels()
if message.channel not in active_channels:
return

# get message contents
config = get_game_config()
author = message.author
channel = message.channel
Expand Down Expand Up @@ -143,6 +150,7 @@ def prompt_player(event: PromptEvent):
)
set_character_agent(character_name, character, player)
set_player(user_name, player)
player_mentions[user_name] = author.mention

logger.info(f"{user_name} has joined the game as {character.name}!")
join_event = PlayerEvent("join", character_name, user_name)
Expand Down Expand Up @@ -179,6 +187,9 @@ def prompt_player(event: PromptEvent):
if content.startswith(config.bot.discord.command_prefix + "leave"):
remove_player(user_name)

if user_name in player_mentions:
del player_mentions[user_name]

# revert to LLM agent
character, _ = get_character_agent_for_name(player.name)
if character and player.fallback_agent:
Expand Down Expand Up @@ -443,6 +454,9 @@ def embed_from_prompt(event: PromptEvent):

if user:
# TODO: use Discord user.mention to ping the user
if user in player_mentions:
user = player_mentions[user]

prompt_embed.add_field(
name="Player",
value=user,
Expand Down
42 changes: 29 additions & 13 deletions taleweave/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from packit.agent import Agent
from packit.conditions import condition_or, condition_threshold
from packit.errors import ToolError
from packit.loops import loop_retry
from packit.results import function_result
from packit.toolbox import Toolbox
from packit.utils import could_be_json

from taleweave.actions.base import (
action_ask,
Expand Down Expand Up @@ -80,8 +80,10 @@ def world_result_parser(value, agent, **kwargs):


def prompt_character_action(
room, character, agent, action_names, action_toolbox, current_turn
room, character, agent, action_toolbox, current_turn
) -> str:
action_names = action_toolbox.list_tools()

# collect data for the prompt
notes_prompt, events_prompt = get_notes_events(character, current_turn)

Expand Down Expand Up @@ -114,18 +116,21 @@ def result_parser(value, **kwargs):
except Exception:
pass

if could_be_json(value):
# TODO: only emit valid actions that parse and run correctly, and try to avoid parsing the JSON twice
try:
result = world_result_parser(value, **kwargs)

# TODO: try to avoid parsing the JSON twice
event = ActionEvent.from_json(value, room, character)
else:
broadcast(event)

return result
except ToolError:
raise ActionError(
"Your last reply was not valid JSON. Please try again and reply with a valid function call in JSON format."
format_prompt(
"world_simulate_character_action_error_json", actions=action_names
)
)

broadcast(event)

return world_result_parser(value, **kwargs)

# prompt and act
logger.info("starting turn for character: %s", character.name)
result = loop_retry(
Expand Down Expand Up @@ -201,9 +206,21 @@ def prompt_character_planning(
event_count = len(character.planner.calendar.events)
note_count = len(character.planner.notes)

def result_parser(value, **kwargs):
try:
return function_result(value, **kwargs)
except ToolError:
raise ActionError(
format_prompt(
"world_simulate_character_planning_error_json",
actions=planner_toolbox.list_tools(),
)
)

logger.info("starting planning for character: %s", character.name)
_, condition_end, result_parser = make_keyword_condition(
get_prompt("world_simulate_character_planning_done")
get_prompt("world_simulate_character_planning_done"),
result_parser=result_parser,
)
stop_condition = condition_or(
condition_end, partial(condition_threshold, max=max_steps)
Expand Down Expand Up @@ -256,7 +273,6 @@ def simulate_world(
*actions,
]
)
action_names = action_tools.list_tools()

# build a toolbox for the planners
planner_toolbox = Toolbox(
Expand Down Expand Up @@ -309,7 +325,7 @@ def simulate_world(

try:
result = prompt_character_action(
room, character, agent, action_names, action_tools, current_turn
room, character, agent, action_tools, current_turn
)
result_event = ResultEvent(
result=result, room=room, character=character
Expand Down
12 changes: 8 additions & 4 deletions taleweave/utils/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
logger = getLogger(__name__)


def make_keyword_condition(end_message: str, keywords=["end", "stop"]):
def make_keyword_condition(
end_message: str,
keywords=["end", "stop"],
result_parser=multi_function_or_str_result,
):
set_end, condition_end = make_flag_condition()

def result_parser(value, **kwargs):
def inner_parser(value, **kwargs):
normalized_value = normalize_name(value)
if normalized_value in keywords:
logger.debug(f"found keyword, setting stop condition: {normalized_value}")
Expand Down Expand Up @@ -51,9 +55,9 @@ def result_parser(value, **kwargs):
set_end()
return end_message

return multi_function_or_str_result(value, **kwargs)
return result_parser(value, **kwargs)

return set_end, condition_end, result_parser
return set_end, condition_end, inner_parser


def summarize_room(room: Room, player: Character) -> str:
Expand Down

0 comments on commit aa44632

Please sign in to comment.