Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
a3a2d69
Fix game_segment/weighted_total_loss bugs and refine prompts, compute…
xiongjyu Nov 20, 2025
959a558
Fixed the accumulate_steps bug and added cprofile functionality.
xiongjyu Nov 20, 2025
ecedc5f
Refine the code and fix the bug in data collection.
xiongjyu Nov 22, 2025
2d53d22
Add REINFORCE-style losses and store old_logprob in the buffer.
xiongjyu Nov 23, 2025
c608600
Fix the get_llm_prior bug so that every action receives a logprob
xiongjyu Nov 24, 2025
15e39f6
fixed the history bug in the build_llm_prompt and logs in forward_learn
xiongjyu Nov 24, 2025
7c9acd9
rename advantage_tensor on rft
xiongjyu Nov 24, 2025
738f300
Fixed the action out-of-bounds bug and added a record for forward_col…
xiongjyu Nov 26, 2025
0a166f6
Fixed the misalignment between old_log_prob and log_prob, and correct…
xiongjyu Nov 27, 2025
4f3668e
add some logs for analysying
xiongjyu Nov 27, 2025
2985e60
Polish the code and standardize the format.
xiongjyu Nov 29, 2025
ff98006
Add kL divergence in rft and llm_prior_entropy in collect
xiongjyu Dec 2, 2025
7e43e45
polish config and format
xiongjyu Dec 3, 2025
d6555e5
delete unused files
xiongjyu Dec 3, 2025
b7d42ee
Decouple the training of world_model and LLM.
xiongjyu Dec 9, 2025
95e2347
add cache in the jericho
xiongjyu Dec 10, 2025
9682486
Separate sync and async entry points to simplify the program.
xiongjyu Dec 10, 2025
0a38197
Reference OpenRLHF’s implementation to update vLLM weights in real ti…
xiongjyu Dec 14, 2025
e361039
delete unused orz files
xiongjyu Dec 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 25 additions & 226 deletions lzero/mcts/buffer/game_buffer_priorzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,158 +18,6 @@
from typing import List, Any, Union, Tuple
from lzero.mcts.buffer.game_buffer_unizero import UniZeroGameBuffer


class PriorZeroGameBuffer(UniZeroGameBuffer):
"""
[PRIORZERO-MODIFIED]
Enhanced GameBuffer that provides game_segments for LLM policy training.

Modifications:
1. sample() returns game_segments as 4th element
2. Efficient implementation using existing game_segment_list from _make_batch
3. No additional memory overhead (returns references, not copies)
"""

def __init__(self, cfg):
"""Initialize PriorZero Game Buffer."""
super().__init__(cfg)

# [PRIORZERO-NEW] Cache for the last sampled game segments
# This avoids re-sampling when we need game segments
self._last_sampled_game_segments = None
self._last_sampled_batch_indices = None

def sample(
self,
batch_size: int,
policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]
) -> List[Any]:
"""
[PRIORZERO-MODIFIED]
Sample data and prepare current_batch, target_batch, AND game_segments.

Returns:
train_data: [current_batch, target_batch, game_segments]
- current_batch: [obs, action, target_action, mask, indices, weights, make_time, timestep]
- target_batch: [rewards, values, policies]
- game_segments: List of GameSegment objects used in this batch

Note:
game_segments are returned for LLM training (SFT/RFT).
They contain:
- mcts_policy_segment: MCTS visit distributions (for SFT supervision)
- raw_obs_segment: Raw text observations (for LLM prompts)
- reward_segment: Environment rewards (for RFT)
- search_value_segment: MCTS search values (for analysis)
"""
policy._target_model.to(self._cfg.device)
policy._target_model.eval()

# ======================================================================
# [PRIORZERO-KEY] Sample data and extract game_segments
# ======================================================================
# obtain the current_batch and prepare target context
reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch(
batch_size, self._cfg.reanalyze_ratio
)

# [PRIORZERO-NEW] Extract game_segments from the sampling process
# These were already created in _make_batch, we just need to save them
game_segments = self._last_sampled_game_segments

# Defensive check: ensure game_segments match batch_size
if game_segments is None or len(game_segments) != len(current_batch[4]): # current_batch[4] is batch_index_list
# Fallback: create empty list if something went wrong
import logging
logging.warning(
f"[PriorZeroBuffer] game_segments mismatch: "
f"expected {len(current_batch[4])}, got {len(game_segments) if game_segments else None}. "
f"Falling back to empty list (SFT/RFT will be skipped)."
)
game_segments = []

# ======================================================================
# Standard UniZero processing (unchanged)
# ======================================================================
# current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list]

# target reward, target value
batch_rewards, batch_target_values = self._compute_target_reward_value(
reward_value_context, policy._target_model, current_batch[2], current_batch[-1] # current_batch[2] is batch_target_action
)

# target policy
batch_target_policies_re = self._compute_target_policy_reanalyzed(
policy_re_context, policy._target_model, current_batch[1], current_batch[-1]
) # current_batch[1] is batch_action
batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
policy_non_re_context, self.action_space_size
)

# fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies
if 0 < self._cfg.reanalyze_ratio < 1:
batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re])
elif self._cfg.reanalyze_ratio == 1:
batch_target_policies = batch_target_policies_re
elif self._cfg.reanalyze_ratio == 0:
batch_target_policies = batch_target_policies_non_re

target_batch = [batch_rewards, batch_target_values, batch_target_policies]

# ======================================================================
# [PRIORZERO-KEY] Return current_batch, target_batch, AND game_segments
# ======================================================================
train_data = [current_batch, target_batch, game_segments]
return train_data

def _sample_orig_data(self, batch_size: int) -> Tuple[Any]:
"""
[PRIORZERO-MODIFIED]
Override to cache game_segments during sampling.

This avoids double sampling by caching the result for sample() to use.
"""
# Call parent implementation
result = super()._sample_orig_data(batch_size)

# Cache the game_segment_list (first element of result tuple)
game_segment_list = result[0]
self._last_sampled_game_segments = game_segment_list
self._last_sampled_batch_indices = result[2] # batch_index_list

return result

def _sample_orig_data_episode(self, batch_size: int) -> Tuple[Any]:
"""
[PRIORZERO-MODIFIED]
Override to cache game_segments during episode sampling.

This avoids double sampling by caching the result for sample() to use.
"""
# Call parent implementation
result = super()._sample_orig_data_episode(batch_size)

# Cache the game_segment_list (first element of result tuple)
game_segment_list = result[0]
self._last_sampled_game_segments = game_segment_list
self._last_sampled_batch_indices = result[2] # batch_index_list

return result

def clear(self):
"""
[PRIORZERO-MODIFIED]
Clear buffer and cached game segments.
"""
super().clear()
self._last_sampled_game_segments = None
self._last_sampled_batch_indices = None


# ==============================================================================
# Optimized Alternative (Avoids Double Sampling)
# ==============================================================================

class PriorZeroGameBufferOptimized(UniZeroGameBuffer):
"""
[PRIORZERO-OPTIMIZED]
Expand All @@ -195,16 +43,14 @@ def sample(self, batch_size: int, policy) -> List[Any]:
batch_size, self._cfg.reanalyze_ratio
)

# Get cached game segments (set by our overridden _make_batch)
game_segments = self._cached_game_segments or []

obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list, raw_obs_list, history_obs_list, action_logprob_list = current_batch
# Standard processing
batch_rewards, batch_target_values = self._compute_target_reward_value(
reward_value_context, policy._target_model, current_batch[2], current_batch[-1]
reward_value_context, policy._target_model, current_batch[2], timestep_list
)

batch_target_policies_re = self._compute_target_policy_reanalyzed(
policy_re_context, policy._target_model, current_batch[1], current_batch[-1]
policy_re_context, policy._target_model, current_batch[1], timestep_list
)
batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
policy_non_re_context, self.action_space_size
Expand All @@ -219,7 +65,7 @@ def sample(self, batch_size: int, policy) -> List[Any]:

target_batch = [batch_rewards, batch_target_values, batch_target_policies]

return [current_batch, target_batch, game_segments]
return [current_batch, target_batch]

def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
"""
Expand All @@ -243,6 +89,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
# Rest of the code is identical to parent's _make_batch
batch_size = len(batch_index_list)
obs_list, action_list, mask_list = [], [], []
raw_obs_list, history_obs_list = [], []
action_logprob_list = []
timestep_list = []
bootstrap_action_list = []

Expand Down Expand Up @@ -272,6 +120,16 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
)
)
raw_obs_list.append(game_segment_list[i].get_unroll_raw_obs(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
))
history_obs_list.append(game_segment_list[i].get_unroll_histroy_obs(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
))
action_logprob_list.append(game_segment_list[i].get_unroll_action_logprob(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
))

action_list.append(actions_tmp)
mask_list.append(mask_tmp)
timestep_list.append(timestep_tmp)
Expand All @@ -291,6 +149,10 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list]
for i in range(len(current_batch)):
current_batch[i] = np.asarray(current_batch[i])

current_batch.append(raw_obs_list)
current_batch.append(history_obs_list)
current_batch.append(action_logprob_list)

total_transitions = self.get_num_of_transitions()

Expand Down Expand Up @@ -319,71 +181,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:

return reward_value_context, policy_re_context, policy_non_re_context, current_batch


# ==============================================================================
# Factory Function
# ==============================================================================

def create_priorzero_buffer(cfg, optimized: bool = True):
"""
Factory function to create PriorZero game buffer.

Args:
cfg: Configuration dict
optimized: If True, use optimized version (recommended)

Returns:
buffer: PriorZero game buffer instance
"""
if optimized:
return PriorZeroGameBufferOptimized(cfg)
else:
return PriorZeroGameBuffer(cfg)


if __name__ == "__main__":
print("="*80)
print("PriorZero Game Buffer - Unit Tests")
print("="*80)

