Skip to content

Commit f25dd57

Browse files
committed
make memory configurable, consistently truncate discord messages, fix action prompt
1 parent a970572 commit f25dd57

File tree

9 files changed

+64
-45
lines changed

9 files changed

+64
-45
lines changed

client/src/models.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ export interface StringParameter {
6969
export interface NumberParameter {
7070
type: 'number';
7171
default?: number;
72-
enum?: Array<string>;
72+
enum?: Array<number>;
7373
}
7474

7575
export type Parameter = BooleanParameter | NumberParameter | StringParameter;

client/src/prompt.tsx

+16-4
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ export function enumerateSignificantParameterValues(name: string, world: World)
129129
}
130130
}
131131

132-
export function convertSignificantParameter(name: string, parameter: Parameter, world: Maybe<World>): Parameter {
132+
export function convertSignificantParameter<T extends Parameter>(name: string, parameter: T, world: Maybe<World>): T {
133133
if (parameter.type === 'boolean') {
134134
return parameter;
135135
}
@@ -154,15 +154,27 @@ export function formatAction(action: string, parameters: Record<string, boolean
154154
return `~${action}:${Object.entries(parameters).map(([name, value]) => `${name}=${value}`).join(',')}`;
155155
}
156156

157+
export function getEnumOrDefault<T>(defaultValue: Maybe<T>, enumValues: Maybe<Array<T>>, evenMoreDefault: T): T {
158+
if (doesExist(defaultValue)) {
159+
return defaultValue;
160+
}
161+
162+
if (doesExist(enumValues)) {
163+
return enumValues[0];
164+
}
165+
166+
return evenMoreDefault;
167+
}
168+
157169
export function makeDefaultParameterValues(parameters: Record<string, Parameter>) {
158170
return Object.entries(parameters).reduce((acc, [name, parameter]) => {
159171
switch (parameter.type) {
160172
case 'boolean':
161-
return { ...acc, [name]: mustDefault(parameter.default, false) };
173+
return { ...acc, [name]: getEnumOrDefault(parameter.default, [], false) };
162174
case 'number':
163-
return { ...acc, [name]: mustDefault(parameter.default, 0) };
175+
return { ...acc, [name]: getEnumOrDefault(parameter.default, parameter.enum, 0) };
164176
case 'string':
165-
return { ...acc, [name]: mustDefault(parameter.default, '') };
177+
return { ...acc, [name]: getEnumOrDefault(parameter.default, parameter.enum, '') };
166178
default:
167179
return acc;
168180
}

taleweave/actions/base.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
broadcast,
66
get_agent_for_character,
77
get_character_agent_for_name,
8+
get_game_config,
89
get_prompt,
910
world_context,
1011
)
@@ -22,8 +23,6 @@
2223

2324
logger = getLogger(__name__)
2425

25-
MAX_CONVERSATION_STEPS = 2
26-
2726

2827
def action_examine(target: str) -> str:
2928
"""
@@ -173,7 +172,8 @@ def action_ask(character: str, question: str) -> str:
173172
character: The name of the character to ask. You cannot ask yourself questions.
174173
question: The question to ask them.
175174
"""
176-
# capture references to the current character and room, because they will be overwritten
175+
config = get_game_config()
176+
177177
with action_context() as (action_room, action_character):
178178
# sanity checks
179179
question_character, question_agent = get_character_agent_for_name(character)
@@ -216,7 +216,7 @@ def action_ask(character: str, question: str) -> str:
216216
end_prompt,
217217
echo_function=action_tell.__name__,
218218
echo_parameter="message",
219-
max_length=MAX_CONVERSATION_STEPS,
219+
max_length=config.world.character.conversation_limit,
220220
)
221221

222222
if result:
@@ -233,7 +233,7 @@ def action_tell(character: str, message: str) -> str:
233233
character: The name of the character to tell. You cannot talk to yourself.
234234
message: The message to tell them.
235235
"""
236-
# capture references to the current character and room, because they will be overwritten
236+
config = get_game_config()
237237

238238
with action_context() as (action_room, action_character):
239239
# sanity checks
@@ -268,7 +268,7 @@ def action_tell(character: str, message: str) -> str:
268268
end_prompt,
269269
echo_function=action_tell.__name__,
270270
echo_parameter="message",
271-
max_length=MAX_CONVERSATION_STEPS,
271+
max_length=config.world.character.conversation_limit,
272272
)
273273

274274
if result:

taleweave/bot/discord.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,12 @@ async def broadcast_event(message: str | GameEvent):
323323
event_messages[event_message.id] = message
324324

325325

326+
def truncate(text: str, length: int = 1000) -> str:
327+
if len(text) > length:
328+
return text[:length] + "..."
329+
return text
330+
331+
326332
def embed_from_event(event: GameEvent) -> Embed | None:
327333
if isinstance(event, GenerateEvent):
328334
return embed_from_generate(event)
@@ -357,7 +363,7 @@ def embed_from_action(event: ActionEvent):
357363

358364
def embed_from_reply(event: ReplyEvent):
359365
reply_embed = Embed(title=event.room.name, description=event.speaker.name)
360-
reply_embed.add_field(name="Reply", value=event.text)
366+
reply_embed.add_field(name="Reply", value=truncate(event.text))
361367
return reply_embed
362368

363369

@@ -367,12 +373,8 @@ def embed_from_generate(event: GenerateEvent) -> Embed:
367373

368374

369375
def embed_from_result(event: ResultEvent):
370-
text = event.result
371-
if len(text) > 1000:
372-
text = text[:1000] + "..."
373-
374376
result_embed = Embed(title=event.room.name, description=event.character.name)
375-
result_embed.add_field(name="Result", value=text)
377+
result_embed.add_field(name="Result", value=truncate(event.result))
376378
return result_embed
377379

378380

@@ -384,14 +386,14 @@ def embed_from_player(event: PlayerEvent):
384386
title = format_prompt("discord_leave_title", event=event)
385387
description = format_prompt("discord_leave_result", event=event)
386388

387-
player_embed = Embed(title=title, description=description)
389+
player_embed = Embed(title=title, description=truncate(description))
388390
return player_embed
389391

390392

391393
def embed_from_prompt(event: PromptEvent):
392394
# TODO: ping the player
393395
prompt_embed = Embed(title=event.room.name, description=event.character.name)
394-
prompt_embed.add_field(name="Prompt", value=event.prompt)
396+
prompt_embed.add_field(name="Prompt", value=truncate(event.prompt))
395397
return prompt_embed
396398

397399

@@ -400,5 +402,5 @@ def embed_from_status(event: StatusEvent):
400402
title=event.room.name if event.room else "",
401403
description=event.character.name if event.character else "",
402404
)
403-
status_embed.add_field(name="Status", value=event.text)
405+
status_embed.add_field(name="Status", value=truncate(event.text))
404406
return status_embed

taleweave/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def snapshot_system(world: World, turn: int, data: None = None) -> None:
416416
set_dungeon_master(world_builder)
417417

418418
# start the sim
419-
logger.debug("simulating world: %s", world)
419+
logger.debug("simulating world: %s", world.name)
420420
simulate_world(
421421
world,
422422
turns=args.turns,

taleweave/simulate.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
set_current_world,
4545
set_game_systems,
4646
)
47+
from taleweave.errors import ActionError
4748
from taleweave.game_system import GameSystem
4849
from taleweave.models.entity import Character, Room, World
4950
from taleweave.models.event import ActionEvent, ResultEvent
@@ -117,12 +118,9 @@ def result_parser(value, **kwargs):
117118
# TODO: only emit valid actions that parse and run correctly, and try to avoid parsing the JSON twice
118119
event = ActionEvent.from_json(value, room, character)
119120
else:
120-
# TODO: this path should be removed and throw
121-
# logger.warning(
122-
# "invalid action, emitting as result event - this is a bug somewhere"
123-
# )
124-
# event = ResultEvent(value, room, character)
125-
raise ValueError("invalid non-JSON action")
121+
raise ActionError(
122+
"Your last reply was not valid JSON. Please try again and reply with a valid function call in JSON format."
123+
)
126124

127125
broadcast(event)
128126

@@ -216,14 +214,14 @@ def prompt_character_planning(
216214
while not stop_condition(current=i):
217215
result = loop_retry(
218216
agent,
219-
get_prompt("world_simulate_character_planning"),
220-
context={
221-
"event_count": event_count,
222-
"events_prompt": events_prompt,
223-
"note_count": note_count,
224-
"notes_prompt": notes_prompt,
225-
"room_summary": summarize_room(room, character),
226-
},
217+
format_prompt(
218+
"world_simulate_character_planning",
219+
event_count=event_count,
220+
events_prompt=events_prompt,
221+
note_count=note_count,
222+
notes_prompt=notes_prompt,
223+
room_summary=summarize_room(room, character),
224+
),
227225
result_parser=result_parser,
228226
stop_condition=stop_condition,
229227
toolbox=planner_toolbox,

taleweave/state.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
from packit.agent import Agent, agent_easy_connect
88
from pydantic import RootModel
99

10-
from taleweave.context import get_all_character_agents, set_character_agent
10+
from taleweave.context import (
11+
get_all_character_agents,
12+
get_game_config,
13+
set_character_agent,
14+
)
1115
from taleweave.models.entity import World
1216
from taleweave.player import LocalPlayer
1317

14-
MEMORY_LIMIT = 25 # 10
15-
1618

1719
def create_agents(
1820
world: World,
@@ -69,6 +71,7 @@ def snapshot_world(world: World, turn: int):
6971
def restore_memory(
7072
data: Sequence[str | Dict[str, str]]
7173
) -> deque[str | AIMessage | HumanMessage | SystemMessage]:
74+
config = get_game_config()
7275
memories = []
7376

7477
for memory in data:
@@ -85,7 +88,7 @@ def restore_memory(
8588
elif memory_type == "ai":
8689
memories.append(AIMessage(content=memory_content))
8790

88-
return deque(memories, maxlen=MEMORY_LIMIT)
91+
return deque(memories, maxlen=config.world.character.memory_limit)
8992

9093

9194
def save_world(world, filename):

taleweave/systems/digest.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from taleweave.game_system import FormatPerspective, GameSystem
66
from taleweave.models.entity import Character, Room, World, WorldEntity
77
from taleweave.models.event import ActionEvent, GameEvent
8+
from taleweave.utils.prompt import format_str
89
from taleweave.utils.search import find_containing_room
910

1011
logger = getLogger(__name__)
@@ -22,7 +23,7 @@ def create_turn_digest(
2223
if prompt_key in library.prompts:
2324
try:
2425
template = library.prompts[prompt_key]
25-
message = template.format(event=event)
26+
message = format_str(template, event=event)
2627
messages.append(message)
2728
except Exception:
2829
logger.exception("error formatting digest event: %s", event)

taleweave/utils/prompt.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@
33
from jinja2 import Environment
44

55
from taleweave.context import get_prompt_library
6+
from taleweave.utils.string import and_list, or_list
67
from taleweave.utils.world import describe_entity, name_entity
78

89
logger = getLogger(__name__)
910

11+
jinja_env = Environment()
12+
jinja_env.filters["describe"] = describe_entity
13+
jinja_env.filters["name"] = name_entity
14+
jinja_env.filters["and_list"] = and_list
15+
jinja_env.filters["or_list"] = or_list
16+
1017

1118
def format_prompt(prompt_key: str, **kwargs) -> str:
1219
try:
@@ -19,9 +26,5 @@ def format_prompt(prompt_key: str, **kwargs) -> str:
1926

2027

2128
def format_str(template_str: str, **kwargs) -> str:
22-
env = Environment()
23-
env.filters["describe"] = describe_entity
24-
env.filters["name"] = name_entity
25-
26-
template = env.from_string(template_str)
29+
template = jinja_env.from_string(template_str)
2730
return template.render(**kwargs)

0 commit comments

Comments
 (0)