# Create mock config
class MockConfig:
def __init__(self):
self.device = 'cpu'
self.env_type = 'not_board_games'
self.game_segment_length = 200
self.num_unroll_steps = 5
self.td_steps = 5
self.batch_size = 32
self.use_priority = False
self.reanalyze_ratio = 0.0
self.sample_type = 'transition'
self.replay_buffer_size = 10000
self.model = type('obj', (object,), {
'model_type': 'mlp',
'action_space_size': 10,
'observation_shape': 128,
})()

cfg = MockConfig()

# Test both versions
for name, buffer_class in [
("Standard", PriorZeroGameBuffer),
("Optimized", PriorZeroGameBufferOptimized)
]:
print(f"\nTesting {name} Buffer:")
print("-" * 40)

buffer = buffer_class(cfg)
print(f"✓ Buffer created: {type(buffer).__name__}")
print(f" - sample_type: {buffer.sample_type}")
print(f" - action_space_size: {buffer.action_space_size}")

# Note: Full testing would require mock GameSegments and Policy
# For now, just verify instantiation
print(f"✓ {name} buffer initialized successfully")

print("\n" + "="*80)
print("✓ All tests passed!")
print("="*80)
def _clear(self):
self.game_pos_priorities = []
self.game_segment_buffer = []
self.game_segment_game_pos_look_up = []

2 changes: 1 addition & 1 deletion lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,7 +2064,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
value_priority=value_priority,
intermediate_tensor_x=intermediate_tensor_x,
obs_embeddings=detached_obs_embeddings, # <-- 新增
)
), inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, outputs.logits_value.shape[-1])).detach()


# TODO: test correctness
Expand Down
10 changes: 0 additions & 10 deletions lzero/worker/muzero_segment_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,16 +477,6 @@ def collect(
if self.policy_config.use_ture_chance_label_in_chance_encoder:
append_kwargs['chance'] = self.chance_dict_tmp[env_id]

# [PRIORZERO-NEW] Add raw_obs_text if available in obs (not info!)
# Jericho env puts raw_obs_text in the obs dictionary
if env_id == 0 and collected_step < 5: # Debug first few steps
print(f"[OBS_DEBUG] Step {collected_step} env {env_id}: obs keys = {list(obs.keys())}")
print(f"[OBS_DEBUG] obs type = {type(obs)}")
if 'raw_obs_text' in obs:
print(f"[OBS_DEBUG] Found raw_obs_text: {str(obs['raw_obs_text'])[:100]}...")
else:
print(f"[OBS_DEBUG] NO raw_obs_text in obs!")

if 'raw_obs_text' in obs:
append_kwargs['raw_obs_text'] = obs['raw_obs_text']
elif 'raw_obs_text' in info:
Expand Down
26 changes: 23 additions & 3 deletions zoo/jericho/envs/jericho_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from collections import OrderedDict

import gym
import numpy as np
Expand Down Expand Up @@ -49,12 +50,13 @@ class JerichoEnv(BaseEnv):
'max_seq_len': 512,
'remove_stuck_actions': False,
'add_location_and_inventory': False,
# 'for_unizero': False,
'for_unizero': True,
'save_replay': False,
'save_replay_path': None,
'env_type': "zork1",
'collect_policy_mode': "agent"
'collect_policy_mode': "agent",
'use_cache': True,
'cache_size': 100000,
}

def __init__(self, cfg: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -93,6 +95,12 @@ def __init__(self, cfg: Dict[str, Any]) -> None:
self.add_location_and_inventory: bool = self.cfg['add_location_and_inventory']
self.for_unizero: bool = self.cfg['for_unizero']

self.use_cache = self.cfg['use_cache']
if self.use_cache:
self.cache_size = self.cfg['cache_size']
self.cache_buffer = OrderedDict()
print(f'[jericho]: use_cache: {self.use_cache}, cache_size={self.cache_size}')

# Initialize the tokenizer once (only in rank 0 process if distributed)
if JerichoEnv.tokenizer is None:
if self.rank == 0:
Expand Down Expand Up @@ -138,7 +146,18 @@ def prepare_obs(self, obs: str, return_str: bool = False) -> Dict[str, Any]:
raw_obs_text = obs # Save original text BEFORE any modification

if self._action_list is None:
self._action_list = self._env.get_valid_actions()
if self.use_cache:
cache_key = self._env.get_world_state_hash()
if cache_key in self.cache_buffer:
self.cache_buffer.move_to_end(cache_key)
self._action_list = self.cache_buffer[cache_key]
else:
self._action_list = self._env.get_valid_actions()
self.cache_buffer[cache_key] = self._action_list
if len(self.cache_buffer) > self.cache_size:
self.cache_buffer.popitem(last=False)
else:
self._action_list = self._env.get_valid_actions()

# Filter available actions based on whether stuck actions are removed.
if self.remove_stuck_actions:
Expand Down Expand Up @@ -344,6 +363,7 @@ def step(self, action: Union[int, np.ndarray, str], return_str: bool = False) ->
previous_obs: Optional[str] = self.last_observation if (self.remove_stuck_actions and self.last_observation is not None) else None

observation, reward, done, info = self._env.step(action_str)
info['action_str'] = action_str

self._timestep += 1
if not self.for_unizero:
Expand Down
Loading