diff --git a/lzero/mcts/buffer/game_buffer_priorzero.py b/lzero/mcts/buffer/game_buffer_priorzero.py index c9dda2cf5..7a5adbde8 100644 --- a/lzero/mcts/buffer/game_buffer_priorzero.py +++ b/lzero/mcts/buffer/game_buffer_priorzero.py @@ -18,158 +18,6 @@ from typing import List, Any, Union, Tuple from lzero.mcts.buffer.game_buffer_unizero import UniZeroGameBuffer - -class PriorZeroGameBuffer(UniZeroGameBuffer): - """ - [PRIORZERO-MODIFIED] - Enhanced GameBuffer that provides game_segments for LLM policy training. - - Modifications: - 1. sample() returns game_segments as 4th element - 2. Efficient implementation using existing game_segment_list from _make_batch - 3. No additional memory overhead (returns references, not copies) - """ - - def __init__(self, cfg): - """Initialize PriorZero Game Buffer.""" - super().__init__(cfg) - - # [PRIORZERO-NEW] Cache for the last sampled game segments - # This avoids re-sampling when we need game segments - self._last_sampled_game_segments = None - self._last_sampled_batch_indices = None - - def sample( - self, - batch_size: int, - policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] - ) -> List[Any]: - """ - [PRIORZERO-MODIFIED] - Sample data and prepare current_batch, target_batch, AND game_segments. - - Returns: - train_data: [current_batch, target_batch, game_segments] - - current_batch: [obs, action, target_action, mask, indices, weights, make_time, timestep] - - target_batch: [rewards, values, policies] - - game_segments: List of GameSegment objects used in this batch - - Note: - game_segments are returned for LLM training (SFT/RFT). - They contain: - - mcts_policy_segment: MCTS visit distributions (for SFT supervision) - - raw_obs_segment: Raw text observations (for LLM prompts) - - reward_segment: Environment rewards (for RFT) - - search_value_segment: MCTS search values (for analysis) - """ - policy._target_model.to(self._cfg.device) - policy._target_model.eval() - - # ====================================================================== - # [PRIORZERO-KEY] Sample data and extract game_segments - # ====================================================================== - # obtain the current_batch and prepare target context - reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( - batch_size, self._cfg.reanalyze_ratio - ) - - # [PRIORZERO-NEW] Extract game_segments from the sampling process - # These were already created in _make_batch, we just need to save them - game_segments = self._last_sampled_game_segments - - # Defensive check: ensure game_segments match batch_size - if game_segments is None or len(game_segments) != len(current_batch[4]): # current_batch[4] is batch_index_list - # Fallback: create empty list if something went wrong - import logging - logging.warning( - f"[PriorZeroBuffer] game_segments mismatch: " - f"expected {len(current_batch[4])}, got {len(game_segments) if game_segments else None}. " - f"Falling back to empty list (SFT/RFT will be skipped)." - ) - game_segments = [] - - # ====================================================================== - # Standard UniZero processing (unchanged) - # ====================================================================== - # current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list] - - # target reward, target value - batch_rewards, batch_target_values = self._compute_target_reward_value( - reward_value_context, policy._target_model, current_batch[2], current_batch[-1] # current_batch[2] is batch_target_action - ) - - # target policy - batch_target_policies_re = self._compute_target_policy_reanalyzed( - policy_re_context, policy._target_model, current_batch[1], current_batch[-1] - ) # current_batch[1] is batch_action - batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self.action_space_size - ) - - # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies - if 0 < self._cfg.reanalyze_ratio < 1: - batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) - elif self._cfg.reanalyze_ratio == 1: - batch_target_policies = batch_target_policies_re - elif self._cfg.reanalyze_ratio == 0: - batch_target_policies = batch_target_policies_non_re - - target_batch = [batch_rewards, batch_target_values, batch_target_policies] - - # ====================================================================== - # [PRIORZERO-KEY] Return current_batch, target_batch, AND game_segments - # ====================================================================== - train_data = [current_batch, target_batch, game_segments] - return train_data - - def _sample_orig_data(self, batch_size: int) -> Tuple[Any]: - """ - [PRIORZERO-MODIFIED] - Override to cache game_segments during sampling. - - This avoids double sampling by caching the result for sample() to use. - """ - # Call parent implementation - result = super()._sample_orig_data(batch_size) - - # Cache the game_segment_list (first element of result tuple) - game_segment_list = result[0] - self._last_sampled_game_segments = game_segment_list - self._last_sampled_batch_indices = result[2] # batch_index_list - - return result - - def _sample_orig_data_episode(self, batch_size: int) -> Tuple[Any]: - """ - [PRIORZERO-MODIFIED] - Override to cache game_segments during episode sampling. - - This avoids double sampling by caching the result for sample() to use. - """ - # Call parent implementation - result = super()._sample_orig_data_episode(batch_size) - - # Cache the game_segment_list (first element of result tuple) - game_segment_list = result[0] - self._last_sampled_game_segments = game_segment_list - self._last_sampled_batch_indices = result[2] # batch_index_list - - return result - - def clear(self): - """ - [PRIORZERO-MODIFIED] - Clear buffer and cached game segments. - """ - super().clear() - self._last_sampled_game_segments = None - self._last_sampled_batch_indices = None - - -# ============================================================================== -# Optimized Alternative (Avoids Double Sampling) -# ============================================================================== - class PriorZeroGameBufferOptimized(UniZeroGameBuffer): """ [PRIORZERO-OPTIMIZED] @@ -195,16 +43,14 @@ def sample(self, batch_size: int, policy) -> List[Any]: batch_size, self._cfg.reanalyze_ratio ) - # Get cached game segments (set by our overridden _make_batch) - game_segments = self._cached_game_segments or [] - + obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list, raw_obs_list, history_obs_list, action_logprob_list = current_batch # Standard processing batch_rewards, batch_target_values = self._compute_target_reward_value( - reward_value_context, policy._target_model, current_batch[2], current_batch[-1] + reward_value_context, policy._target_model, current_batch[2], timestep_list ) batch_target_policies_re = self._compute_target_policy_reanalyzed( - policy_re_context, policy._target_model, current_batch[1], current_batch[-1] + policy_re_context, policy._target_model, current_batch[1], timestep_list ) batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( policy_non_re_context, self.action_space_size @@ -219,7 +65,7 @@ def sample(self, batch_size: int, policy) -> List[Any]: target_batch = [batch_rewards, batch_target_values, batch_target_policies] - return [current_batch, target_batch, game_segments] + return [current_batch, target_batch] def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: """ @@ -243,6 +89,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: # Rest of the code is identical to parent's _make_batch batch_size = len(batch_index_list) obs_list, action_list, mask_list = [], [], [] + raw_obs_list, history_obs_list = [], [] + action_logprob_list = [] timestep_list = [] bootstrap_action_list = [] @@ -272,6 +120,16 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True ) ) + raw_obs_list.append(game_segment_list[i].get_unroll_raw_obs( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + )) + history_obs_list.append(game_segment_list[i].get_unroll_histroy_obs( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + )) + action_logprob_list.append(game_segment_list[i].get_unroll_action_logprob( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + )) + action_list.append(actions_tmp) mask_list.append(mask_tmp) timestep_list.append(timestep_tmp) @@ -291,6 +149,10 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list] for i in range(len(current_batch)): current_batch[i] = np.asarray(current_batch[i]) + + current_batch.append(raw_obs_list) + current_batch.append(history_obs_list) + current_batch.append(action_logprob_list) total_transitions = self.get_num_of_transitions() @@ -319,71 +181,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: return reward_value_context, policy_re_context, policy_non_re_context, current_batch - -# ============================================================================== -# Factory Function -# ============================================================================== - -def create_priorzero_buffer(cfg, optimized: bool = True): - """ - Factory function to create PriorZero game buffer. - - Args: - cfg: Configuration dict - optimized: If True, use optimized version (recommended) - - Returns: - buffer: PriorZero game buffer instance - """ - if optimized: - return PriorZeroGameBufferOptimized(cfg) - else: - return PriorZeroGameBuffer(cfg) - - -if __name__ == "__main__": - print("="*80) - print("PriorZero Game Buffer - Unit Tests") - print("="*80) - - # Create mock config - class MockConfig: - def __init__(self): - self.device = 'cpu' - self.env_type = 'not_board_games' - self.game_segment_length = 200 - self.num_unroll_steps = 5 - self.td_steps = 5 - self.batch_size = 32 - self.use_priority = False - self.reanalyze_ratio = 0.0 - self.sample_type = 'transition' - self.replay_buffer_size = 10000 - self.model = type('obj', (object,), { - 'model_type': 'mlp', - 'action_space_size': 10, - 'observation_shape': 128, - })() - - cfg = MockConfig() - - # Test both versions - for name, buffer_class in [ - ("Standard", PriorZeroGameBuffer), - ("Optimized", PriorZeroGameBufferOptimized) - ]: - print(f"\nTesting {name} Buffer:") - print("-" * 40) - - buffer = buffer_class(cfg) - print(f"✓ Buffer created: {type(buffer).__name__}") - print(f" - sample_type: {buffer.sample_type}") - print(f" - action_space_size: {buffer.action_space_size}") - - # Note: Full testing would require mock GameSegments and Policy - # For now, just verify instantiation - print(f"✓ {name} buffer initialized successfully") - - print("\n" + "="*80) - print("✓ All tests passed!") - print("="*80) + def _clear(self): + self.game_pos_priorities = [] + self.game_segment_buffer = [] + self.game_segment_game_pos_look_up = [] + \ No newline at end of file diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index d69671ac5..b2a9d7f5a 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -2064,7 +2064,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar value_priority=value_priority, intermediate_tensor_x=intermediate_tensor_x, obs_embeddings=detached_obs_embeddings, # <-- 新增 - ) + ), inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, outputs.logits_value.shape[-1])).detach() # TODO: test correctness diff --git a/lzero/worker/muzero_segment_collector.py b/lzero/worker/muzero_segment_collector.py index 7c265630b..319fa4b15 100644 --- a/lzero/worker/muzero_segment_collector.py +++ b/lzero/worker/muzero_segment_collector.py @@ -477,16 +477,6 @@ def collect( if self.policy_config.use_ture_chance_label_in_chance_encoder: append_kwargs['chance'] = self.chance_dict_tmp[env_id] - # [PRIORZERO-NEW] Add raw_obs_text if available in obs (not info!) - # Jericho env puts raw_obs_text in the obs dictionary - if env_id == 0 and collected_step < 5: # Debug first few steps - print(f"[OBS_DEBUG] Step {collected_step} env {env_id}: obs keys = {list(obs.keys())}") - print(f"[OBS_DEBUG] obs type = {type(obs)}") - if 'raw_obs_text' in obs: - print(f"[OBS_DEBUG] Found raw_obs_text: {str(obs['raw_obs_text'])[:100]}...") - else: - print(f"[OBS_DEBUG] NO raw_obs_text in obs!") - if 'raw_obs_text' in obs: append_kwargs['raw_obs_text'] = obs['raw_obs_text'] elif 'raw_obs_text' in info: diff --git a/zoo/jericho/envs/jericho_env.py b/zoo/jericho/envs/jericho_env.py index e6ac44a2b..7d5e48e28 100644 --- a/zoo/jericho/envs/jericho_env.py +++ b/zoo/jericho/envs/jericho_env.py @@ -4,6 +4,7 @@ import json from datetime import datetime from typing import Any, Dict, List, Optional, Union +from collections import OrderedDict import gym import numpy as np @@ -49,12 +50,13 @@ class JerichoEnv(BaseEnv): 'max_seq_len': 512, 'remove_stuck_actions': False, 'add_location_and_inventory': False, - # 'for_unizero': False, 'for_unizero': True, 'save_replay': False, 'save_replay_path': None, 'env_type': "zork1", - 'collect_policy_mode': "agent" + 'collect_policy_mode': "agent", + 'use_cache': True, + 'cache_size': 100000, } def __init__(self, cfg: Dict[str, Any]) -> None: @@ -93,6 +95,12 @@ def __init__(self, cfg: Dict[str, Any]) -> None: self.add_location_and_inventory: bool = self.cfg['add_location_and_inventory'] self.for_unizero: bool = self.cfg['for_unizero'] + self.use_cache = self.cfg['use_cache'] + if self.use_cache: + self.cache_size = self.cfg['cache_size'] + self.cache_buffer = OrderedDict() + print(f'[jericho]: use_cache: {self.use_cache}, cache_size={self.cache_size}') + # Initialize the tokenizer once (only in rank 0 process if distributed) if JerichoEnv.tokenizer is None: if self.rank == 0: @@ -138,7 +146,18 @@ def prepare_obs(self, obs: str, return_str: bool = False) -> Dict[str, Any]: raw_obs_text = obs # Save original text BEFORE any modification if self._action_list is None: - self._action_list = self._env.get_valid_actions() + if self.use_cache: + cache_key = self._env.get_world_state_hash() + if cache_key in self.cache_buffer: + self.cache_buffer.move_to_end(cache_key) + self._action_list = self.cache_buffer[cache_key] + else: + self._action_list = self._env.get_valid_actions() + self.cache_buffer[cache_key] = self._action_list + if len(self.cache_buffer) > self.cache_size: + self.cache_buffer.popitem(last=False) + else: + self._action_list = self._env.get_valid_actions() # Filter available actions based on whether stuck actions are removed. if self.remove_stuck_actions: @@ -344,6 +363,7 @@ def step(self, action: Union[int, np.ndarray, str], return_str: bool = False) -> previous_obs: Optional[str] = self.last_observation if (self.remove_stuck_actions and self.last_observation is not None) else None observation, reward, done, info = self._env.step(action_str) + info['action_str'] = action_str self._timestep += 1 if not self.for_unizero: diff --git a/zoo/jericho/priorzero/ensure_local_lightzero.py b/zoo/jericho/priorzero/ensure_local_lightzero.py deleted file mode 100644 index 7a697176b..000000000 --- a/zoo/jericho/priorzero/ensure_local_lightzero.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Utility module to ensure local LightZero is used across all PriorZero modules. - -This ensures PriorZero uses the local LightZero installation at: -/mnt/nfs/zhangjinouwen/puyuan/LightZero - -Usage: - Import this at the beginning of any PriorZero module: - - from ensure_local_lightzero import ensure_local_lightzero - ensure_local_lightzero() -""" - -import sys -from pathlib import Path - - -def ensure_local_lightzero(): - """ - Ensures the local LightZero path is first in sys.path. - - This allows PriorZero to use a LightZero version that has been - specifically adapted for PriorZero, rather than a globally installed version. - - Also adds the PriorZero directory to sys.path to ensure PriorZero modules - can be imported. - """ - LIGHTZERO_ROOT = Path("/mnt/nfs/zhangjinouwen/puyuan/LightZero").resolve() - PRIORZERO_DIR = Path(__file__).parent.resolve() - - if not LIGHTZERO_ROOT.exists(): - print(f"⚠️ Warning: LightZero root not found at {LIGHTZERO_ROOT}") - return False - - lightzero_str = str(LIGHTZERO_ROOT) - priorzero_str = str(PRIORZERO_DIR) - - # Remove any existing LightZero paths from sys.path - sys.path = [p for p in sys.path if 'LightZero' not in p or p == lightzero_str] - - # Insert local LightZero at the beginning - if lightzero_str not in sys.path: - sys.path.insert(0, lightzero_str) - - # Also ensure PriorZero directory is in sys.path for module imports - if priorzero_str not in sys.path: - sys.path.insert(0, priorzero_str) - - # Verify - try: - import lzero - lzero_path = Path(lzero.__file__).parent.parent - - if lzero_path == LIGHTZERO_ROOT: - print(f"✓ Using local LightZero: {lzero_path}") - print(f"✓ PriorZero modules path: {priorzero_str}") - return True - else: - print(f"⚠️ Warning: Using LightZero from {lzero_path}") - print(f" Expected: {LIGHTZERO_ROOT}") - return False - except ImportError as e: - print(f"⚠️ Warning: Could not import lzero: {e}") - return False - - -# Auto-ensure on import -ensure_local_lightzero() diff --git a/zoo/jericho/priorzero/fix_environment.sh b/zoo/jericho/priorzero/fix_environment.sh deleted file mode 100644 index 8876f54df..000000000 --- a/zoo/jericho/priorzero/fix_environment.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -# fix_environment.sh -# Fix numpy version conflicts and other dependency issues - -echo "==========================================" -echo "Fixing PriorZero Environment Dependencies" -echo "==========================================" - -# 1. Fix numpy version (downgrade to 1.26.4 for compatibility) -echo "" -echo "1. Fixing numpy version..." -pip install "numpy<2,>=1.24.1" --force-reinstall --no-deps - -# 2. Reinstall conflicting packages -echo "" -echo "2. Reinstalling di-engine and lightzero..." -pip install di-engine==0.5.3 --no-deps -pip install lightzero==0.2.0 --no-deps - -# 3. Verify installations -echo "" -echo "3. Verifying installations..." -python -c "import numpy; print(f'numpy version: {numpy.__version__}')" -python -c "import torch; print(f'torch version: {torch.__version__}')" -python -c "import vllm; print(f'vllm version: {vllm.__version__}')" - -echo "" -echo "==========================================" -echo "Environment fix complete!" -echo "==========================================" -echo "" -echo "Now you can run:" -echo " python priorzero_config.py" -echo " python game_segment_priorzero.py" -echo " python priorzero_entry.py --quick_test" diff --git a/zoo/jericho/priorzero/game_segment_priorzero.py b/zoo/jericho/priorzero/game_segment_priorzero.py index 654b93e5c..46eb46a18 100644 --- a/zoo/jericho/priorzero/game_segment_priorzero.py +++ b/zoo/jericho/priorzero/game_segment_priorzero.py @@ -1,20 +1,3 @@ -# game_segment_priorzero.py -""" -[PRIORZERO] Enhanced Game Segment for PriorZero - -This module extends the standard GameSegment to store additional information -needed for LLM policy training (SFT + RFT). - -Key Features: -- Store MCTS policy distributions for SFT training -- Store raw text observations for LLM prompt construction -- Store LLM generated priors for analysis and debugging -- Store search values for priority calculation - -Author: PriorZero Team -Date: 2025-01-20 -""" - import numpy as np from typing import Optional, List, Any from lzero.mcts.buffer.game_segment import GameSegment as OriginalGameSegment @@ -50,13 +33,11 @@ def __init__( """ super().__init__(action_space, game_segment_length, config, task_id) - # [PRIORZERO-NEW] Additional segments for LLM training - self.mcts_policy_segment = [] # MCTS visit count distributions self.raw_obs_segment = [] # Raw text observations - self.llm_prior_segment = [] # LLM generated priors (for debugging) - self.search_value_segment = [] # MCTS search values + self.history_obs_segment = [] + self.action_logprob_segment = [] # Logprob of chosen action (for PPO/RFT) - def reset(self, init_observations: List[np.ndarray]) -> None: + def reset(self, init_observations: List[np.ndarray], init_raw_obs, init_history_obs, init_action_logprob) -> None: """ [PRIORZERO-MODIFIED] Reset the segment with initial observations. @@ -65,12 +46,13 @@ def reset(self, init_observations: List[np.ndarray]) -> None: init_observations: List of initial frame stack observations """ super().reset(init_observations) - - # Clear PriorZero-specific segments - self.mcts_policy_segment.clear() self.raw_obs_segment.clear() - self.llm_prior_segment.clear() - self.search_value_segment.clear() + self.history_obs_segment.clear() + self.action_logprob_segment.clear() + + self.raw_obs_segment.append(init_raw_obs) + self.history_obs_segment.append(init_history_obs) + self.action_logprob_segment.append(init_action_logprob) def append( self, @@ -79,6 +61,11 @@ def append( reward: float, action_mask: np.ndarray, to_play: int, + timestep: int = 0, + chance: int = 0, + raw_obs_text: Optional[str] = None, + history_obs: Optional[List[str]] = None, + action_logprob: Optional[float] = None, **kwargs ) -> None: """ @@ -93,36 +80,13 @@ def append( to_play: Player ID (for multi-agent) **kwargs: Additional arguments (timestep, chance, raw_obs_text, llm_prior_text) """ - # [PRIORZERO-NEW] Extract PriorZero-specific kwargs before passing to parent - raw_obs_text = kwargs.pop('raw_obs_text', None) - llm_prior_text = kwargs.pop('llm_prior_text', None) - - # [DEBUG] Log first few appends to see what's being passed - if len(self.raw_obs_segment) < 3: - print(f"[SEGMENT_DEBUG] append() called: kwargs keys = {list(kwargs.keys())}") - print(f"[SEGMENT_DEBUG] raw_obs_text = {raw_obs_text[:50] if raw_obs_text else 'None'}...") - # Call parent append with remaining kwargs - super().append(action, obs, reward, action_mask, to_play, **kwargs) - - # [PRIORZERO-NEW] Initialize placeholders for new segments - # These will be filled in by store_search_stats() - self.mcts_policy_segment.append(None) - self.search_value_segment.append(None) - - # [PRIORZERO-NEW] Store raw text observation if provided + super().append(action, obs, reward, action_mask, to_play, timestep, chance) self.raw_obs_segment.append(raw_obs_text) + self.history_obs_segment.append(history_obs) + self.action_logprob_segment.append(action_logprob) - # [PRIORZERO-NEW] Store LLM prior text if provided (for debugging) - self.llm_prior_segment.append(llm_prior_text) - - def store_search_stats( - self, - root_visit_dist: List[float], - value: float, - *args, - **kwargs - ) -> None: + def store_search_stats(self, visit_counts: List, root_value: List) -> None: """ [PRIORZERO-MODIFIED] Store MCTS search statistics. @@ -138,32 +102,7 @@ def store_search_stats( *args: Additional positional arguments (for compatibility) **kwargs: Additional keyword arguments (improved_policy, etc.) """ - # [FIX] Handle NaN values - import numpy as np - if value is None or (isinstance(value, float) and np.isnan(value)): - # Use 0.0 as default for NaN values - value = 0.0 - - # Call parent method to store standard statistics - super().store_search_stats(root_visit_dist, value, *args, **kwargs) - - # [PRIORZERO-NEW] Store MCTS policy distribution - # Convert to numpy array and normalize to probability distribution - policy_array = np.array(root_visit_dist, dtype=np.float32) - - if policy_array.sum() > 0: - policy_array = policy_array / policy_array.sum() - else: - # If no visits (shouldn't happen), use uniform distribution - policy_array = np.ones_like(policy_array) / len(policy_array) - - # Update the most recent position (corresponding to last append) - if len(self.mcts_policy_segment) > 0: - self.mcts_policy_segment[-1] = policy_array - - # [PRIORZERO-NEW] Store search value - if len(self.search_value_segment) > 0: - self.search_value_segment[-1] = float(value) + super().store_search_stats(visit_counts, root_value) def game_segment_to_array(self) -> None: """ @@ -175,287 +114,74 @@ def game_segment_to_array(self) -> None: """ # Call parent method to convert standard segments super().game_segment_to_array() - - # [PRIORZERO-NEW] Convert PriorZero-specific segments to arrays - # Use object dtype to handle variable-length arrays and None values - self.mcts_policy_segment = np.array(self.mcts_policy_segment, dtype=object) - self.search_value_segment = np.array(self.search_value_segment, dtype=np.float32) - - # For text data, keep as list (more flexible for variable-length strings) - # self.raw_obs_segment and self.llm_prior_segment remain as lists - - def get_stats(self) -> dict: - """ - [PRIORZERO-NEW] - Get statistics about this game segment. - - Returns: - stats: Dictionary of statistics - """ - stats = { - 'segment_length': len(self.reward_segment) if hasattr(self, 'reward_segment') else 0, - 'total_reward': sum(self.reward_segment) if hasattr(self, 'reward_segment') else 0, - 'num_mcts_policies': sum(1 for p in self.mcts_policy_segment if p is not None), - 'num_raw_obs': sum(1 for o in self.raw_obs_segment if o is not None), - 'num_llm_priors': sum(1 for p in self.llm_prior_segment if p is not None), - 'avg_search_value': np.mean([v for v in self.search_value_segment if v is not None]) if any(v is not None for v in self.search_value_segment) else 0.0, - } - return stats - - def get_mcts_policy_for_training(self, index: int) -> Optional[np.ndarray]: - """ - [PRIORZERO-NEW] - Get MCTS policy at a specific index for training. - - Args: - index: Index in the segment - - Returns: - policy: MCTS policy distribution, or None if not available - """ - if 0 <= index < len(self.mcts_policy_segment): - return self.mcts_policy_segment[index] - return None - - def get_raw_obs_for_training(self, index: int) -> Optional[str]: - """ - [PRIORZERO-NEW] - Get raw text observation at a specific index for training. - - Args: - index: Index in the segment - - Returns: - raw_obs: Raw text observation, or None if not available - """ - if 0 <= index < len(self.raw_obs_segment): - return self.raw_obs_segment[index] - return None - - def get_history_for_training(self, index: int, history_length: int = 5) -> List[tuple]: - """ - [PRIORZERO-NEW] - Get history context for LLM prompting. - - Args: - index: Current index in the segment - history_length: Number of past transitions to include - - Returns: - history: List of (obs, action, reward) tuples - """ - history = [] - - # Get recent transitions - start_idx = max(0, index - history_length) - for i in range(start_idx, index): - if i < len(self.raw_obs_segment) and i < len(self.action_segment) and i < len(self.reward_segment): - obs_text = self.raw_obs_segment[i] - action_id = self.action_segment[i] - reward = self.reward_segment[i] - - # Only add if observation is available - if obs_text is not None: - history.append((obs_text, action_id, reward)) - - return history - - def __repr__(self) -> str: - """ - [PRIORZERO-MODIFIED] - String representation with PriorZero statistics. - """ - base_repr = super().__repr__() - stats = self.get_stats() - - priorzero_info = ( - f"\n MCTS policies: {stats['num_mcts_policies']}" - f"\n Raw observations: {stats['num_raw_obs']}" - f"\n LLM priors: {stats['num_llm_priors']}" - f"\n Avg search value: {stats['avg_search_value']:.3f}" + self.action_logprob_segment = np.asarray(self.action_logprob_segment) + + def pad_over( + self, next_segment_observations: List, next_segment_rewards: List, next_segment_actions: List, next_segment_root_values: List, + next_segment_child_visits: List, next_segment_improved_policy: List = None, next_chances: List = None, + next_segment_raw_obs: List = None, next_segment_history_obs: List = None, next_segment_action_logprob: List = None + ) -> None: + super().pad_over( + next_segment_observations, next_segment_rewards, next_segment_actions, next_segment_root_values, + next_segment_child_visits, next_segment_improved_policy, next_chances ) - - return base_repr + priorzero_info - + assert len(next_segment_raw_obs) <= self.num_unroll_steps + self.td_steps + assert len(next_segment_history_obs) <= self.num_unroll_steps + self.td_steps + assert len(next_segment_action_logprob) <= self.num_unroll_steps + self.td_steps + import copy + for raw_obs in next_segment_raw_obs: + self.raw_obs_segment.append(copy.deepcopy(raw_obs)) + for history_obs in next_segment_history_obs: + self.history_obs_segment.append(copy.deepcopy(history_obs)) + for lp in next_segment_action_logprob: + self.action_logprob_segment.append(copy.deepcopy(lp)) + + def get_unroll_raw_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: + """ + Overview: + Get an observation of the correct format: o[t, t + stack frames + num_unroll_steps]. + Arguments: + - timestep (int): The time step. + - num_unroll_steps (int): The extra length of the observation frames. + - padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory. + """ + stacked_raw_obs = self.raw_obs_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps] + if padding: + pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_raw_obs) + if pad_len > 0: + pad_frames = np.array([stacked_raw_obs[-1] for _ in range(pad_len)]) + stacked_raw_obs = np.concatenate((stacked_raw_obs, pad_frames)) + return stacked_raw_obs + + def get_unroll_histroy_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: + """ + Overview: + Get an observation of the correct format: o[t, t + stack frames + num_unroll_steps]. + Arguments: + - timestep (int): The time step. + - num_unroll_steps (int): The extra length of the observation frames. + - padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory. + """ + stacked_histroy_obs = self.history_obs_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps] + if padding: + pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_histroy_obs) + if pad_len > 0: + pad_frames = np.array([stacked_histroy_obs[-1] for _ in range(pad_len)]) + stacked_histroy_obs = np.concatenate((stacked_histroy_obs, pad_frames)) + return stacked_histroy_obs + + def get_unroll_action_logprob(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: + """ + Return action logprobs aligned with actions for unroll window. + """ + stacked_logprob = list(self.action_logprob_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps]) + if padding: + pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_logprob) + if pad_len > 0: + pad_frames = np.array([stacked_logprob[-1] for _ in range(pad_len)]) + stacked_logprob = np.concatenate((stacked_logprob, pad_frames)) + return stacked_logprob # ============================================================================== # Utility Functions # ============================================================================== - -def create_priorzero_game_segment( - action_space, - game_segment_length: int = 200, - config: Optional[Any] = None, - task_id: Optional[int] = None -) -> GameSegment: - """ - Factory function to create a PriorZero GameSegment. - - Args: - action_space: Action space from environment - game_segment_length: Maximum length of the segment - config: Policy configuration - task_id: Task ID for multi-task learning - - Returns: - segment: PriorZero GameSegment instance - """ - return GameSegment(action_space, game_segment_length, config, task_id) - - -def validate_game_segment(segment: GameSegment) -> bool: - """ - Validate that a GameSegment has consistent data. - - Args: - segment: GameSegment to validate - - Returns: - is_valid: True if segment is valid, False otherwise - """ - try: - # Check basic lengths - if not hasattr(segment, 'obs_segment'): - return False - - base_length = len(segment.obs_segment) - - # Check that all segments have compatible lengths - if hasattr(segment, 'action_segment'): - if len(segment.action_segment) != base_length: - return False - - if hasattr(segment, 'reward_segment'): - if len(segment.reward_segment) != base_length: - return False - - # Check PriorZero-specific segments - if len(segment.mcts_policy_segment) != base_length: - return False - - if len(segment.raw_obs_segment) != base_length: - return False - - # Check that MCTS policies are valid when present - for policy in segment.mcts_policy_segment: - if policy is not None: - if not isinstance(policy, np.ndarray): - return False - if policy.sum() < 0.99 or policy.sum() > 1.01: # Should sum to ~1.0 - return False - if np.any(policy < 0): # Should be non-negative - return False - - return True - - except Exception as e: - print(f"Validation error: {e}") - return False - - -# ============================================================================== -# Example Usage and Testing -# ============================================================================== - -if __name__ == "__main__": - print("="*80) - print("Testing PriorZero GameSegment") - print("="*80) - - # Create a mock action space - class MockActionSpace: - def __init__(self, n): - self.n = n - - # Create a mock config with all required attributes - class MockConfig: - def __init__(self): - self.num_unroll_steps = 10 - self.td_steps = 5 - self.discount_factor = 0.99 - self.gray_scale = False - self.transform2string = False - self.sampled_algo = False - self.gumbel_algo = False - self.use_ture_chance_label_in_chance_encoder = False - self.model = type('obj', (object,), { - 'frame_stack_num': 4, - 'action_space_size': 10, - 'observation_shape': (84, 84, 3), - 'image_channel': 3 - })() - - action_space = MockActionSpace(n=10) - mock_config = MockConfig() - - # Create a game segment - segment = GameSegment(action_space, game_segment_length=100, config=mock_config) - - # Reset with initial observations - init_obs = [np.zeros((84, 84, 3)) for _ in range(4)] - segment.reset(init_obs) - - print("\n1. Empty segment:") - print(f" Length: {len(segment.obs_segment)}") - print(f" MCTS policies: {len(segment.mcts_policy_segment)}") - - # Simulate some transitions - print("\n2. Adding transitions...") - for i in range(5): - obs = np.random.rand(84, 84, 3) - action = np.random.randint(0, 10) - reward = np.random.randn() - action_mask = np.ones(10) - - # Append transition - segment.append( - action, obs, reward, action_mask, to_play=0, - raw_obs_text=f"You see a room. Step {i}.", - llm_prior_text=f"Top actions: go north, take key" - ) - - # Store MCTS stats - visit_dist = np.random.dirichlet([1.0] * 10).tolist() - value = np.random.randn() - segment.store_search_stats(visit_dist, value) - - print(f" Added {len(segment.obs_segment)} transitions") - - # Get statistics - print("\n3. Segment statistics:") - stats = segment.get_stats() - for key, value in stats.items(): - print(f" {key}: {value}") - - # Test retrieval functions - print("\n4. Testing retrieval functions:") - mcts_policy = segment.get_mcts_policy_for_training(2) - print(f" MCTS policy at index 2: {mcts_policy is not None}") - if mcts_policy is not None: - print(f" Shape: {mcts_policy.shape}") - print(f" Sum: {mcts_policy.sum():.3f}") - - raw_obs = segment.get_raw_obs_for_training(2) - print(f" Raw obs at index 2: {raw_obs}") - - history = segment.get_history_for_training(4, history_length=3) - print(f" History for index 4: {len(history)} transitions") - - # Validate segment - print("\n5. Validating segment:") - is_valid = validate_game_segment(segment) - print(f" Is valid: {is_valid}") - - # Convert to array - print("\n6. Converting to array:") - segment.game_segment_to_array() - print(f" MCTS policy type: {type(segment.mcts_policy_segment)}") - print(f" Search value type: {type(segment.search_value_segment)}") - - # Print representation - print("\n7. Segment representation:") - print(segment) - - print("\n" + "="*80) - print("✓ All tests passed!") - print("="*80) diff --git a/zoo/jericho/priorzero/priorzero_collector.py b/zoo/jericho/priorzero/priorzero_collector.py index 1fb6e53c7..d852311d1 100644 --- a/zoo/jericho/priorzero/priorzero_collector.py +++ b/zoo/jericho/priorzero/priorzero_collector.py @@ -1,43 +1,26 @@ -# priorzero_collector.py -""" -[PRIORZERO] PriorZero Collector Implementation - -This module implements async data collection with LLM prior integration. - -Key Features: -- Async LLM inference using vLLM for efficient batch generation -- History buffer management for context-aware prompting -- Error handling and retry logic for robust LLM calls -- Full alignment with UniZero collector architecture - -Author: PriorZero Team -Date: 2025-01-20 -""" - import asyncio import logging import sys import time +import cProfile +from contextlib import contextmanager from collections import deque, defaultdict from pathlib import Path from typing import Optional, Any, List, Dict, Tuple -# [CRITICAL] Ensure local LightZero is used -from ensure_local_lightzero import ensure_local_lightzero -ensure_local_lightzero() - import numpy as np import torch from ding.envs import BaseEnvManager from ding.torch_utils import to_ndarray -from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY -from vllm import AsyncLLMEngine, SamplingParams +from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, allreduce_data +from vllm import SamplingParams +import os # Import from local LightZero from lzero.worker.muzero_segment_collector import MuZeroSegmentCollector as OriginalCollector from lzero.mcts.utils import prepare_observation from game_segment_priorzero import GameSegment - +from priorzero_policy import build_llm_prompt # ============================================================================== # Helper Functions @@ -91,10 +74,8 @@ def extract_raw_obs_text(obs_dict: Dict[str, Any]) -> str: class PriorZeroCollector(OriginalCollector): """ [PRIORZERO-MODIFIED] - Async collector that integrates LLM priors into MCTS-based data collection. Features: - - Async LLM inference with vLLM engine - History buffer for each environment (sliding window) - Robust error handling with retries - Detailed logging of LLM prior statistics @@ -102,7 +83,7 @@ class PriorZeroCollector(OriginalCollector): def __init__( self, - vllm_engine: AsyncLLMEngine, + llm_prior_generator, policy_config: Dict, **kwargs ): @@ -110,7 +91,7 @@ def __init__( Initialize PriorZeroCollector. Args: - vllm_engine: vLLM async engine for LLM inference + vllm_engine policy_config: Policy configuration (contains llm_policy_cfg) **kwargs: Additional arguments for parent class """ @@ -118,13 +99,9 @@ def __init__( # because parent class needs it kwargs['policy_config'] = policy_config - # Extract debug_mode before passing to parent (parent doesn't accept this parameter) - self.debug_mode = kwargs.pop('debug_mode', False) - super().__init__(**kwargs) - self.vllm_engine = vllm_engine - # self.policy_config already set by parent class from kwargs + self.llm_prior_generator = llm_prior_generator self.llm_policy_cfg = policy_config.llm_policy_cfg # [PRIORZERO-NEW] History buffer for each environment @@ -132,206 +109,149 @@ def __init__( self.history_buffers = defaultdict( lambda: deque(maxlen=self.llm_policy_cfg.history_length) ) - - # [PRIORZERO-NEW] Statistics for logging - self.llm_stats = { - 'total_calls': 0, - 'successful_calls': 0, - 'failed_calls': 0, - 'retry_count': 0, - 'total_latency': 0.0, - 'llm_prior_top1_match_count': 0, # How often LLM top-1 matches MCTS choice - } + self.prompt_log_interval = getattr(self.llm_policy_cfg, 'prompt_log_interval', 0) + + self.profile_cfg = getattr(self.policy_config, 'profile_cfg', {}) + self._profile_enabled = bool(self.profile_cfg.get('enable_cprofile', False)) + self._profile_log_interval = int(self.profile_cfg.get('log_interval', 50)) + self._profile_dir = f"./{self._exp_name}/log/profile" + self._profile_stats = { 'collect_get_llm_prior_profile': {'count': 0, 'total': 0.0, 'max': 0.0}, + 'collect_step_profile': {'count': 0, 'total': 0.0, 'max': 0.0}, + 'collect_forward_profile': {'count': 0, 'total': 0.0, 'max': 0.0} + } + self._profile_stats_file = f'{self._profile_dir}/collector_time.log' + if self._profile_enabled: + os.makedirs(self._profile_dir, exist_ok=True) + + # Where to persist sampled LLM outputs during collect + self._llm_output_log_path = f"./{self._exp_name}/log/collector/llm_output.log" + self._llm_call_count = 0 + self._llm_prior_req_counter = 0 self._logger.info("✓ PriorZeroCollector initialized with vLLM engine") self._logger.info(f" - History length: {self.llm_policy_cfg.history_length}") self._logger.info(f" - Generate max length: {self.llm_policy_cfg.generate_max_len}") - # [PRIORZERO-NEW] Use custom GameSegment - self.GameSegment = GameSegment + def pad_and_save_last_trajectory( + self, i: int, last_game_segments: List[GameSegment], last_game_priorities: List[np.ndarray], + game_segments: List[GameSegment], done: np.ndarray + ) -> None: + beg_index = self.policy_config.model.frame_stack_num + end_index = beg_index + self.policy_config.num_unroll_steps + self.policy_config.td_steps + + pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] + pad_raw_obs_lst = game_segments[i].raw_obs_segment[beg_index:end_index] + pad_history_obs_lst = game_segments[i].history_obs_segment[beg_index:end_index] + pad_action_logprob_lst = game_segments[i].action_logprob_segment[beg_index:end_index] + + # NOTE: Specific padding logic for UniZero. + pad_action_lst = game_segments[i].action_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] + pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] + + beg_index = 0 + end_index = beg_index + self.unroll_plus_td_steps - 1 + pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] + + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_lst = game_segments[i].chance_segment[beg_index:end_index] + + beg_index = 0 + end_index = beg_index + self.unroll_plus_td_steps + pad_root_values_lst = game_segments[i].root_value_segment[beg_index:end_index] + + if self.policy_config.gumbel_algo: + pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index:end_index] + + # Pad and finalize the last game segment. + if self.policy_config.gumbel_algo: + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, + next_segment_improved_policy=pad_improved_policy_prob + ) + else: + if self.policy_config.use_ture_chance_label_in_chance_encoder: + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, + next_chances=chance_lst, next_segment_raw_obs=pad_raw_obs_lst, + next_segment_history_obs=pad_history_obs_lst, next_segment_action_logprob=pad_action_logprob_lst + ) + else: + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, + next_segment_raw_obs=pad_raw_obs_lst, next_segment_history_obs=pad_history_obs_lst, + next_segment_action_logprob=pad_action_logprob_lst + ) + + last_game_segments[i].game_segment_to_array() - async def _async_get_llm_prior( + # Add the completed game segment to the pool. + self.game_segment_pool.append((last_game_segments[i], last_game_priorities[i], done[i])) + + # Reset placeholders for the next collection cycle. + last_game_segments[i] = None + last_game_priorities[i] = None + + def _get_llm_prior( self, states: List[str], - request_ids: List[str], + valid_actions_list: List[List[str]], histories: Optional[List[List[Tuple[str, str, float]]]] = None, - max_retries: int = 3, - timeout: float = 30.0 ) -> List[Any]: """ - [PRIORZERO-NEW] - Async call to LLM to get action ranking priors. - - Args: - states: List of current observation texts - request_ids: List of unique request IDs for tracking - histories: Optional list of history tuples for each state - max_retries: Maximum number of retries on failure - timeout: Timeout in seconds for each request - - Returns: - llm_outputs: List of vLLM output objects + [PRIORZERO-SEQUENCE-SCORING] + Ensures every action has a logprob by retrying and falling back if needed. """ - # [FIX] Check if vLLM engine is available - if self.vllm_engine is None: - self._logger.info("INFO: vLLM engine not available, skipping LLM prior") - return [None] * len(states) - - from priorzero_policy import build_llm_prompt - - # Build prompts - prompts = [] - for i, state in enumerate(states): - history = histories[i] if histories is not None else None - - # Build instruction using the helper function from policy - instruction = build_llm_prompt( - current_obs=state, - history=history, - use_cot=self.llm_policy_cfg.use_cot - ) - # Apply chat template if policy has tokenizer - if hasattr(self._policy, 'llm_tokenizer'): - prompt = self._policy.llm_tokenizer.apply_chat_template( - [{"role": "user", "content": instruction}], - tokenize=False, - add_generation_prompt=True + assert self.llm_prior_generator is not None, "llm_prior_generator is None." + all_prompts = [] + all_labels = [] + for i, actions in enumerate(valid_actions_list): + state = states[i] + history = histories[i] + prompt = build_llm_prompt(current_obs=state, history=history, use_cot=self.llm_policy_cfg.use_cot) + for action in actions: + all_prompts.append(prompt) + all_labels.append(action) + + all_prior_scores = self.llm_prior_generator._generate_vllm(all_prompts, all_labels, reduction='mean') + llm_prior, idx = [], 0 + for env_id in range(len(states)): + tmp_dict = {} + for action in valid_actions_list[env_id]: + tmp_dict[action] = all_prior_scores[idx] + idx = idx + 1 + llm_prior.append(tmp_dict) + return llm_prior + + @contextmanager + def _profile_block(self, name: str): + if not self._profile_enabled: + yield None + return + profiler = cProfile.Profile() + start_time = time.perf_counter() + profiler.enable() + try: + yield profiler + finally: + profiler.disable() + elapsed = time.perf_counter() - start_time + self._record_profile_time(name, elapsed) + + def _record_profile_time(self, name: str, elapsed: float) -> None: + log_every = max(1, self._profile_log_interval) + self._profile_stats[name]['count'] += 1 + self._profile_stats[name]['total'] += elapsed + self._profile_stats[name]['max'] = max(self._profile_stats[name]['max'], elapsed) + if self._profile_stats[name]['count'] % log_every == 0: + avg = self._profile_stats[name]['total'] / self._profile_stats[name]['count'] + with open(self._profile_stats_file, mode='a', encoding='utf-8') as f: + f.write( + f"{time.time():.3f}\tname={name}\tcount={self._profile_stats[name]['count']}\t" + f"total_s={self._profile_stats[name]['total']:.4f}\tavg_s={avg:.4f}\tmax_s={self._profile_stats[name]['max']:.4f}\n" ) - else: - prompt = instruction - - # [FIX] Ensure prompt is a string - if prompt is None: - self._logger.error(f"[ERROR] Prompt {i} is None! Instruction was: {instruction[:100] if instruction else 'None'}") - prompt = "" # Fallback to empty string - elif not isinstance(prompt, str): - self._logger.error(f"[ERROR] Prompt {i} is not a string! Type: {type(prompt)}, Value: {prompt}") - prompt = str(prompt) # Force conversion to string - - prompts.append(prompt) - - # Configure sampling parameters - sampling_params = SamplingParams( - temperature=1.0, - top_p=1.0, - max_tokens=self.llm_policy_cfg.generate_max_len, - skip_special_tokens=False, - ) - - # Retry logic - for attempt in range(max_retries): - try: - start_time = time.time() - - # [DEBUG] Log prompts and parameters before generation - if self.debug_mode and attempt == 0: - self._logger.info(f"[DEBUG] Sending {len(prompts)} prompts to vLLM engine") - for i, prompt in enumerate(prompts[:2]): # Show first 2 prompts - self._logger.info(f"[DEBUG] Prompt {i} (len={len(prompt)}): {prompt[:200]}...") - self._logger.info(f"[DEBUG] Sampling params: temp={sampling_params.temperature}, max_tokens={sampling_params.max_tokens}, top_p={sampling_params.top_p}") - self._logger.info(f"[DEBUG] Request IDs: {request_ids[:2]}...") - - # [FIX] vLLM V1 generate() takes single prompt, not list - # Create generators for each prompt individually - generators = [] - for i, (prompt, req_id) in enumerate(zip(prompts, request_ids)): - gen = self.vllm_engine.generate( - prompt, # Single prompt string - sampling_params, - req_id # Single request_id string - ) - generators.append((i, gen)) - - # Collect results - llm_outputs = [None] * len(prompts) - - try: - # Collect all results concurrently - async def collect_from_generator(idx, gen): - """Collect final result from a generator""" - final_result = None - async for result in gen: - final_result = result - # Check timeout - if time.time() - start_time > timeout: - raise asyncio.TimeoutError(f"LLM generation timeout after {timeout}s") - return idx, final_result - - # Gather all results concurrently - tasks = [collect_from_generator(idx, gen) for idx, gen in generators] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Process results - for result in results: - if isinstance(result, Exception): - raise result - idx, output = result - llm_outputs[idx] = output - - except asyncio.TimeoutError: - self._logger.warning(f"⚠ LLM generation timeout after {timeout}s (attempt {attempt+1}/{max_retries})") - if attempt < max_retries - 1: - self.llm_stats['retry_count'] += 1 - continue - else: - # On final timeout, return None for all - self.llm_stats['failed_calls'] += len(prompts) - return [None] * len(prompts) - - # Check if all outputs were received - if None in llm_outputs: - missing_count = llm_outputs.count(None) - self._logger.warning(f"⚠ {missing_count}/{len(prompts)} LLM outputs missing (attempt {attempt+1}/{max_retries})") - if attempt < max_retries - 1: - self.llm_stats['retry_count'] += 1 - continue - - # Success - elapsed = time.time() - start_time - self.llm_stats['total_calls'] += len(prompts) - self.llm_stats['successful_calls'] += len([o for o in llm_outputs if o is not None]) - self.llm_stats['failed_calls'] += len([o for o in llm_outputs if o is None]) - self.llm_stats['total_latency'] += elapsed - - self._logger.debug(f"✓ LLM generation completed in {elapsed:.2f}s ({len(prompts)} prompts)") - - # [DEBUG] Log detailed LLM outputs if debug mode is enabled - if self.debug_mode: - for i, (prompt, output) in enumerate(zip(prompts, llm_outputs)): - if output is not None: - output_text = output.outputs[0].text if output.outputs else "[No output]" - self._logger.info(f"[DEBUG] Env {i} - Prompt: {prompt[:100]}... -> LLM Output: {output_text[:100]}...") - else: - self._logger.warning(f"[DEBUG] Env {i} - LLM output is None") - - return llm_outputs - - except Exception as e: - import traceback - error_msg = f"{type(e).__name__}: {str(e)}" if str(e) else type(e).__name__ - error_trace = traceback.format_exc() - - # [FIX] Always log the full traceback on first attempt or in debug mode - if attempt == 0 or self.debug_mode: - self._logger.error(f"✗ LLM generation error (attempt {attempt+1}/{max_retries}): {error_msg}") - self._logger.error(f"Full traceback:\n{error_trace}") - else: - self._logger.error(f"✗ LLM generation error (attempt {attempt+1}/{max_retries}): {error_msg}") - - if attempt < max_retries - 1: - self.llm_stats['retry_count'] += 1 - await asyncio.sleep(0.5) # Brief pause before retry - else: - # Final failure - self._logger.error(f"✗ LLM generation failed after {max_retries} attempts. Last error: {error_msg}") - self._logger.error(f"Final traceback:\n{error_trace}") - self.llm_stats['failed_calls'] += len(prompts) - return [None] * len(prompts) - - return [None] * len(prompts) - - async def collect( + + def collect( self, num_segments: Optional[int] = None, train_iter: int = 0, @@ -344,9 +264,8 @@ async def collect( Main changes from parent: 1. Extract text observations from environment - 2. Async call to LLM to get action priors - 3. Pass LLM priors to policy forward pass - 4. Update history buffers after each step + 2. Pass LLM priors to policy forward pass + 3. Update history buffers after each step Args: num_segments: Number of segments to collect @@ -372,29 +291,26 @@ async def collect( temperature = policy_kwargs.get('temperature', 1.0) epsilon = policy_kwargs.get('epsilon', 0.0) - # ================================================================== - # Initialization - # ================================================================== collected_episode = 0 collected_step = 0 + llm_prior_entropy = [[] for _ in range(self._env_num)] env_nums = self._env_num init_obs = self._env.ready_obs - # Wait for all environments to be ready retry_waiting_time = 0.05 while len(init_obs.keys()) != env_nums: self._logger.info(f'Waiting for all environments to reset. Ready: {list(init_obs.keys())}') time.sleep(retry_waiting_time) init_obs = self._env.ready_obs - # Initialize state tracking for env_id in range(env_nums): if env_id in init_obs: self.action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) self.to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) self.timestep_dict[env_id] = to_ndarray(init_obs[env_id].get('timestep', -1)) - # Initialize game segments + last_game_segments = [None for _ in range(env_nums)] + last_game_priorities = [None for _ in range(env_nums)] game_segments = [ GameSegment( self._env.action_space, @@ -404,7 +320,6 @@ async def collect( ) for _ in range(env_nums) ] - # Initialize observation stacks observation_window_stack = [ deque(maxlen=self.policy_config.model.frame_stack_num) for _ in range(env_nums) @@ -415,13 +330,12 @@ async def collect( for _ in range(self.policy_config.model.frame_stack_num) ] observation_window_stack[env_id].extend(initial_frames) - game_segments[env_id].reset(observation_window_stack[env_id]) + game_segments[env_id].reset(observation_window_stack[env_id], init_raw_obs=extract_raw_obs_text(init_obs[env_id]), + init_history_obs=list(self.history_buffers[env_id]), init_action_logprob=None) - # Priority calculation lists search_values_lst = [[] for _ in range(env_nums)] pred_values_lst = [[] for _ in range(env_nums)] - # Logging variables eps_steps_lst = np.zeros(env_nums) visit_entropies_lst = np.zeros(env_nums) @@ -433,21 +347,18 @@ async def collect( # ================================================================== while True: with self._timer: - # Get ready environments obs = self._env.ready_obs ready_env_id = set(obs.keys()) if len(ready_env_id) < self._env_num: self._logger.debug(f'Only {len(ready_env_id)}/{self._env_num} envs ready') - # Prepare stacked observations for world model stack_obs_dict = { env_id: game_segments[env_id].get_obs() for env_id in ready_env_id } stack_obs_list = [stack_obs_dict[env_id] for env_id in sorted(list(ready_env_id))] - # Prepare action masks and other info action_mask = [self.action_mask_dict[env_id] for env_id in sorted(list(ready_env_id))] to_play = [self.to_play_dict[env_id] for env_id in sorted(list(ready_env_id))] timestep = [self.timestep_dict[env_id] for env_id in sorted(list(ready_env_id))] @@ -463,59 +374,45 @@ async def collect( # ============================================================== # [PRIORZERO-NEW] Get LLM Priors # ============================================================== - if not collect_with_pure_policy: + if collect_with_pure_policy: + continue + else: # Extract text observations and valid actions raw_obs_list = [] histories_list = [] - valid_actions_list = [] # [PRIORZERO] Store valid actions for each env + valid_actions_list = [] for env_id in sorted(list(ready_env_id)): - # Extract raw text raw_obs_text = extract_raw_obs_text(obs[env_id]) raw_obs_list.append(raw_obs_text) - # Get history for this environment history = list(self.history_buffers[env_id]) histories_list.append(history) - # [PRIORZERO] Extract valid actions from observation valid_actions = obs[env_id].get('valid_actions', []) valid_actions_list.append(valid_actions) - # Generate request IDs - request_ids = [ - f"collect_{train_iter}_{i}" - for i in range(len(raw_obs_list)) - ] - - # Async call to LLM - llm_outputs = await self._async_get_llm_prior( - raw_obs_list, - request_ids, - histories_list - ) - - # Add to policy kwargs - policy_kwargs['llm_prior_outputs'] = llm_outputs - policy_kwargs['valid_actions_list'] = valid_actions_list # [PRIORZERO] Pass valid actions - else: - policy_kwargs['llm_prior_outputs'] = None - policy_kwargs['valid_actions_list'] = None + if self.policy_config.llm_policy_cfg.enable_llm: + with self._profile_block(name='collect_get_llm_prior_profile'): + llm_prior_logprob = self._get_llm_prior( + states=raw_obs_list, + valid_actions_list=valid_actions_list, # [PRIORZERO] Pass valid actions + histories=histories_list + ) + else: + llm_prior_logprob = [None for i in range(len(valid_actions_list))] - # ============================================================== - # Policy Forward Pass - # ============================================================== - policy_args = (stack_obs_tensor, action_mask, temperature, to_play, epsilon) policy_kwargs_forward = { - 'ready_env_id': sorted(list(ready_env_id)), - 'timestep': timestep, - 'llm_prior_outputs': policy_kwargs.get('llm_prior_outputs'), - 'valid_actions_list': policy_kwargs.get('valid_actions_list') # [PRIORZERO] Pass valid actions + 'llm_prior_logprob': llm_prior_logprob, + 'valid_actions_list': valid_actions_list, } if self.task_id is not None: policy_kwargs_forward['task_id'] = self.task_id - - policy_output = self._policy.forward(*policy_args, **policy_kwargs_forward) + with self._profile_block(name='collect_forward_profile'): + policy_output = self._policy.forward(data=stack_obs_tensor, action_mask=action_mask, + temperature=temperature, to_play=to_play, epsilon=epsilon, + ready_env_id=sorted(list(ready_env_id)), timestep=timestep, + **policy_kwargs_forward) # Extract outputs actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} @@ -538,12 +435,8 @@ async def collect( # ============================================================== # Step Environments # ============================================================== - timesteps = self._env.step(actions) - - # [DEBUG] Log actions taken if debug mode is enabled - if self.debug_mode: - for env_id, action in actions.items(): - self._logger.info(f"[DEBUG] Env {env_id} - Action taken: {action}") + with self._profile_block(name='collect_step_profile'): + timesteps = self._env.step(actions) interaction_duration = self._timer.value / len(timesteps) @@ -566,25 +459,21 @@ async def collect( episode_timestep.done, episode_timestep.info ) - - # [DEBUG] Log observation and reward if debug mode is enabled - if self.debug_mode: - raw_obs_preview = extract_raw_obs_text(obs_new)[:150] - self._logger.info(f"[DEBUG] Env {env_id} - Obs: {raw_obs_preview}... | Reward: {reward} | Done: {done}") - - # Store search statistics - if collect_with_pure_policy: - game_segments[env_id].store_search_stats(temp_visit_list, 0) + game_segments[env_id].store_search_stats( + distributions_dict_with_env_id[env_id], + value_dict_with_env_id[env_id]) + # =========================================================== + # [PRIORZERO-NEW] Update History Buffer + # =========================================================== + raw_obs_text = extract_raw_obs_text(obs[env_id]) + if env_id < len(valid_actions_list) and actions[env_id] < len(valid_actions_list[env_id]): + action = valid_actions_list[env_id][actions[env_id]] else: - game_segments[env_id].store_search_stats( - distributions_dict_with_env_id[env_id], - value_dict_with_env_id[env_id] - ) - + action = info.get('action_str', "go") + + self.history_buffers[env_id].append((raw_obs_text, action, float(reward))) + # Append transition to game segment - # [PRIORZERO-FIX] Extract and pass raw_obs_text to GameSegment - raw_obs_text_for_segment = extract_raw_obs_text(obs_new) - game_segments[env_id].append( actions[env_id], to_ndarray(obs_new['observation']), @@ -592,25 +481,11 @@ async def collect( self.action_mask_dict[env_id], self.to_play_dict[env_id], timestep=to_ndarray(obs_new.get('timestep', -1)), - raw_obs_text=raw_obs_text_for_segment + raw_obs_text=extract_raw_obs_text(obs_new), + history_obs=list(self.history_buffers[env_id]), + action_logprob=llm_prior_logprob[env_id] # 是一个字典对 {'open': -151; "down": -231} ) - # =========================================================== - # [PRIORZERO-NEW] Update History Buffer - # =========================================================== - raw_obs_text = extract_raw_obs_text(obs[env_id]) - # [PRIORZERO] Use dynamic action mapping if available - dynamic_action_inv_map = policy_output.get(env_id, {}).get('dynamic_action_inv_map', None) - if dynamic_action_inv_map is not None: - action_text = dynamic_action_inv_map.get(actions[env_id], f"action_{actions[env_id]}") - else: - # Fallback to static mapping - action_text = getattr(self._policy, 'action_inv_map', {}).get( - actions[env_id], - f"action_{actions[env_id]}" - ) - self.history_buffers[env_id].append((raw_obs_text, action_text, float(reward))) - # Update state self.action_mask_dict[env_id] = to_ndarray(obs_new['action_mask']) self.to_play_dict[env_id] = to_ndarray(obs_new['to_play']) @@ -642,26 +517,17 @@ async def collect( # Save Full Game Segment # =========================================================== if game_segments[env_id].is_full(): - if self.last_game_segments[env_id] is not None: - self.pad_and_save_last_trajectory( - env_id, - self.last_game_segments, - self.last_game_priorities, - game_segments, - self.dones - ) + if last_game_segments[env_id] is not None: + self.pad_and_save_last_trajectory(env_id, last_game_segments, last_game_priorities, + game_segments, self.dones) # Calculate priorities - priorities = self._compute_priorities( - env_id, - pred_values_lst, - search_values_lst - ) + priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) pred_values_lst[env_id], search_values_lst[env_id] = [], [] # Save segment - self.last_game_segments[env_id] = game_segments[env_id] - self.last_game_priorities[env_id] = priorities + last_game_segments[env_id] = game_segments[env_id] + last_game_priorities[env_id] = priorities # Create new segment game_segments[env_id] = GameSegment( @@ -670,9 +536,15 @@ async def collect( config=self.policy_config, task_id=self.task_id ) - game_segments[env_id].reset(observation_window_stack[env_id]) + game_segments[env_id].reset(observation_window_stack[env_id], init_raw_obs=extract_raw_obs_text(obs_new), init_history_obs=list(self.history_buffers[env_id]), init_action_logprob=None) self._env_info[env_id]['step'] += 1 + if llm_prior_logprob[env_id] is not None: + llm_prior_tensor = torch.tensor([logit for k, logit in llm_prior_logprob[env_id].items()]) + llm_prior_prob = torch.softmax(llm_prior_tensor, dim=-1) + llm_prior_entropy[env_id].append(-torch.sum(llm_prior_prob * torch.log(llm_prior_prob + 1e-9), dim=-1)) + else: + llm_prior_entropy[env_id].append(0.0) collected_step += 1 self._env_info[env_id]['time'] += self._timer.value + interaction_duration @@ -683,13 +555,12 @@ async def collect( if episode_timestep.done: self._logger.info(f'======== Env {env_id} episode finished! ========') self._total_episode_count += 1 - # Logging info_log = { 'reward': episode_timestep.info['eval_episode_return'], 'time': self._env_info[env_id]['time'], 'step': self._env_info[env_id]['step'], - } + 'llm_prior_entropy': sum(llm_prior_entropy[env_id])/len(llm_prior_entropy[env_id])} if not collect_with_pure_policy: info_log['visit_entropy'] = ( visit_entropies_lst[env_id] / eps_steps_lst[env_id] @@ -698,23 +569,11 @@ async def collect( collected_episode += 1 self._episode_info.append(info_log) - # Save remaining segments - if self.last_game_segments[env_id] is not None: - self.pad_and_save_last_trajectory( - env_id, - self.last_game_segments, - self.last_game_priorities, - game_segments, - self.dones - ) - - priorities = self._compute_priorities( - env_id, - pred_values_lst, - search_values_lst - ) + if last_game_segments[env_id] is not None: + self.pad_and_save_last_trajectory( env_id, last_game_segments, last_game_priorities, game_segments, self.dones) + priorities = self._compute_priorities( env_id, pred_values_lst, search_values_lst) game_segments[env_id].game_segment_to_array() if len(game_segments[env_id].reward_segment) > 0: self.game_segment_pool.append(( @@ -722,7 +581,6 @@ async def collect( priorities, self.dones[env_id] )) - # Reset pred_values_lst[env_id], search_values_lst[env_id] = [], [] eps_steps_lst[env_id], visit_entropies_lst[env_id] = 0, 0 @@ -732,15 +590,22 @@ async def collect( # Clear history buffer for this environment self.history_buffers[env_id].clear() - # Re-initialize game segment + init_obs = self._env.ready_obs + observation_window_stack[env_id] = deque( + [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config, task_id=self.task_id ) - game_segments[env_id].reset(observation_window_stack[env_id]) + game_segments[env_id].reset(observation_window_stack[env_id], init_raw_obs=extract_raw_obs_text(init_obs[env_id]), init_history_obs=list(self.history_buffers[env_id]), init_action_logprob=None) + last_game_segments[env_id] = None + last_game_priorities[env_id] = None # ================================================================== # Check if Enough Segments Collected @@ -771,25 +636,22 @@ async def collect( # ================================================================== collected_duration = sum([d['time'] for d in self._episode_info]) + if self._world_size > 1: + # Before allreduce + self._logger.info(f"Rank {self._rank} before allreduce: collected_step={collected_step}, collected_episode={collected_episode}") + collected_step = allreduce_data(collected_step, 'sum') + collected_episode = allreduce_data(collected_episode, 'sum') + collected_duration = allreduce_data(collected_duration, 'sum') + # After allreduce + self._logger.info(f"Rank {self._rank} after allreduce: collected_step={collected_step}, collected_episode={collected_episode}") + + self._total_envstep_count += collected_step self._total_episode_count += collected_episode self._total_duration += collected_duration self._output_log(train_iter) - # [PRIORZERO-NEW] Log LLM statistics - if self.llm_stats['total_calls'] > 0: - avg_latency = self.llm_stats['total_latency'] / self.llm_stats['total_calls'] - success_rate = self.llm_stats['successful_calls'] / self.llm_stats['total_calls'] - - self._logger.info( - f"📊 LLM Prior Statistics:\n" - f" - Total calls: {self.llm_stats['total_calls']}\n" - f" - Success rate: {success_rate*100:.1f}%\n" - f" - Avg latency: {avg_latency:.3f}s\n" - f" - Retry count: {self.llm_stats['retry_count']}" - ) - return return_data def _output_log(self, train_iter: int) -> None: @@ -797,4 +659,58 @@ def _output_log(self, train_iter: int) -> None: [INHERITED] Log collection statistics (inherited from parent). """ - super()._output_log(train_iter) + if self._rank != 0: + return + + if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: + self._last_train_iter = train_iter + episode_count = len(self._episode_info) + envstep_count = sum([d['step'] for d in self._episode_info]) + duration = sum([d['time'] for d in self._episode_info]) + episode_reward = [d['reward'] for d in self._episode_info] + episode_llm_prior_entropy = [d['llm_prior_entropy'] for d in self._episode_info] + + info = { + 'episode_count': episode_count, + 'envstep_count': envstep_count, + 'avg_envstep_per_episode': envstep_count / episode_count, + 'avg_envstep_per_sec': envstep_count / duration if duration > 0 else 0, + 'avg_episode_per_sec': episode_count / duration if duration > 0 else 0, + 'collect_time': duration, + 'reward_mean': np.mean(episode_reward), + 'reward_std': np.std(episode_reward), + 'reward_max': np.max(episode_reward), + 'reward_min': np.min(episode_reward), + 'total_envstep_count': self._total_envstep_count, + 'total_episode_count': self._total_episode_count, + 'total_duration': self._total_duration, + 'llm_prior_entropy_mean': np.mean(episode_llm_prior_entropy), + 'llm_prior_entropy_max': np.max(episode_llm_prior_entropy), + 'llm_prior_entropy_min': np.min(episode_llm_prior_entropy) + } + + if not self.collect_with_pure_policy: + visit_entropy = [d['visit_entropy'] for d in self._episode_info] + info['visit_entropy_mean'] = np.mean(visit_entropy) + if self.policy_config.gumbel_algo: + completed_value = [d['completed_value'] for d in self._episode_info] + info['completed_value_mean'] = np.mean(completed_value) + + self._episode_info.clear() + + # Log to console + self._logger.info("Collector Training Summary:\n{}".format('\n'.join([f' {k}: {v}' for k, v in info.items()]))) + + # Log to TensorBoard and WandB + for k, v in info.items(): + if self.task_id is None: + tb_prefix_iter = f'{self._instance_name}_iter/' + tb_prefix_step = f'{self._instance_name}_step/' + else: + tb_prefix_iter = f'{self._instance_name}_iter_task{self.task_id}/' + tb_prefix_step = f'{self._instance_name}_step_task{self.task_id}/' + + self._tb_logger.add_scalar(tb_prefix_iter + k, v, train_iter) + self._tb_logger.add_scalar(tb_prefix_step + k, v, self._total_envstep_count) + + diff --git a/zoo/jericho/priorzero/priorzero_config.py b/zoo/jericho/priorzero/priorzero_config.py index 1614aaed4..d22f780bf 100644 --- a/zoo/jericho/priorzero/priorzero_config.py +++ b/zoo/jericho/priorzero/priorzero_config.py @@ -1,66 +1,12 @@ -# priorzero_config.py -""" -[PRIORZERO] PriorZero Configuration - -This module provides complete configuration for PriorZero algorithm. - -Key Features: -- Complete UniZero world model configuration -- LLM policy configuration (ORZ-style) -- Action space mapping for text environments -- Flexible switches to enable/disable components - -Author: PriorZero Team -Date: 2025-01-20 -""" - import os from typing import Dict, Tuple from easydict import EasyDict - - -def get_jericho_action_mapping(env_id: str = 'zork1.z5') -> Tuple[Dict[str, int], Dict[int, str]]: - """ - Get action mapping for Jericho environments. - - In Jericho, the action space is typically defined by the game's valid actions. - For simplicity, we'll provide a basic mapping that can be extended. - - Args: - env_id: Jericho game ID - - Returns: - action_map: Mapping from action text to action index - action_inv_map: Mapping from action index to action text - """ - # Basic common actions for text adventure games - # These should ideally be loaded from the environment's action space - common_actions = [ - # Movement - "go north", "go south", "go east", "go west", - "go up", "go down", "go northeast", "go northwest", - "go southeast", "go southwest", - # Object interaction - "take all", "drop all", "inventory", "look", - "examine", "open", "close", "unlock", - # Common verbs - "read", "eat", "drink", "wear", "remove", - ] - - # Create mapping - action_map = {action.lower(): idx for idx, action in enumerate(common_actions)} - action_inv_map = {idx: action for action, idx in action_map.items()} - - return action_map, action_inv_map - +import torch.distributed as dist def get_priorzero_config( env_id: str = 'zork1.z5', seed: int = 0, exp_name: str = None, - enable_llm: bool = True, - enable_rft: bool = True, - debug_mode: bool = False, ) -> Tuple[EasyDict, EasyDict]: """ Generate complete PriorZero configuration. @@ -71,17 +17,11 @@ def get_priorzero_config( exp_name: Experiment name (auto-generated if None) enable_llm: Whether to enable LLM policy (if False, degrades to pure UniZero) enable_rft: Whether to enable RFT training (if False, only use SFT) - debug_mode: Whether to enable detailed debug logging (obs, action, LLM output, etc.) Returns: main_config: Main configuration dictionary create_config: Creation configuration for DI-engine components """ - - # ============================================================================== - # 1. Basic Settings - # ============================================================================== - # Action space and max steps per environment (from jericho_unizero_config.py) env_configurations = { 'detective.z5': (12, 100), 'omniquest.z5': (25, 100), @@ -89,404 +29,206 @@ def get_priorzero_config( 'zork1.z5': (55, 500), } action_space_size, max_steps = env_configurations.get(env_id, (20, 100)) - - # World model encoder (for processing text observations) - wm_encoder_option = 'legacy' # Options: 'legacy', 'clip', 'custom' - wm_model_name = 'BAAI/bge-base-en-v1.5' # Sentence transformer for text encoding - - # LLM policy model + wm_encoder_option = 'legacy' + wm_model_name = 'BAAI/bge-base-en-v1.5' + multi_gpu = False + GPUs = 1 + + collector_env_num = 4 + evaluator_env_num = 2 + n_episode = collector_env_num + + num_unroll_steps = 10 + infer_context_length = 4 + game_segment_length = 50 + num_layers = 2 + embed_dim = 768 + replay_ratio = 0.1 + batch_size = 64 + collect_num_simulations=25 + eval_num_simulations=25 + + if multi_gpu: + n_episode = int(GPUs * collector_env_num) + batch_size = int(batch_size * GPUs) + + ## LLM 参数 # llm_model_name = "Qwen/Qwen2.5-1.5B-Instruct" # Smaller model for faster iteration - llm_model_name = "Qwen/Qwen2.5-0.5B-Instruct" # Smaller model for faster iteration - - # Get action mappings - action_map, action_inv_map = get_jericho_action_mapping(env_id) - - # Convert action_inv_map to use string keys for EasyDict compatibility - action_inv_map_str = {str(k): v for k, v in action_inv_map.items()} - - # ============================================================================== - # 2. Environment Configuration - # ============================================================================== + llm_model_name = "/mnt/afs/wanzunian/niuyazhe/xiongjyu/models/Qwen2.5-0.5B-Instruct" + train_batch_size = 128 # Total batch size across all GPUs + GPUS = 1 + micro_batch_size = 16 # Micro batch size per GPU + gradient_accumulation_steps = train_batch_size // micro_batch_size // GPUS + rft_loss_type = 'reinforce++' # 'reinforce' | 'reinforce++' | 'ppo-simple-adv' + use_cot = False # Whether to use chain-of-thought prompting + history_length = 5 + llm_learn_num_samples = 512 + replay_buffer_size = llm_learn_num_samples + env_config = dict( - # Stop conditions stop_value=int(1e6), max_steps=max_steps, - - # Observation and action space - observation_shape=512, # BGE embedding dimension - action_space_size=action_space_size, - - # [FIX] Jericho environment expects these at top level + observation_shape=512, env_id=env_id, - game_path=f"/mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", + game_path=f"/mnt/afs/wanzunian/niuyazhe/xiongjyu/jericho/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", + for_unizero=True, tokenizer_path=wm_model_name, - env_type="jericho", max_action_num=action_space_size, max_seq_len=512, - save_replay=False, - save_replay_path="", - collect_policy_mode="default", - - # Parallelization - collector_env_num=4, - evaluator_env_num=2, - n_evaluator_episode=2, - - # Environment manager + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, manager=dict( shared_memory=False, - reset_timeout=60, # Increased timeout for text env initialization - ), - ) - - # ============================================================================== - # 3. UniZero World Model Configuration - # ============================================================================== - world_model_config = dict( - # [CRITICAL] DI-engine requires 'type' field to identify model class - type='UniZeroModel', - - # [FIX] EasyDict.pop() doesn't handle default values properly, must include import_names - import_names=[], # Empty list since UniZeroModel is already registered - - # Model type - model_type='mlp', # For vector observations (text embeddings) - continuous_action_space=False, - - # Observation and action - observation_shape=512, - action_space_size=action_space_size, - - # [FIX] Encoder settings must be at top level for UniZeroModel.__init__ - encoder_option=wm_encoder_option, - encoder_url=wm_model_name, - - # World model architecture - world_model_cfg=dict( - # Obs type - obs_type="text", # Important: text-based observations - - # Environment settings - env_num=max(4, 2), # max(collector_env_num, evaluator_env_num), will be updated in quick_test - action_space_size=action_space_size, - - # Transformer settings - # num_layers=4, # Reduced for faster training - num_layers=2, # Reduced for faster training # TODO - num_heads=8, - embed_dim=512, - - # Context and unroll - # Note: Each timestep contains 2 tokens: observation and action - num_unroll_steps=10, # Number of steps to unroll in training - infer_context_length=4, # Inference context length - tokens_per_block=2, # obs + action - max_blocks=10, # num_unroll_steps (default) - max_tokens=2 * 10, # 2 * num_unroll_steps - context_length=2 * 4, # 2 * infer_context_length - - # Regularization - embed_pdrop=0.1, - resid_pdrop=0.1, - attn_pdrop=0.1, - - # Loss weights - latent_recon_loss_weight=0.0, # Latent reconstruction loss - perceptual_loss_weight=0.0, - policy_entropy_weight=0.0, # Entropy regularization - - # Normalization - final_norm_option_in_head="LayerNorm", - final_norm_option_in_encoder="LayerNorm", - predict_latent_loss_type='mse', # or 'group_kl' with SimNorm - - # Device - device="cuda", - - # Advanced settings - gru_gating=False, - attention='causal', - support_size=101, # For distributional RL - - # Analysis flags - analysis_sim_norm=False, - analysis_dormant_ratio_weight_rank=False, - # use_priority=False, - use_priority=True, - - # Position encoding - rotary_emb=False, # Whether to use RoPE - rope_theta=10000, - max_seq_len=8192, - - # LoRA (optional, for world model) - lora_r=0, # Set > 0 to enable LoRA - - # Other - decode_loss_mode=None, # 'after_backbone', 'before_backbone', or None - gamma=1.0, # Discount factor - dormant_threshold=0.025, - - task_embed_option=None, - use_task_embed=False, - use_normal_head=True, - use_softmoe_head=False, - use_moe_head=False, - num_experts_in_moe_head=4, - moe_in_transformer=False, - multiplication_moe_in_transformer=False, - n_shared_experts=1, - num_experts_per_tok=1, - num_experts_of_moe_in_transformer=8, - # game_segment_length=200, - game_segment_length=50, ), - - # Distributional RL - categorical_distribution=True, - reward_support_range=(-50., 51., 1.), # (min, max, step) for reward support - value_support_range=(-50., 51., 1.), # (min, max, step) for value support - - # Self-supervised learning - self_supervised_learning_loss=True, - - # Model architecture details - frame_stack_num=1, - bias=True, - res_connection_in_dynamics=True, - norm_type='LN', # LayerNorm for text + use_cache=True, + cache_size=100000, ) - - # ============================================================================== - # 4. LLM Policy Configuration (ORZ-style) - # ============================================================================== - llm_policy_config = dict( - # Model path - pretrain_llm_path=llm_model_name, - - # LoRA for parameter-efficient fine-tuning - use_lora=False, # Set to True to enable LoRA - lora_r=8, - lora_alpha=16, - lora_dropout=0.05, - - # Training - llm_learning_rate=1e-6, - llm_weight_decay=0.01, - llm_loss_weight=0.5, # Weight of SFT loss in total loss - rft_loss_weight=0.3, # Weight of RFT loss in total loss - - # [PRIORZERO-OOM-FIX] Gradient accumulation for memory efficiency - # Process LLM training in smaller micro-batches to avoid OOM - llm_micro_batch_size=4, # Small batch size per forward pass (reduce if still OOM) - llm_gradient_accumulation_steps=8, # Accumulate gradients over 8 steps (effective batch = 4*8=32) - # Note: Effective batch size = llm_micro_batch_size * llm_gradient_accumulation_steps - - # Generation - prompt_max_len=2048, - generate_max_len=256, # Max tokens for LLM output - - # Prompting strategy - history_length=5, # Number of recent (obs, action, reward) tuples to include - use_cot=True, # Whether to use Chain-of-Thought prompting - - # Training strategy - sft_target='mcts_policy', # 'mcts_policy' or 'oracle_policy' - enable_rft=enable_rft, # Whether to enable RFT with env rewards - # enable_rft=False, # Whether to enable RFT with env rewards # TODO - - # vLLM settings - vllm_tensor_parallel_size=1, - gpu_memory_utilization=0.3, # Adjust based on your GPU memory - ) - - # ============================================================================== - # 5. Policy Configuration (Combines World Model + LLM) - # ============================================================================== policy_config = dict( + type='priorzero', + multi_gpu=multi_gpu, + use_wandb=False, + profile_cfg=dict( + enable_cprofile=False, # Enable cProfile for collect/train hot paths + log_interval=100, # Aggregate wall-time stats every N profiled sections + ), learn=dict( learner=dict( hook=dict( - save_ckpt_after_iter=1000000, # To save memory, set a large value. If intermediate checkpoints are needed, reduce this value. + save_ckpt_after_iter=1000000, ), ), ), - type='priorzero', - - # Environment settings (must match env config) - collector_env_num=env_config['collector_env_num'], - evaluator_env_num=env_config['evaluator_env_num'], - - # Model config (world model) - model=world_model_config, - - # [PRIORZERO-NEW] LLM policy config - llm_policy_cfg=llm_policy_config, - - # [PRIORZERO-NEW] Action mappings (use original dict, not EasyDict) - # These will be set directly on policy instance, not through EasyDict - _action_map=action_map, # Prefix with _ to avoid EasyDict conversion - _action_inv_map=action_inv_map, - - # ============================================================================== - # [ASYNC-NEW] Async Training Configuration - # ============================================================================== - # off_policy_degree controls the degree of asynchrony between collect and train: - # - 0: Fully synchronous (serial) mode - collect -> train -> eval - # - 1-10: Low async - train can lag behind collect by a few batches - # - 10-50: Medium async - train can lag more, higher throughput - # - >50: High async - maximum throughput, highest off-policy bias - # - # Special value -1: Auto-tune based on buffer size and batch size - off_policy_degree=0, # Default to synchronous mode for stability - # off_policy_degree=5, - - # Whether to enable async evaluation (runs eval in background) - enable_async_eval=False, - - # MCTS settings - num_simulations=25, - collect_num_simulations=25, - eval_num_simulations=25, - - # MCTS exploration - root_dirichlet_alpha=0.3, - root_noise_weight=0.25, - - # MCTS variants (set one to True to use that variant) - sampled_algo=False, # Sampled MuZero - gumbel_algo=False, # Gumbel MuZero - mcts_ctree=True, # Use C++ MCTS (faster) - - # Training settings - batch_size=32, - learning_rate=3e-4, # World model learning rate + model=dict( + observation_shape=512, + action_space_size=action_space_size, + encoder_option=wm_encoder_option, + encoder_url=wm_model_name, + model_type="mlp", + continuous_action_space=False, + norm_type="LN", + world_model_cfg=dict( + norm_type="LN", + final_norm_option_in_head="LayerNorm", + final_norm_option_in_encoder="LayerNorm", + predict_latent_loss_type='mse', + policy_entropy_weight=5e-2, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device="cuda", + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=24, + embed_dim=embed_dim, + obs_type="text", + env_num=max(collector_env_num, evaluator_env_num), + decode_loss_mode=None, + latent_recon_loss_weight=0, + + task_embed_option=None, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + game_segment_length=game_segment_length, + ) + ), + update_per_collect=None, + num_segments=collector_env_num, + action_type="varied_action_space", + model_path=None, + num_unroll_steps=num_unroll_steps, + reanalyze_ratio=0, + replay_ratio=replay_ratio, + batch_size=batch_size, + learning_rate=3e-4, weight_decay=1e-4, + cos_lr_scheduler=False, + fixed_temperature_value=0.25, + manual_temperature_decay=False, + n_episode=n_episode, + train_start_after_envsteps=0, + replay_buffer_size=replay_buffer_size, + eval_freq=int(3e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=1 / 1000000, + reanalyze_batch_size=160, + reanalyze_partition=0.75, + device='cuda', + + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, + game_segment_length=game_segment_length, + off_policy_degree=0, + enable_async_eval=False, + optim_type='AdamW', grad_clip_value=10.0, - - # Loss components - value_loss_weight=1.0, + value_loss_weight=0.25, policy_loss_weight=1.0, reward_loss_weight=1.0, - # Adaptive entropy weight (for exploration) - use_adaptive_entropy_weight=True, + use_adaptive_entropy_weight=False, adaptive_entropy_alpha_lr=1e-4, - - # Encoder gradient clipping with annealing - use_encoder_clip_annealing=True, + use_encoder_clip_annealing=False, encoder_clip_anneal_type='cosine', encoder_clip_start_value=30.0, encoder_clip_end_value=10.0, encoder_clip_anneal_steps=100000, - - # Training schedule - num_unroll_steps=10, - td_steps=5, - train_start_after_envsteps=0, - # train_start_after_envsteps=1000, - update_per_collect=None, # Will be set automatically - replay_ratio=0.25, - - # Replay buffer - # replay_buffer_size=int(1e4), - replay_buffer_size=int(1e5), - use_priority=True, # Prioritized experience replay + use_priority=False, # Prioritized experience replay priority_prob_alpha=0.6, priority_prob_beta=0.4, - - # Evaluation - eval_freq=500, - - # Game segments - # game_segment_length=200, - game_segment_length=50, - num_segments=env_config['collector_env_num'], # Must equal collector_env_num - - # Misc - ignore_done=False, - collect_with_pure_policy=False, - monitor_extra_statistics=True, - - # Device - cuda=True, - device='cuda', - multi_gpu=False, - - # Environment type - env_type='not_board_games', - action_type='varied_action_space', # Jericho has varied action space per state - battle_mode='play_with_bot_mode', - - # Data processing - transform2string=False, - gray_scale=False, - use_augmentation=False, - - # Advanced - use_rnd_model=False, # Random Network Distillation for exploration - analysis_sim_norm=False, - sample_type='transition', - - # ============================================================================== - # [ALIGN WITH UNIZERO] Reanalyze Configuration (atari_unizero_segment_config.py line 201-206) - # ============================================================================== - # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, - # 2 means reanalyze once every two epochs, 1/50 means reanalyze once every 50 epochs. - buffer_reanalyze_freq=1/5000000000, # Effectively disabled for Jericho (set very low) - # Each reanalyze process will reanalyze sequences - # ( transitions per sequence) - reanalyze_batch_size=160, - # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, - # 0.5 means samples from the first half of the buffer. - reanalyze_partition=0.75, - # Reanalyze ratio (used in some algorithms, kept for compatibility) - reanalyze_ratio=0.0, - ) - - # ============================================================================== - # 6. Replay Buffer Configuration - # ============================================================================== - replay_buffer_config = dict( - type='game', - replay_buffer_size=policy_config['replay_buffer_size'], - batch_size=policy_config['batch_size'], + llm_policy_cfg=dict( + # 是否使用大模型的相关参数 + enable_llm=True, + enable_sft=False, + enable_rft=True, + sft_loss_weight=1, # Weight of SFT loss in total loss + rft_loss_weight=1, + prompt_log_interval=1000, # 隔多久step输出模型的回答和valid action进行对比 + + # 模型相关参数 + pretrain_llm_path=llm_model_name, + history_length=history_length, + use_cot=use_cot, + prompt_max_len=2048, + generate_max_len=128, + temperature = 1.0, + top_p = 1.0, + + # 训练相关参数 + zero_stage=0, + train_batch_size=train_batch_size, + micro_batch_size=micro_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + learning_rate=1e-5, + weight_decay=0.01, + + # loss相关参数 + rft_loss_type=rft_loss_type, + rft_clip_epsilon=0.2, + rft_kl_coef=0.01, + + # vllm 相关参数 + vllm_tensor_parallel_size=1, + gpu_memory_utilization=0.2, + ), ) - - # ============================================================================== - # 6.5 Remove problematic nested dicts before EasyDict conversion - # ============================================================================== - # Store action mappings separately to avoid EasyDict issues with integer keys - _temp_action_map = action_map - _temp_action_inv_map = action_inv_map - - # ============================================================================== - # 7. Main Configuration Assembly - # ============================================================================== priorzero_config = dict( env=env_config, policy=policy_config, - replay_buffer=replay_buffer_config, - - # Experiment settings - exp_name=exp_name or f"priorzero_{env_id}_seed{seed}", - seed=seed, - - # Debug settings - debug_mode=debug_mode, + exp_name=exp_name, + seed=seed ) - # ============================================================================== - # 8. Create Configuration (for DI-engine component creation) - # ============================================================================== create_config = dict( env=dict( type="jericho", import_names=["zoo.jericho.envs.jericho_env"], ), env_manager=dict( - type="base" # [FIX] Use 'base' for jericho to avoid daemon process issues + type="base" ), policy=dict( type="priorzero", @@ -506,183 +248,50 @@ def get_priorzero_config( ), ) - # ============================================================================== - # 9. Convert to EasyDict for convenient access - # ============================================================================== - # IMPORTANT: Remove _action_map and _action_inv_map from policy_config before EasyDict - # to avoid integer key issues - policy_config_copy = {k: v for k, v in policy_config.items() if not k.startswith('_')} - priorzero_config['policy'] = policy_config_copy - main_config = EasyDict(priorzero_config) create_config = EasyDict(create_config) - - # Set experiment path - main_config.exp_name = f"data_priorzero/{main_config.exp_name}" - - # [IMPORTANT] Set action mappings as regular attributes (not through EasyDict) - # Use object.__setattr__ to bypass EasyDict's __setattr__ which tries to convert dicts - object.__setattr__(main_config.policy, 'action_map', _temp_action_map) - object.__setattr__(main_config.policy, 'action_inv_map', _temp_action_inv_map) - return main_config, create_config -def get_priorzero_config_for_quick_test(env_id: str = 'zork1.z5', seed: int = 0, debug_mode: bool = False): - """ - Get a lightweight configuration for quick testing (reduced resources). - - This is useful for: - - Debugging - - CI/CD pipelines - - Local development without powerful GPUs - - IMPORTANT: All sequence-length related parameters must be consistent: - - num_unroll_steps: Number of timesteps in training unroll - - max_blocks: Should equal num_unroll_steps - - max_tokens: Should equal num_unroll_steps * tokens_per_block (= num_unroll_steps * 2) - - infer_context_length: Context length for inference - - context_length: Should equal infer_context_length * tokens_per_block (= infer_context_length * 2) - """ - main_config, create_config = get_priorzero_config(env_id, seed, debug_mode=debug_mode) - - # ============================================================================== - # [CRITICAL FIX] Define num_unroll_steps FIRST to ensure consistency - # ============================================================================== - quick_test_num_unroll_steps = 10 # Core parameter that determines sequence length - quick_test_infer_context_length = 4 # Inference context length - tokens_per_block = 2 # obs + action (fixed in UniZero architecture) - - # Reduce computational requirements - main_config.env.collector_env_num = 2 - main_config.env.evaluator_env_num = 1 - main_config.env.n_evaluator_episode = 1 - - # ============================================================================== - # Policy-level configurations - # ============================================================================== - main_config.policy.num_simulations = 5 - # main_config.policy.batch_size = 20 - main_config.policy.batch_size = 2 - main_config.policy.game_segment_length = 20 # Can be larger than num_unroll_steps - main_config.policy.num_segments = 2 # Must equal collector_env_num - main_config.policy.replay_buffer_size = 1000 - - # [CRITICAL] Set policy-level num_unroll_steps to match world model - main_config.policy.num_unroll_steps = quick_test_num_unroll_steps - - # ============================================================================== - # World model configurations - ALL must be consistent with num_unroll_steps - # ============================================================================== - main_config.policy.model.world_model_cfg.num_layers = 1 - main_config.policy.model.world_model_cfg.num_heads = 2 - - # Update env_num to match the reduced collector/evaluator counts - main_config.policy.model.world_model_cfg.env_num = max( - main_config.env.collector_env_num, - main_config.env.evaluator_env_num - ) - - # [CRITICAL] Sequence length parameters - must all be consistent - main_config.policy.model.world_model_cfg.num_unroll_steps = quick_test_num_unroll_steps - main_config.policy.model.world_model_cfg.max_blocks = quick_test_num_unroll_steps - main_config.policy.model.world_model_cfg.max_tokens = quick_test_num_unroll_steps * tokens_per_block # 3 * 2 = 6 - - main_config.policy.model.world_model_cfg.infer_context_length = quick_test_infer_context_length - main_config.policy.model.world_model_cfg.context_length = quick_test_infer_context_length * tokens_per_block # 2 * 2 = 4 - - # Verify tokens_per_block is set correctly (should already be 2 from base config) - main_config.policy.model.world_model_cfg.tokens_per_block = tokens_per_block - - # ============================================================================== - # LLM policy configurations - # ============================================================================== - main_config.policy.llm_policy_cfg.prompt_max_len = 1024 - main_config.policy.llm_policy_cfg.generate_max_len = 128 - main_config.policy.llm_policy_cfg.history_length = 3 - # [PRIORZERO-OOM-FIX] Reduce micro-batch size for quick test to avoid OOM - main_config.policy.llm_policy_cfg.llm_micro_batch_size = 2 - main_config.policy.llm_policy_cfg.llm_gradient_accumulation_steps = 4 - - main_config.exp_name = f"{main_config.exp_name}_debug" - - return main_config, create_config - - -# ============================================================================== -# Preset Configurations for Different Scenarios -# ============================================================================== - -def get_config_pure_unizero(env_id: str = 'zork1.z5', seed: int = 0): - """Get config for pure UniZero (without LLM).""" - main_config, create_config = get_priorzero_config( - env_id=env_id, - seed=seed, - enable_llm=False, - ) - main_config.exp_name = f"pure_unizero_{env_id}_seed{seed}" - main_config.policy.llm_policy_cfg.llm_loss_weight = 0.0 - main_config.policy.llm_policy_cfg.rft_loss_weight = 0.0 - return main_config, create_config - - -def get_config_llm_only_sft(env_id: str = 'zork1.z5', seed: int = 0): - """Get config for LLM with only SFT (no RFT).""" - main_config, create_config = get_priorzero_config( - env_id=env_id, - seed=seed, - enable_rft=False, - ) - main_config.exp_name = f"priorzero_sft_only_{env_id}_seed{seed}" - return main_config, create_config - - -def get_config_with_lora(env_id: str = 'zork1.z5', seed: int = 0): - """Get config with LoRA enabled for LLM (memory efficient).""" - main_config, create_config = get_priorzero_config(env_id=env_id, seed=seed) - main_config.policy.llm_policy_cfg.use_lora = True - main_config.exp_name = f"priorzero_lora_{env_id}_seed{seed}" +def get_priorzero_debug_config( + env_id: str = 'zork1.z5', + seed: int = 0, + exp_name: str = None, +) -> EasyDict: + + main_config, create_config = get_priorzero_config(env_id=env_id, seed=seed, exp_name=exp_name) + collector_env_num = 4 + evaluator_env_num = 1 + max_steps=10 + + num_unroll_steps = 5 + infer_context_length = 2 + batch_size = 16 + collect_num_simulations=2 + eval_num_simulations=2 + num_layers=1 + game_segment_length = 20 + llm_learn_num_samples = 64 + + create_config.collector_env_num = collector_env_num + create_config.evaluator_env_num = evaluator_env_num + create_config.max_steps = max_steps + + main_config.policy.model.world_model_cfg.max_blocks = num_unroll_steps + main_config.policy.model.world_model_cfg.max_tokens = 2 * num_unroll_steps + main_config.policy.model.world_model_cfg.context_length = 2 * infer_context_length + main_config.policy.model.world_model_cfg.num_layers = num_layers + main_config.policy.model.world_model_cfg.game_segment_length = game_segment_length + main_config.policy.num_unroll_steps = num_unroll_steps + main_config.policy.batch_size = batch_size + main_config.policy.collect_num_simulations = collect_num_simulations + main_config.policy.eval_num_simulations = eval_num_simulations + main_config.policy.model.world_model_cfg.env_num = collector_env_num + main_config.policy.num_segments = collector_env_num + main_config.policy.collector_env_num = collector_env_num + main_config.policy.update_per_collect = 2 + main_config.policy.game_segment_length = game_segment_length + main_config.policy.replay_buffer_size = llm_learn_num_samples + main_config.policy.llm_policy_cfg.llm_learn_num_samples = llm_learn_num_samples + return main_config, create_config - - -# ============================================================================== -# Example Usage -# ============================================================================== - -if __name__ == "__main__": - # Test configuration generation - print("="*80) - print("Testing PriorZero Configuration Generation") - print("="*80) - - # 1. Standard config - print("\n1. Standard PriorZero Config:") - main_cfg, create_cfg = get_priorzero_config(env_id='zork1.z5', seed=0) - print(f" Exp name: {main_cfg.exp_name}") - print(f" Action space size: {main_cfg.policy.model.action_space_size}") - print(f" LLM model: {main_cfg.policy.llm_policy_cfg.pretrain_llm_path}") - print(f" World model layers: {main_cfg.policy.model.world_model_cfg.num_layers}") - print(f" Num action mappings: {len(main_cfg.policy.action_map)}") - - # 2. Quick test config - print("\n2. Quick Test Config:") - test_cfg, _ = get_priorzero_config_for_quick_test() - print(f" Batch size: {test_cfg.policy.batch_size}") - print(f" Num simulations: {test_cfg.policy.num_simulations}") - print(f" Collector envs: {test_cfg.env.collector_env_num}") - - # 3. Pure UniZero config - print("\n3. Pure UniZero Config:") - unizero_cfg, _ = get_config_pure_unizero() - print(f" LLM loss weight: {unizero_cfg.policy.llm_policy_cfg.llm_loss_weight}") - print(f" RFT enabled: {unizero_cfg.policy.llm_policy_cfg.enable_rft}") - - # 4. Config with LoRA - print("\n4. Config with LoRA:") - lora_cfg, _ = get_config_with_lora() - print(f" Use LoRA: {lora_cfg.policy.llm_policy_cfg.use_lora}") - print(f" LoRA rank: {lora_cfg.policy.llm_policy_cfg.lora_r}") - - print("\n" + "="*80) - print("✓ All configurations generated successfully!") - print("="*80) diff --git a/zoo/jericho/priorzero/priorzero_entry.py b/zoo/jericho/priorzero/priorzero_entry.py deleted file mode 100644 index 65337f6f7..000000000 --- a/zoo/jericho/priorzero/priorzero_entry.py +++ /dev/null @@ -1,581 +0,0 @@ -# priorzero_entry.py -""" -[PRIORZERO] Main Training Entry Point - -This module provides the main async training loop for PriorZero. - -Key Features: -- Async training with vLLM integration -- Checkpoint management and recovery -- Comprehensive logging (TensorBoard + file logs) -- Graceful error handling - -Author: PriorZero Team -Date: 2025-01-20 -""" - -import asyncio -import os -import sys -from functools import partial -from pathlib import Path -from typing import Tuple, Optional -# from lzero.entry.utils import log_buffer_memory_usage -# from lzero.policy import visit_count_temperature -# from ding.rl_utils import get_epsilon_greedy_fn - -# ============================================================================== -# [CRITICAL] Ensure local LightZero is used for PriorZero-specific adaptations -# ============================================================================== -from ensure_local_lightzero import ensure_local_lightzero -ensure_local_lightzero() - - -import ray -import torch -import wandb -from ding.config import compile_config -from ding.envs import create_env_manager, get_vec_env_setting -from ding.policy import create_policy -from ding.utils import set_pkg_seed, get_rank -from ding.worker import create_buffer, BaseLearner -from tensorboardX import SummaryWriter -from loguru import logger -from vllm import AsyncLLMEngine -from vllm.engine.arg_utils import AsyncEngineArgs - -# Import PriorZero components -from priorzero_config import get_priorzero_config, get_priorzero_config_for_quick_test -from priorzero_collector import PriorZeroCollector -from priorzero_evaluator import PriorZeroEvaluator -# Import policy to ensure registration happens -import priorzero_policy # noqa: F401 - - -async def train_priorzero( - cfg: dict, - create_cfg: dict, - seed: int = 0, - max_train_iter: int = int(1e6), - max_env_step: Optional[int] = int(1e10), - enable_save: bool = True, -): - """ - [PRIORZERO-MODIFIED] - Main async training function for PriorZero. - - Args: - cfg: Main configuration dictionary - create_cfg: Creation configuration for DI-engine components - seed: Random seed - max_train_iter: Maximum training iterations - enable_save: Whether to save checkpoints - """ - # ================================================================== - # 1. Compile Configuration - # ================================================================== - cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) - - # ================================================================== - # 2. Initialize Ray (for distributed vLLM) - # ================================================================== - # Note: vLLM will initialize Ray internally if needed. - # We skip manual Ray initialization to avoid conflicts with existing clusters. - if ray.is_initialized(): - logger.info(f"✓ Ray already initialized (connected to existing cluster)") - else: - logger.info(f"✓ Ray not initialized - vLLM will handle initialization if needed") - - # ================================================================== - # 3. Create vLLM Engine - # ================================================================== - logger.info("Creating vLLM engine...") - - # [ROBUST FIX] Handle shared GPU environment - # Issue: vLLM V1 engine fails when other processes release GPU memory during init - # Solution: Use alternative initialization method that bypasses V1 checks - import os - - # Note: In vLLM>=0.3.0, worker_use_ray is replaced by distributed_executor_backend - # For single GPU: use "mp" (multiprocessing) - # For multi-GPU: use "ray" if available - tensor_parallel = cfg.policy.llm_policy_cfg.vllm_tensor_parallel_size - distributed_backend = "ray" if tensor_parallel > 1 and ray.is_initialized() else None - - # [ROBUST FIX] Lower GPU memory utilization in shared environment - # This leaves more headroom for memory fluctuations - gpu_mem_util = cfg.policy.llm_policy_cfg.gpu_memory_utilization - if gpu_mem_util > 0.85: - gpu_mem_util = 0.75 # More conservative in shared environment - logger.info(f"✓ Adjusted GPU memory utilization to {gpu_mem_util} for stability") - - # [ROBUST FIX] Use alternative initialization to avoid V1 engine issues - # Set env var BEFORE importing to ensure it takes effect - use_v1_env = os.environ.get('VLLM_USE_V1', None) - if use_v1_env is None: - # Only set if not already set by user - os.environ['VLLM_USE_V1'] = '0' - logger.info("✓ Using vLLM V0 engine for stability in shared GPU environment") - - try: - engine_args = AsyncEngineArgs( - model=cfg.policy.llm_policy_cfg.pretrain_llm_path, - tensor_parallel_size=tensor_parallel, - gpu_memory_utilization=gpu_mem_util, - distributed_executor_backend=distributed_backend, - trust_remote_code=True, - # [ROBUST FIX] Disable prefix caching in shared environment to reduce memory complexity - enable_prefix_caching=False, - # [ROBUST FIX] Disable enforce_eager to avoid memory profiling issues - enforce_eager=False, - ) - vllm_engine = AsyncLLMEngine.from_engine_args(engine_args) - logger.info(f"✓ vLLM Engine created (backend: {distributed_backend or 'default'})") - except (ValueError, RuntimeError) as e: - if "VLLM_USE_V1" in str(e) or "memory profiling" in str(e): - # Fallback: Try without V1 env var - logger.warning(f"⚠️ Initial vLLM initialization failed: {e}") - logger.info("Retrying with alternative configuration...") - if 'VLLM_USE_V1' in os.environ: - del os.environ['VLLM_USE_V1'] - - engine_args = AsyncEngineArgs( - model=cfg.policy.llm_policy_cfg.pretrain_llm_path, - tensor_parallel_size=tensor_parallel, - gpu_memory_utilization=gpu_mem_util * 0.9, # Even more conservative - distributed_executor_backend=distributed_backend, - trust_remote_code=True, - enable_prefix_caching=False, - enforce_eager=True, # Force eager mode as fallback - ) - vllm_engine = AsyncLLMEngine.from_engine_args(engine_args) - logger.info(f"✓ vLLM Engine created with fallback configuration") - else: - raise - - # ================================================================== - # 4. Create Environments - # ================================================================== - logger.info("Creating environments...") - logger.info(f"[DEBUG] Config values: collector_env_num={cfg.env.collector_env_num}, " - f"evaluator_env_num={cfg.env.evaluator_env_num}, " - f"n_evaluator_episode={cfg.env.n_evaluator_episode}") - env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) - logger.info(f"[DEBUG] get_vec_env_setting returned: " - f"collector envs={len(collector_env_cfg)}, " - f"evaluator envs={len(evaluator_env_cfg)}") - collector_env = create_env_manager( - cfg.env.manager, - [partial(env_fn, cfg=c) for c in collector_env_cfg] - ) - evaluator_env = create_env_manager( - cfg.env.manager, - [partial(env_fn, cfg=c) for c in evaluator_env_cfg] - ) - - # Seed environments - collector_env.seed(seed) - evaluator_env.seed(seed, dynamic_seed=False) - set_pkg_seed(seed, use_cuda=True) - logger.info(f"✓ Environments created and seeded (seed={seed})") - logger.info(f"[DEBUG] Actual env counts: collector={collector_env.env_num}, " - f"evaluator={evaluator_env.env_num}") - - # ================================================================== - # 5. Create Policy, Buffer, and Components - # ================================================================== - logger.info("Creating policy, buffer, and components...") - - # Create policy (align with UniZero) - policy = create_policy( - cfg.policy, - enable_field=['learn', 'collect', 'eval'] - ) - logger.info("✓ Policy created") - - # Create TensorBoard logger (align with UniZero) - os.makedirs(f'./{cfg.exp_name}/log/', exist_ok=True) - tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial')) if get_rank() == 0 else None - logger.info(f"✓ TensorBoard logger: ./{cfg.exp_name}/log/") - - # Create learner (align with UniZero - this sets up policy._logger) - learner = BaseLearner( - cfg.policy.learn.learner, - policy.learn_mode, - tb_logger, - exp_name=cfg.exp_name - ) - logger.info("✓ BaseLearner created") - - # [PRIORZERO-MODIFIED] Create PriorZero-specific replay buffer - # This buffer returns game_segments for LLM training (SFT/RFT) - from lzero.mcts.buffer.game_buffer_priorzero import PriorZeroGameBufferOptimized - replay_buffer = PriorZeroGameBufferOptimized(cfg.policy) - logger.info("✓ PriorZero replay buffer created (with game_segments support)") - - # Create collector - collector = PriorZeroCollector( - env=collector_env, - policy=policy.collect_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - vllm_engine=vllm_engine, - policy_config=cfg.policy, - debug_mode=cfg.get('debug_mode', False), - ) - logger.info("✓ Collector created") - - # Create evaluator - evaluator = PriorZeroEvaluator( - eval_freq=cfg.policy.eval_freq, - n_evaluator_episode=cfg.env.n_evaluator_episode, - stop_value=cfg.env.stop_value, - env=evaluator_env, - policy=policy.eval_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - vllm_engine=vllm_engine, - policy_config=cfg.policy, - ) - logger.info("✓ Evaluator created") - - # Initialize WandB if enabled (PriorZero enhancement) - if cfg.policy.get('use_wandb', True): - if get_rank() == 0: - wandb.init( - project=cfg.policy.get('wandb_project', 'priorzero'), - name=cfg.exp_name, - config=cfg, - tags=['priorzero', 'unizero', 'llm-policy'], - ) - logger.info("✓ WandB initialized") - # Set train iter and env step for policy wandb logging - policy.set_train_iter_env_step(learner.train_iter, collector.envstep) - - # Call learner's before_run hook (align with UniZero) - learner.call_hook('before_run') - - # ================================================================== - # 6. Initialize Async Training Coordinator - # ================================================================== - from async_training_coordinator import AsyncTrainingCoordinator - - coordinator = AsyncTrainingCoordinator( - off_policy_degree=cfg.policy.off_policy_degree, - enable_async_eval=cfg.policy.enable_async_eval, - buffer_size=cfg.policy.replay_buffer_size, - batch_size=cfg.policy.batch_size, - ) - - # ================================================================== - # 7. Main Training Loop - # ================================================================== - logger.info("="*80) - logger.info("Starting PriorZero Training") - logger.info("="*80) - logger.info(f"Experiment: {cfg.exp_name}") - logger.info(f"Max iterations: {max_train_iter}") - logger.info(f"Batch size: {cfg.policy.batch_size}") - logger.info(f"LLM model: {cfg.policy.llm_policy_cfg.pretrain_llm_path}") - logger.info(f"World model layers: {cfg.policy.model.world_model_cfg.num_layers}") - logger.info(f"Off-policy degree: {cfg.policy.off_policy_degree} ({'SYNC' if cfg.policy.off_policy_degree == 0 else 'ASYNC'})") - logger.info(f"Async eval: {cfg.policy.enable_async_eval}") - logger.info("="*80) - - # [ALIGN WITH UNIZERO] Initialize reanalyze-related counters (train_unizero_segment.py line 119-121) - buffer_reanalyze_count = 0 - train_epoch = 0 - reanalyze_batch_size = cfg.policy.reanalyze_batch_size - batch_size = cfg.policy.batch_size - best_eval_reward = -float('inf') - policy_config = cfg.policy - - # Async control variables - collect_task = None - train_task = None - pending_new_data = None # Store collected data waiting to be added to buffer - - try: - while True: - # ================================================================== - # Determine if we're in synchronous or asynchronous mode - # ================================================================== - is_sync_mode = coordinator.is_synchronous - - # ================================================================== - # Evaluation (align with train_unizero_segment.py line 158-162) - # ================================================================== - if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter): - # if learner.train_iter == 0 r evaluator.should_eval(learner.train_iter): - - logger.info(f"\n[Iter {learner.train_iter}] Evaluating...") - - # Define async eval function - async def eval_fn(): - return evaluator.eval( - save_ckpt_fn=learner.save_checkpoint if enable_save else None, - train_iter=learner.train_iter, - envstep=collector.envstep - ) - - # Run eval through coordinator (handles sync/async based on config) - eval_result = await coordinator.run_eval(eval_fn) - - # If sync eval, process result immediately - if not cfg.policy.enable_async_eval and eval_result is not None: - stop, eval_reward_dict = eval_result - mean_reward = eval_reward_dict.get('reward_mean', 0) - logger.info(f" ✓ Evaluation done: reward_mean={mean_reward:.2f}") - - if mean_reward > best_eval_reward: - best_eval_reward = mean_reward - - if stop: - logger.info(f" 🎉 Training converged! (reward >= {cfg.env.stop_value})") - break - else: - logger.info(f" ✓ Async evaluation started in background") - - # ================================================================== - # Collect Data (align with train_unizero_segment.py line 165) - # ================================================================== - collect_kwargs = { - 'temperature': 0.25, - 'epsilon': 0.0 - } - - if is_sync_mode: - # ============================================================ - # SYNCHRONOUS MODE: Original serial execution - # ============================================================ - logger.info(f"\n[Iter {learner.train_iter}] Collecting data...") - - new_data = await collector.collect( - train_iter=learner.train_iter, - policy_kwargs=collect_kwargs - ) - - # Update replay buffer - from lzero.entry.utils import calculate_update_per_collect - update_per_collect = calculate_update_per_collect(cfg, new_data, world_size=1) - - replay_buffer.push_game_segments(new_data) - replay_buffer.remove_oldest_data_to_fit() - buffer_size = replay_buffer.get_num_of_transitions() if hasattr(replay_buffer, 'get_num_of_transitions') else 0 - logger.info(f" ✓ Data collected, buffer size: {buffer_size} transitions") - - else: - # ============================================================ - # ASYNCHRONOUS MODE: Collect can overlap with train - # ============================================================ - # Start or check collect task - if collect_task is None or collect_task.done(): - if coordinator.can_collect(): - logger.info(f"\n[Iter {learner.train_iter}] Starting async collect...") - - # Define async collect function - async def collect_fn(): - return await collector.collect( - train_iter=learner.train_iter, - policy_kwargs=collect_kwargs - ) - - # Start collect task through coordinator - collect_task = asyncio.create_task(coordinator.run_collect(collect_fn)) - else: - logger.debug(f"Collect blocked (lag={coordinator.collect_train_lag}/{coordinator.off_policy_degree})") - - # Check if collect completed - if collect_task is not None and collect_task.done(): - new_data = await collect_task - collect_task = None - - # Store for buffer update - pending_new_data = new_data - logger.info(f" ✓ Async collect completed, data pending buffer update") - - # Update buffer if we have pending data - if pending_new_data is not None: - from lzero.entry.utils import calculate_update_per_collect - update_per_collect = calculate_update_per_collect(cfg, pending_new_data, world_size=1) - - replay_buffer.push_game_segments(pending_new_data) - replay_buffer.remove_oldest_data_to_fit() - buffer_size = replay_buffer.get_num_of_transitions() if hasattr(replay_buffer, 'get_num_of_transitions') else 0 - logger.info(f" ✓ Buffer updated, size: {buffer_size} transitions") - - pending_new_data = None - else: - # No new data yet, use previous update_per_collect or default - update_per_collect = cfg.policy.get('update_per_collect', 10) - - # ============================================================ - # Periodically reanalyze buffer (align with train_unizero_segment.py line 175-186) - # ============================================================ - if cfg.policy.buffer_reanalyze_freq >= 1: - # Reanalyze buffer times in one train_epoch - reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq - else: - # Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch - if train_epoch > 0 and train_epoch % int(1/cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): - logger.info(f"[Reanalyze] Starting buffer reanalysis...") - replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) - buffer_reanalyze_count += 1 - logger.info(f" ✓ Buffer reanalyze count: {buffer_reanalyze_count}") - - # ============================================================ - # Training (align with train_unizero_segment.py line 189-221) - # ============================================================ - if collector.envstep > cfg.policy.train_start_after_envsteps: - # Check if there is sufficient data for training - if cfg.policy.sample_type == 'episode': - data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size - else: - data_sufficient = replay_buffer.get_num_of_transitions() > batch_size - - if not data_sufficient: - logger.warning( - f' ⚠ Data in replay_buffer is not sufficient: ' - f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect...' - ) - continue - - logger.info(f"[Iter {learner.train_iter}] Training...") - - # Define training function - async def train_one_batch(): - # Reanalyze buffer during training (align with train_unizero_segment.py line 202-210) - # Note: This is simplified - full reanalyze logic should be per-batch - - # Sample batch - train_data = replay_buffer.sample(batch_size, policy) - train_data.insert(2, learner.train_iter) - - # Train - log_vars = learner.train(train_data, collector.envstep) - - # Update priority if enabled - if cfg.policy.use_priority: - replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) - - return log_vars - - if is_sync_mode: - # Synchronous: train all batches sequentially - for i in range(update_per_collect): - await train_one_batch() - else: - # Asynchronous: train batches while allowing collect to proceed - # We still train sequentially per batch, but collect can run in parallel - if coordinator.can_train(): - # Train one batch through coordinator - await coordinator.run_train(train_one_batch) - else: - logger.debug(f"Train waiting for collect...") - - # Increment epoch counter (align with train_unizero_segment.py line 222) - train_epoch += 1 - - # [FIX] Clear KV cache BEFORE collection to prevent index overflow during MCTS - policy.recompute_pos_emb_diff_and_clear_cache() - - # ============================================================ - # Check stopping criteria (align with train_unizero_segment.py line 226-227) - # ============================================================ - if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: - logger.info("Stopping condition met, training ends!") - break - - # In async mode, yield to event loop - if not is_sync_mode: - await asyncio.sleep(0.001) - - except KeyboardInterrupt: - logger.warning("\n⚠ Training interrupted by user (Ctrl+C)") - - except Exception as e: - logger.error(f"\n✗ Training error: {e}") - import traceback - traceback.print_exc() - - finally: - # ============================================================ - # Cleanup (align with train_unizero_segment.py line 229) - # ============================================================ - learner.call_hook('after_run') - - # Wait for any pending async eval - if cfg.policy.enable_async_eval: - logger.info("Waiting for async eval to complete...") - await coordinator.wait_for_eval() - - # Print async training statistics - async_stats = coordinator.get_statistics() - logger.info("\n" + "="*80) - logger.info("Async Training Statistics:") - logger.info(f" Mode: {async_stats['mode'].upper()}") - logger.info(f" Collect iterations: {async_stats['collect_count']}") - logger.info(f" Train iterations: {async_stats['train_count']}") - logger.info(f" Final lag: {async_stats['collect_train_lag']}") - if 'collect_avg_time' in async_stats: - logger.info(f" Avg collect time: {async_stats['collect_avg_time']:.2f}s") - if 'train_avg_time' in async_stats: - logger.info(f" Avg train time: {async_stats['train_avg_time']:.2f}s") - if 'eval_avg_time' in async_stats: - logger.info(f" Avg eval time: {async_stats['eval_avg_time']:.2f}s") - logger.info("="*80) - - logger.info("\nCleaning up...") - collector_env.close() - evaluator_env.close() - tb_logger.close() - - logger.info("="*80) - logger.info("Training Complete!") - logger.info(f"Total iterations: {learner.train_iter}") - logger.info(f"Best eval reward: {best_eval_reward:.2f}") - logger.info("="*80) - - return policy - - -def main(): - """ - Main entry point with argument parsing. - """ - import argparse - - parser = argparse.ArgumentParser(description='PriorZero Training') - parser.add_argument('--env_id', type=str, default='zork1.z5', help='Jericho game ID') - parser.add_argument('--seed', type=int, default=0, help='Random seed') - parser.add_argument('--max_iter', type=int, default=int(1e6), help='Max training iterations') - parser.add_argument('--quick_test', action='store_true', help='Use quick test config') - parser.add_argument('--no_save', action='store_true', help='Disable checkpoint saving') - parser.add_argument('--debug', action='store_true', help='Enable detailed debug logging (obs, action, LLM output)') - - args = parser.parse_args() - - # args.quick_test = True # ONLY FOR DEBUG - - # Get configuration - if args.quick_test: - logger.info("Using quick test configuration") - main_cfg, create_cfg = get_priorzero_config_for_quick_test(args.env_id, args.seed, debug_mode=args.debug) - else: - main_cfg, create_cfg = get_priorzero_config(args.env_id, args.seed, debug_mode=args.debug) - - # Run training - asyncio.run(train_priorzero( - main_cfg, - create_cfg, - seed=args.seed, - max_train_iter=args.max_iter, - enable_save=not args.no_save - )) - - -if __name__ == "__main__": - import os - # Disable tokenizer parallelism to prevent multi-process conflicts - os.environ['TOKENIZERS_PARALLELISM'] = 'false' - main() diff --git a/zoo/jericho/priorzero/priorzero_entry_async.py b/zoo/jericho/priorzero/priorzero_entry_async.py new file mode 100644 index 000000000..1f5d690d9 --- /dev/null +++ b/zoo/jericho/priorzero/priorzero_entry_async.py @@ -0,0 +1,326 @@ +import asyncio +import os +import sys +from functools import partial +from pathlib import Path +from typing import Tuple, Optional + +import ray +import torch +import wandb +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import create_buffer, BaseLearner +from tensorboardX import SummaryWriter +from loguru import logger +from ding.utils import DDPContext +from lzero.config.utils import lz_to_ddp_config + +os.environ.setdefault("VLLM_USE_V1", "1") +from vllm import AsyncLLMEngine +from vllm.engine.arg_utils import AsyncEngineArgs + +from priorzero_config import get_priorzero_config, get_priorzero_debug_config +from priorzero_collector import PriorZeroCollector +from priorzero_evaluator import PriorZeroEvaluator +import priorzero_policy +from lzero.mcts.buffer.game_buffer_priorzero import PriorZeroGameBufferOptimized +from lzero.entry.utils import calculate_update_per_collect + +async def train_priorzero( + cfg: dict, + create_cfg: dict, + seed: int = 0, + max_train_iter: int = int(1e6), + max_env_step: Optional[int] = int(1e10), +): + """ + [PRIORZERO-MODIFIED] + Main async training function for PriorZero. + + Args: + cfg: Main configuration dictionary + create_cfg: Creation configuration for DI-engine components + seed: Random seed + max_train_iter: Maximum training iterations + """ + cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) + if ray.is_initialized(): + logger.info(f"✓ Ray already initialized (connected to existing cluster)") + else: + logger.info(f"✓ Ray not initialized - vLLM will handle initialization if needed") + + logger.info("Creating environments...") + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager( cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager( cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(seed) + evaluator_env.seed(seed, dynamic_seed=False) + set_pkg_seed(seed, use_cuda=True) + + logger.info("Creating policy, buffer, and components...") + policy = create_policy( cfg.policy, enable_field=['learn', 'collect', 'eval'], exp_name=cfg.exp_name) + logger.info("✓ Policy created") + + os.makedirs(f'./{cfg.exp_name}/log/', exist_ok=True) + tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial')) if get_rank() == 0 else None + logger.info(f"✓ TensorBoard logger: ./{cfg.exp_name}/log/") + + if cfg.policy.llm_policy_cfg.enable_llm: + policy._init_llm_learn(tb_logger=tb_logger, exp_name=cfg.exp_name) + + logger.info("Creating vLLM engine...") + tensor_parallel = cfg.policy.llm_policy_cfg.vllm_tensor_parallel_size + distributed_backend = "ray" if tensor_parallel > 1 else None + + gpu_mem_util = cfg.policy.llm_policy_cfg.gpu_memory_utilization + + engine_args = AsyncEngineArgs( + model=policy.llm_ckpt_dir, + tensor_parallel_size=tensor_parallel, + gpu_memory_utilization=gpu_mem_util, + distributed_executor_backend=distributed_backend, + trust_remote_code=True, + enable_prefix_caching=False, + enforce_eager=False, + ) + vllm_engine = AsyncLLMEngine.from_engine_args(engine_args) + logger.info(f"✓ vLLM Engine created (backend: {distributed_backend or 'default'})") + + learner = BaseLearner( + cfg.policy.learn.learner, + policy.learn_mode, + tb_logger, + exp_name=cfg.exp_name + ) + logger.info("✓ BaseLearner created") + + + replay_buffer = PriorZeroGameBufferOptimized(cfg.policy) + logger.info("✓ PriorZero replay buffer created (with game_segments support)") + + # Create collector + collector = PriorZeroCollector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + vllm_engine=vllm_engine, + policy_config=cfg.policy, + ) + logger.info("✓ Collector created") + + # Create evaluator + evaluator = PriorZeroEvaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + vllm_engine=vllm_engine, + policy_config=cfg.policy, + ) + logger.info("✓ Evaluator created") + learner.call_hook('before_run') + + from async_training_coordinator import AsyncTrainingCoordinator + + coordinator = AsyncTrainingCoordinator( + off_policy_degree=cfg.policy.off_policy_degree, + enable_async_eval=cfg.policy.enable_async_eval, + buffer_size=cfg.policy.replay_buffer_size, + batch_size=cfg.policy.batch_size, + ) + assert not coordinator.is_synchronous, print(f'采取异步形式!') + # ================================================================== + # Main Training Loop + # ================================================================== + logger.info("="*80) + logger.info("Starting PriorZero Training") + logger.info("="*80) + logger.info(f"Experiment: {cfg.exp_name}") + logger.info(f"Max iterations: {max_train_iter}") + logger.info(f"Batch size: {cfg.policy.batch_size}") + logger.info(f"LLM model: {cfg.policy.llm_policy_cfg.pretrain_llm_path}") + logger.info(f"World model layers: {cfg.policy.model.world_model_cfg.num_layers}") + logger.info(f"Off-policy degree: {cfg.policy.off_policy_degree} ({'SYNC' if cfg.policy.off_policy_degree == 0 else 'ASYNC'})") + logger.info(f"Async eval: {cfg.policy.enable_async_eval}") + logger.info("="*80) + + # [ALIGN WITH UNIZERO] Initialize reanalyze-related counters (train_unizero_segment.py line 119-121) + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + batch_size = cfg.policy.batch_size + + # Async control variables + collect_task = None + pending_new_data = None # Store collected data waiting to be added to buffer + + + if cfg.policy.multi_gpu: + world_size = get_world_size() + rank = get_rank() + else: + world_size = 1 + rank = 0 + + while True: + if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter): + logger.info(f"\n[Iter {learner.train_iter}] Evaluating...") + + async def eval_fn(): + return evaluator.eval( + save_ckpt_fn=learner.save_checkpoint, + train_iter=learner.train_iter, + envstep=collector.envstep + ) + stop, reward = await coordinator.run_eval(eval_fn) + if stop: + break + + collect_kwargs = { + 'temperature': 0.25, + 'epsilon': 0.0 + } + + if collect_task is None or collect_task.done(): + if coordinator.can_collect(): + logger.info(f"\n[Iter {learner.train_iter}] Starting async collect...") + + async def collect_fn(): + return await collector.collect( + train_iter=learner.train_iter, + policy_kwargs=collect_kwargs + ) + + collect_task = asyncio.create_task(coordinator.run_collect(collect_fn)) + else: + logger.debug(f"Collect blocked (lag={coordinator.collect_train_lag}/{coordinator.off_policy_degree})") + + if collect_task is not None and collect_task.done(): + new_data = await collect_task + collect_task = None + + pending_new_data = new_data + logger.info(f" ✓ Async collect completed, data pending buffer update") + + if pending_new_data is not None: + update_per_collect = calculate_update_per_collect(cfg, pending_new_data, world_size=world_size) + + replay_buffer.push_game_segments(pending_new_data) + buffer_size = replay_buffer.get_num_of_transitions() if hasattr(replay_buffer, 'get_num_of_transitions') else 0 + logger.info(f" ✓ Buffer updated, size: {buffer_size} transitions") + + pending_new_data = None + else: + update_per_collect = cfg.policy.get('update_per_collect', 10) + + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch > 0 and train_epoch % int(1/cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): + logger.info(f"[Reanalyze] Starting buffer reanalysis...") + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logger.info(f" ✓ Buffer reanalyze count: {buffer_reanalyze_count}") + + if collector.envstep > cfg.policy.train_start_after_envsteps: + if cfg.policy.sample_type == 'episode': + data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size + else: + data_sufficient = replay_buffer.get_num_of_transitions() > batch_size + + if not data_sufficient: + logger.warning( + f' ⚠ Data in replay_buffer is not sufficient: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect...' + ) + continue + + logger.info(f"[Iter {learner.train_iter}] Training...") + + async def train_one_batch(): + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(learner.train_iter) + + log_vars = learner.train(train_data, collector.envstep) + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + return log_vars + + if coordinator.can_train(): + await coordinator.run_train(train_one_batch) + else: + logger.debug(f"Train waiting for collect...") + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + logger.info("Stopping condition met, training ends!") + break + + await asyncio.sleep(0.001) + + if cfg.policy.enable_async_eval: + logger.info("Waiting for async eval to complete...") + await coordinator.wait_for_eval() + return policy + + +def main(): + """ + Main entry point with argument parsing. + """ + import argparse + + parser = argparse.ArgumentParser(description='PriorZero Training') + parser.add_argument('--env_id', type=str, default='zork1.z5', help='Jericho game ID') + parser.add_argument('--seed', type=int, default=0, help='Random seed') + parser.add_argument('--max_iter', type=int, default=int(1e6), help='Max training iterations') + parser.add_argument('--quick_test', action='store_true', help='Use quick test config') + parser.add_argument('--no_save', action='store_true', help='Disable checkpoint saving') + parser.add_argument('--debug', action='store_true', help='Enable detailed debug logging (obs, action, LLM output)') + + args = parser.parse_args() + + + args.quick_test = True + if args.quick_test: + logger.info("Using quick test configuration") + main_cfg, create_cfg = get_priorzero_debug_config(args.env_id, args.seed, exp_name=f'data_priorzero/priorzero_async_debug_{args.env_id}_seed0') + else: + main_cfg, create_cfg = get_priorzero_config(args.env_id, args.seed, exp_name=f'data_priorzero/priorzero_rft_reinforce++_{args.env_id}_seed0') + + main_cfg.policy.off_policy_degree = 1 + main_cfg.policy.enable_async_eval = True + + if main_cfg.policy.multi_gpu: + with DDPContext(): + main_cfg = lz_to_ddp_config(main_cfg) + asyncio.run(train_priorzero( + main_cfg, + create_cfg, + seed=args.seed, + max_train_iter=args.max_iter, + )) + + else: + # Run training + asyncio.run(train_priorzero( + main_cfg, + create_cfg, + seed=args.seed, + max_train_iter=args.max_iter, + )) + + +if __name__ == "__main__": + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + main() diff --git a/zoo/jericho/priorzero/priorzero_entry_sync.py b/zoo/jericho/priorzero/priorzero_entry_sync.py new file mode 100644 index 000000000..af0792fd0 --- /dev/null +++ b/zoo/jericho/priorzero/priorzero_entry_sync.py @@ -0,0 +1,288 @@ +import asyncio +import os +import sys +from functools import partial +from pathlib import Path +from typing import Tuple, Optional + +import ray +import torch +import wandb +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import create_buffer, BaseLearner +from tensorboardX import SummaryWriter +from loguru import logger +from ding.utils import DDPContext +from lzero.config.utils import lz_to_ddp_config + +os.environ.setdefault("VLLM_USE_V1", "1") +from vllm import AsyncLLMEngine +from vllm.engine.arg_utils import AsyncEngineArgs + +from priorzero_config import get_priorzero_config, get_priorzero_debug_config +from priorzero_collector import PriorZeroCollector +from priorzero_evaluator import PriorZeroEvaluator +from priorzero_policy import * +from lzero.mcts.buffer.game_buffer_priorzero import PriorZeroGameBufferOptimized +from lzero.entry.utils import calculate_update_per_collect +from priorzero_llm_modules import PriorZeroOpenRLHFLLMConfig, PriorZeroOpenRLHFLLMTrainer + + +def train_priorzero( + cfg: dict, + create_cfg: dict, + seed: int = 0, + max_train_iter: int = int(1e6), + max_env_step: Optional[int] = int(1e10), +): + """ + [PRIORZERO-MODIFIED] + Main async training function for PriorZero. + + Args: + cfg: Main configuration dictionary + create_cfg: Creation configuration for DI-engine components + seed: Random seed + max_train_iter: Maximum training iterations + """ + cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) + if ray.is_initialized(): + logger.info(f"✓ Ray already initialized (connected to existing cluster)") + else: + logger.info(f"✓ Ray not initialized - vLLM will handle initialization if needed") + + logger.info("Creating environments...") + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager( cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager( cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(seed) + evaluator_env.seed(seed, dynamic_seed=False) + set_pkg_seed(seed, use_cuda=True) + + logger.info("Creating policy, buffer, and components...") + policy = create_policy( cfg.policy, enable_field=['learn', 'collect', 'eval'], exp_name=cfg.exp_name) + logger.info("✓ Policy created") + + os.makedirs(f'./{cfg.exp_name}/log/', exist_ok=True) + tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial')) if get_rank() == 0 else None + logger.info(f"✓ TensorBoard logger: ./{cfg.exp_name}/log/") + + vllm_engine = None + if cfg.policy.llm_policy_cfg.enable_llm: + llm_cfg = PriorZeroOpenRLHFLLMConfig( + model_name_or_path=policy.llm_policy_cfg.pretrain_llm_path, + zero_stage=policy.llm_policy_cfg.zero_stage, # 你传 zero_stage2.json + lr=policy.llm_policy_cfg.learning_rate, + weight_decay=policy.llm_policy_cfg.weight_decay, + prompt_max_len=policy.llm_policy_cfg.prompt_max_len, + generate_max_len=policy.llm_policy_cfg.generate_max_len, + use_cot=policy.llm_policy_cfg.use_cot, + rft_loss_type=policy.llm_policy_cfg.rft_loss_type, + rft_clip_epsilon=policy.llm_policy_cfg.rft_clip_epsilon, + rft_kl_coef=policy.llm_policy_cfg.rft_kl_coef, + train_batch_size=policy.llm_policy_cfg.train_batch_size, + micro_train_batch_size=policy.llm_policy_cfg.micro_batch_size, + gradient_accumulation_steps=policy.llm_policy_cfg.gradient_accumulation_steps, + bf16=True, + enable_vllm=True, + vllm_num_engines=1, + vllm_tensor_parallel_size=policy.llm_policy_cfg.vllm_tensor_parallel_size, + gpu_memory_utilization=policy.llm_policy_cfg.gpu_memory_utilization, + seed=seed, + temperature=policy.llm_policy_cfg.temperature, + top_p=policy.llm_policy_cfg.top_p, + ) + trainer = PriorZeroOpenRLHFLLMTrainer(llm_cfg, tb_logger=tb_logger, exp_name=cfg.exp_name) + llm_prior_generator = trainer.llm_prior_generator + # policy._init_llm_learn(tb_logger=tb_logger, exp_name=cfg.exp_name, vllm_engine=vllm_engine) + + learner = BaseLearner( + cfg.policy.learn.learner, + policy.learn_mode, + tb_logger, + exp_name=cfg.exp_name + ) + logger.info("✓ BaseLearner created") + + + replay_buffer = PriorZeroGameBufferOptimized(cfg.policy) + logger.info("✓ PriorZero replay buffer created (with game_segments support)") + + # Create collector + collector = PriorZeroCollector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + llm_prior_generator=llm_prior_generator, + policy_config=cfg.policy, + ) + logger.info("✓ Collector created") + + # Create evaluator + evaluator = PriorZeroEvaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + vllm_engine=vllm_engine, + policy_config=cfg.policy, + ) + logger.info("✓ Evaluator created") + learner.call_hook('before_run') + # ================================================================== + # Main Training Loop + # ================================================================== + logger.info("="*80) + logger.info("Starting PriorZero Training") + logger.info("="*80) + logger.info(f"Experiment: {cfg.exp_name}") + logger.info(f"Max iterations: {max_train_iter}") + logger.info(f"Batch size: {cfg.policy.batch_size}") + logger.info(f"LLM model: {cfg.policy.llm_policy_cfg.pretrain_llm_path}") + logger.info(f"World model layers: {cfg.policy.model.world_model_cfg.num_layers}") + logger.info(f"Off-policy degree: {cfg.policy.off_policy_degree} ({'SYNC' if cfg.policy.off_policy_degree == 0 else 'ASYNC'})") + logger.info(f"Async eval: {cfg.policy.enable_async_eval}") + logger.info("="*80) + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + batch_size = cfg.policy.batch_size + + if cfg.policy.multi_gpu: + world_size = get_world_size() + rank = get_rank() + else: + world_size = 1 + rank = 0 + + while True: + if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter): + logger.info(f"\n[Iter {learner.train_iter}] Evaluating...") + stop, reward = evaluator.eval( + save_ckpt_fn=learner.save_checkpoint, + train_iter=learner.train_iter, + envstep=collector.envstep + ) + if stop: + break + + collect_kwargs = { + 'temperature': 0.25, + 'epsilon': 0.0 + } + + new_data = collector.collect( + train_iter=learner.train_iter, + policy_kwargs=collect_kwargs + ) + update_per_collect = calculate_update_per_collect(cfg, new_data, world_size=world_size) + + replay_buffer.push_game_segments(new_data) + num_of_transitions = replay_buffer.get_num_of_transitions() + logger.info(f" ✓ Data collected, num_of_transitions: {num_of_transitions} transitions") + + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch > 0 and train_epoch % int(1/cfg.policy.buffer_reanalyze_freq) == 0: + logger.info(f"[Reanalyze] Starting buffer reanalysis...") + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logger.info(f" ✓ Buffer reanalyze count: {buffer_reanalyze_count}") + + if collector.envstep <= cfg.policy.train_start_after_envsteps: + continue + + if cfg.policy.sample_type == 'episode': + data_sufficient = num_of_transitions > batch_size + else: + data_sufficient = num_of_transitions > batch_size + + if not data_sufficient: + logger.warning( + f' ⚠ Data in replay_buffer is not sufficient: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect...' + ) + continue + + logger.info(f"[Iter {learner.train_iter}] Training...") + for i in range(update_per_collect): + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(learner.train_iter) + + log_vars = learner.train(train_data, collector.envstep) + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + if num_of_transitions >= replay_buffer.replay_buffer_size: + all_data = replay_buffer.sample(batch_size=cfg.policy.llm_policy_cfg.llm_learn_num_samples, policy=policy) + replay_buffer._clear() + trainer.train_rft_from_priorzero_batch(all_data) + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + logger.info("Stopping condition met, training ends!") + break + + + return policy + + +def main(): + """ + Main entry point with argument parsing. + """ + import argparse + + parser = argparse.ArgumentParser(description='PriorZero Training') + parser.add_argument('--env_id', type=str, default='zork1.z5', help='Jericho game ID') + parser.add_argument('--seed', type=int, default=0, help='Random seed') + parser.add_argument('--max_iter', type=int, default=int(1e6), help='Max training iterations') + parser.add_argument('--quick_test', action='store_true', help='Use quick test config') + parser.add_argument('--no_save', action='store_true', help='Disable checkpoint saving') + parser.add_argument('--debug', action='store_true', help='Enable detailed debug logging (obs, action, LLM output)') + + args = parser.parse_args() + + + args.quick_test = True + if args.quick_test: + logger.info("Using quick test configuration") + main_cfg, create_cfg = get_priorzero_debug_config(args.env_id, args.seed, exp_name=f'data_priorzero/priorzero_sync_debug_{args.env_id}_seed0') + else: + main_cfg, create_cfg = get_priorzero_config(args.env_id, args.seed, exp_name=f'data_priorzero/priorzero_sync_rft_reinforce++_{args.env_id}_seed0') + + if main_cfg.policy.multi_gpu: + with DDPContext(): + main_cfg = lz_to_ddp_config(main_cfg) + asyncio.run(train_priorzero( + main_cfg, + create_cfg, + seed=args.seed, + max_train_iter=args.max_iter, + )) + + else: + # Run training + asyncio.run(train_priorzero( + main_cfg, + create_cfg, + seed=args.seed, + max_train_iter=args.max_iter, + )) + + +if __name__ == "__main__": + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + main() diff --git a/zoo/jericho/priorzero/priorzero_evaluator.py b/zoo/jericho/priorzero/priorzero_evaluator.py index c8a25d0f9..0dc3abc09 100644 --- a/zoo/jericho/priorzero/priorzero_evaluator.py +++ b/zoo/jericho/priorzero/priorzero_evaluator.py @@ -1,15 +1,3 @@ -# priorzero_evaluator.py -""" -[PRIORZERO] PriorZero Evaluator - -Simple evaluator that inherits from MuZeroEvaluator. -Since the policy already integrates LLM priors in its _forward_collect method, -the evaluator can use the parent implementation directly. - -Author: PriorZero Team -Date: 2025-01-20 -""" - from typing import Optional from ding.worker.collector.base_serial_evaluator import SERIAL_EVALUATOR_REGISTRY diff --git a/zoo/jericho/priorzero/priorzero_llm_modules.py b/zoo/jericho/priorzero/priorzero_llm_modules.py new file mode 100644 index 000000000..67595b8f5 --- /dev/null +++ b/zoo/jericho/priorzero/priorzero_llm_modules.py @@ -0,0 +1,410 @@ +from __future__ import annotations +import os +import copy +import json +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +import deepspeed +import ray +import numpy as np +from transformers import AutoTokenizer, AutoModelForCausalLM + +from ding.utils import build_logger +from utils.vllm_engine import create_vllm_engines, batch_vllm_engine_call +from utils.generator import SamplesGenerator +from priorzero_policy import build_llm_prompt +from openrlhf.utils import get_strategy +from openrlhf.trainer.ray.utils import get_physical_gpu_id +from priorzero_utils import compute_approx_kl + + +def torch_dist_barrier_and_cuda_sync(): + """Synchronize distributed training and CUDA operations. + This function ensures that: + 1. All distributed processes reach this point (barrier) + 2. All CUDA operations are completed (synchronize) + """ + import torch + torch.distributed.barrier() + torch.cuda.synchronize() + +@dataclass +class PriorZeroOpenRLHFLLMConfig: + model_name_or_path: str + bf16: bool = True + + prompt_max_len: int = 2048 + generate_max_len: int = 128 + use_cot: bool = True + + rft_loss_type: str = "reinforce++" # "reinforce" | "reinforce++" + rft_clip_epsilon: float = 0.2 + rft_kl_coef: float = 0.0 + + # DeepSpeed + zero_stage: int = 0 # 只提供 zero_optimization + lr: float = 1e-6 + weight_decay: float = 0.01 + grad_clip: float = 1.0 + micro_train_batch_size: int = 1 + train_batch_size: int=128 + gradient_accumulation_steps: int = 1 + ds_tensor_parallel_size: int = 1 + + # vLLM engines (OpenRLHF) + enable_vllm: bool = True + enable_prefix_caching: bool = True + vllm_num_engines: int = 1 + vllm_tensor_parallel_size: int = 1 + gpu_memory_utilization: float = 0.90 + temperature: float = 1.0 + top_p: float = 1.0 + seed: int = 0 + +class PriorZeroOpenRLHFLLMTrainer: + """ + 目标: + - 复用 OpenRLHF 的 vLLM RayActor 引擎与 weight update RPC + - RFT 训练走 DeepSpeed(支持单进程/多进程) + - 权重同步走 update_weight_cuda_ipc(同机同卡多进程最直接) + """ + + def __init__(self, cfg: PriorZeroOpenRLHFLLMConfig, tb_logger, exp_name, instance_name='rft_llm'): + self.cfg = cfg + self.lr = cfg.lr + self.weight_decay = cfg.weight_decay + self.cfg.local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if tb_logger is not None: + self._logger, _ = build_logger( + path=f'./{exp_name}/log/{instance_name}', name=instance_name, need_tb=False + ) + self._tb_logger = tb_logger + else: + pass + self.rft_log = {} + self.train_samples_cnt = 0 + + if not ray.is_initialized(): + ray.init() + + self.use_cuda_ipc = True + + self.strategy = get_strategy(self.cfg) + self.strategy.setup_distributed() # 分布式初始化 + tokenizer + model + optimizer + deepspeed.initialize + + self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, trust_remote_code=True, padding_side="left") + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + cfg.model_name_or_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16 if cfg.bf16 else torch.float16, + device_map=None, + ) + + optim = self.strategy.create_optimizer( + model, + lr=self.lr, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=self.weight_decay, + ) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optim, + T_max=100000, + eta_min=self.lr * 0.1 + ) + self.model_engine, self.optim, self.scheduler = self.strategy.prepare( + (model, optim, scheduler), + is_rlhf=False, + ) + + self.ref_model = None + if cfg.rft_kl_coef > 0.0: + self.ref_model = copy.deepcopy(model).eval().to(self.model_engine.device) + for p in self.ref_model.parameters(): + p.requires_grad_(False) + + self.vllm_engines = None + if cfg.enable_vllm: + self.vllm_engines = create_vllm_engines( + num_engines=cfg.vllm_num_engines, + tensor_parallel_size=cfg.vllm_tensor_parallel_size, + pretrain=cfg.model_name_or_path, + seed=cfg.seed, + full_determinism=False, + enable_prefix_caching=cfg.enable_prefix_caching, + enforce_eager=False, + gpu_memory_utilization=cfg.gpu_memory_utilization, + max_model_len=cfg.prompt_max_len + cfg.generate_max_len, + ) + self.llm_prior_generator = SamplesGenerator(vllm_engines=self.vllm_engines, + strategy=self.strategy, + tokenizer=self.tokenizer, + prompt_max_len=cfg.prompt_max_len, + temperature=cfg.temperature, + top_p=cfg.top_p) + + self._logger.info(f"✓ Load LLM Model in {cfg.model_name_or_path}") + + def build_samples( + self, + raw_obs_list: List[List[str]], + history_obs_list: List[List[List[Tuple[str, str, float]]]], + action_logprob_list: Optional[List[List[Any]]] = None, + target_values: Optional[torch.Tensor] = None, # [B, T-1] 的 G_t + ) -> List[Dict[str, Any]]: + samples: List[Dict[str, Any]] = [] + B = len(raw_obs_list) + if B == 0: + return samples + T = len(raw_obs_list[0]) + + for b in range(B): + for t in range(T - 1): + current_obs = raw_obs_list[b][t] + current_hist = history_obs_list[b][t] + next_hist = history_obs_list[b][t + 1] + + _, true_action, reward_value = next_hist[-1] + if not true_action: + continue + + instruction = build_llm_prompt( + current_obs=current_obs, + history=current_hist, + use_cot=self.cfg.use_cot, + ) + prompt = self.tokenizer.apply_chat_template( + [{"role": "user", "content": instruction}], + tokenize=False, + add_generation_prompt=True, + ) + + old_logprob = None + if action_logprob_list is not None: + old_logprob = action_logprob_list[b][t + 1][true_action] + + target_value = None + if target_values is not None: + target_value = float(target_values[b][t].item()) + + samples.append( + { + "prompt": prompt, + "target": f"{true_action}{self.tokenizer.eos_token}", + "reward": float(reward_value) if reward_value is not None else 0.0, + "target_value": target_value, + "old_logprob": old_logprob, # Reinforce++ ratio 需要 + } + ) + return samples + + def log_state_to_tb(self): + if self._tb_logger is not None: + for k, v in self.rft_log.items(): + self._tb_logger.add_scalar(f'learner_llm_iter/{k}', np.mean(v) if v is not None else 0.0, self.train_samples_cnt) + + self.rft_log = {} + + def _log_state(self, x, name='none'): + if name in self.rft_log: + self.rft_log[name].append(x) + else: + self.rft_log[name] = [x] + + def train_rft_from_priorzero_batch( + self, + data: Tuple[torch.Tensor] + ) -> Dict[str, float]: + + current_batch, target_batch = data + obs_batch_ori, action_batch, target_action_batch, mask_batch, batch_index_tensor, weights, make_time, timestep_batch, raw_obs_list, history_obs_list, action_logprob_list = current_batch + target_reward, target_value, target_policy = target_batch + + samples = self.build_samples(raw_obs_list, history_obs_list, action_logprob_list, target_value) + if len(samples) == 0: + return {"rft_loss": 0.0} + + micro_train_batch_size = self.strategy.micro_train_batch_size + gradient_accumulation_steps = self.strategy.accumulated_gradient + clip_eps = self.cfg.rft_clip_epsilon + kl_coef = self.cfg.rft_kl_coef + loss_type = self.cfg.rft_loss_type.lower() + + self.model_engine.train() + total_loss = 0.0 + + for i in range(0, len(samples), micro_train_batch_size): + chunk = samples[i:i + micro_train_batch_size] + full_texts = [s["prompt"] + s["target"] for s in chunk] + prompts_only = [s["prompt"] for s in chunk] + + inputs = self.tokenizer( + full_texts, + padding=True, + truncation=True, + max_length=self.cfg.prompt_max_len, + return_tensors="pt", + ).to(self.model_engine.device) + + labels = inputs.input_ids.clone() + labels[inputs.attention_mask == 0] = -100 + + for row, ptxt in enumerate(prompts_only): + pad_len = int((inputs.attention_mask[row] == 0).sum().item()) + p_ids = self.tokenizer.encode(ptxt, add_special_tokens=False) + p_len = len(p_ids) + real_prompt_len = pad_len + p_len + labels[row, :real_prompt_len] = -100 + + outputs = self.model_engine(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask) + logits = outputs.logits[:, :-1, :].contiguous() + shifted_labels = labels[:, 1:].contiguous() + + token_logp = -F.cross_entropy(logits.transpose(1, 2), shifted_labels, reduction="none") + mask = (shifted_labels != -100).float() + token_logp = token_logp * mask + seq_logp = token_logp.sum(dim=-1) / (mask.sum(dim=-1) + 1e-8) # 与你现在的实现一致:mean logp + self._log_state(x=seq_logp.mean().item(), name='rft_logprob') + + gt = torch.tensor([s["target_value"] if s["target_value"] is not None else s["reward"] for s in chunk], + device=self.model_engine.device, dtype=torch.float32) + + if loss_type == "reinforce": + adv = gt + self._log_state(x=adv.mean().item(), name='rft_advantage') + + loss = -(adv * seq_logp).mean() + else: + adv = (gt - gt.mean()) / (gt.std() + 1e-8) + self._log_state(x=adv.mean().item(), name='rft_advantage') + + old_lp = torch.tensor([s["old_logprob"] for s in chunk], + device=self.model_engine.device, dtype=torch.float32) + ratio = torch.exp(seq_logp - old_lp) + clipped = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) + surrogate1 = ratio * adv + surrogate2 = clipped * adv + + used_ratio = torch.where(surrogate1 <= surrogate2, ratio, clipped) + self._log_state(x=used_ratio.mean().item(), name='rft_ratio_used') + + loss = -(torch.min(surrogate1, surrogate2)).mean() + + # optional KL(pi || ref) + if kl_coef > 0.0 and self.ref_model is not None: + with torch.no_grad(): + ref_out = self.ref_model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask) + ref_logits = ref_out.logits[:, :-1, :].contiguous() + ref_token_logp = -F.cross_entropy(ref_logits.transpose(1, 2), shifted_labels, reduction="none") + ref_token_logp = (ref_token_logp * mask) + ref_seq_logp = ref_token_logp.sum(dim=-1) / (mask.sum(dim=-1) + 1e-8) + kl_per_seq = compute_approx_kl(seq_logp, ref_seq_logp, kl_estimator='k2') + kl_loss = kl_per_seq.mean() + + self._log_state(x=kl_loss.item(), name='rft_kl') + + loss = loss + kl_coef * kl_loss + + total_loss += loss.item() + self.strategy.backward(loss, self.model_engine, self.optim) + self.strategy.optimizer_step(self.optim, self.model_engine, self.scheduler) + + self._log_state(x=total_loss/gradient_accumulation_steps, name='rft_loss') + self.train_samples_cnt += len(samples) + + if self.vllm_engines is not None: + self._broadcast_to_vllm() + self.log_state_to_tb() + + def _broadcast_to_vllm(self): + use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False) + cache_reset_refs = [] + if use_prefix_cache and torch.distributed.get_rank() == 0: + # clear prefix cache + for engine in self.vllm_engines: + cache_reset_refs.append(engine.reset_prefix_cache.remote()) + + torch.cuda.empty_cache() + model = self.model_engine.module + count, num_params = 0, len(list(model.named_parameters())) + + def _broadcast_param(param, count, num_params): + use_ray = getattr(self.strategy.args, "vllm_sync_with_ray", False) + # Fire all vllm engines for broadcast + if torch.distributed.get_rank() == 0: + shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape + refs = [ + engine.update_weight.remote(name, dtype=param.dtype, shape=shape, empty_cache=count == num_params) + for engine in self.vllm_engines + ] + + if use_ray: + import ray.util.collective as collective + + collective.broadcast(param.data, 0, group_name=self._model_update_group) + else: + self._model_update_group.broadcast(param.data, src=0, stream=torch.cuda.current_stream()) + ray.get(refs) + + def _handle_cuda_ipc(param, count, num_params): + from torch.multiprocessing.reductions import reduce_tensor + + weight = param.data.clone() + ipc_handle = reduce_tensor(weight) + + ipc_handle = {get_physical_gpu_id(): ipc_handle} + ipc_handle_list = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(ipc_handle_list, ipc_handle) + + if torch.distributed.get_rank() == 0: + ipc_handles = {} + for d in ipc_handle_list: + ipc_handles.update(d) + + shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape + refs = [ + engine.update_weight_cuda_ipc.remote( + name, + dtype=param.dtype, + shape=shape, + ipc_handles=ipc_handles, + empty_cache=count == num_params, + ) + for engine in self.vllm_engines + ] + ray.get(refs) + torch_dist_barrier_and_cuda_sync() + + for name, param in model.named_parameters(): + count += 1 # empty_cache at last param + + # broadcast + if not self.use_cuda_ipc: + # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0 + if self.strategy.args.ds_tensor_parallel_size > 1: + with deepspeed.module_inject.layers.GatherReplacedLayerParams([param], model, enabled=True): + _broadcast_param(param, count, num_params) + else: + with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3): + _broadcast_param(param, count, num_params) + # CUDA IPC + else: + if self.strategy.args.ds_tensor_parallel_size > 1: + with deepspeed.module_inject.layers.GatherReplacedLayerParams([param], model, enabled=True): + _handle_cuda_ipc(param, count, num_params) + else: + with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3): + _handle_cuda_ipc(param, count, num_params) + + if cache_reset_refs: + ray.get(cache_reset_refs) + torch.cuda.empty_cache() + torch_dist_barrier_and_cuda_sync() + + \ No newline at end of file diff --git a/zoo/jericho/priorzero/priorzero_orz_complete.py b/zoo/jericho/priorzero/priorzero_orz_complete.py deleted file mode 100644 index f0daf5958..000000000 --- a/zoo/jericho/priorzero/priorzero_orz_complete.py +++ /dev/null @@ -1,965 +0,0 @@ -""" -PriorZero-ORZ Complete Integration -完整可执行版本 with ORZ RayPPOTrainer - -This version includes: -1. Fixed vLLM None handling -2. Fixed asyncio scope issue -3. Complete ORZ RayPPOTrainer integration -4. Robust error handling - -Usage: - DEBUG_MODE=True python -m zoo.jericho.priorzero.priorzero_orz_complete - -Author: PriorZero Team -Date: 2025-10-21 -""" - -import asyncio -import os -import sys -import re -from pathlib import Path -from functools import partial -from typing import Optional, List, Dict, Any, Callable, Awaitable, Tuple -import time -import json - -# ============================================================================== -# Ensure local LightZero is used -# ============================================================================== -from ensure_local_lightzero import ensure_local_lightzero -ensure_local_lightzero() - -import torch -import numpy as np -from ding.config import compile_config -from ding.envs import create_env_manager, get_vec_env_setting -from ding.policy import create_policy -from ding.utils import set_pkg_seed, get_rank -from ding.worker import BaseLearner -from tensorboardX import SummaryWriter -from loguru import logger - -# PriorZero imports -from priorzero_config import get_priorzero_config_for_quick_test, get_priorzero_config -from priorzero_collector import PriorZeroCollector -from priorzero_evaluator import PriorZeroEvaluator -import priorzero_policy # noqa: F401 -from lzero.mcts.buffer.game_buffer_priorzero import PriorZeroGameBufferOptimized - -# vLLM imports (optional) -try: - from vllm import AsyncLLMEngine - from vllm.engine.arg_utils import AsyncEngineArgs - VLLM_AVAILABLE = True -except ImportError: - VLLM_AVAILABLE = False - logger.warning("vLLM not available - LLM inference will be disabled") - -# Try to import ORZ -ORZ_AVAILABLE = False -ORZ_PATH = Path("/mnt/nfs/zhangjinouwen/puyuan/Open-Reasoner-Zero") - -try: - if ORZ_PATH.exists() and str(ORZ_PATH) not in sys.path: - sys.path.insert(0, str(ORZ_PATH)) - - from orz.ppo import RayPPOTrainer, PromptDataset - from orz.exps.examples.ppo.ppo_base_exp import BasePPOExp, BasePPOExpConfig - from orz.ppo.utils import get_strategy - from transformers import AutoTokenizer - import ray - ORZ_AVAILABLE = True - logger.info("✅ ORZ available - will use ORZ RayPPOTrainer for LLM training") -except ImportError as e: - logger.warning(f"⚠️ ORZ not available ({e}) - will use PriorZero's built-in LLM training") - - -# ============================================================================== -# Configuration -# ============================================================================== - -DEBUG_MODE = os.environ.get("DEBUG_MODE", "False") == "True" - - -class HybridTrainingConfig: - """ - Hybrid training configuration combining PriorZero and ORZ settings. - """ - def __init__(self): - # Get base PriorZero config - if DEBUG_MODE: - self.priorzero_cfg, self.priorzero_create_cfg = get_priorzero_config_for_quick_test( - env_id='zork1.z5', - seed=0, - debug_mode=True - ) - else: - self.priorzero_cfg, self.priorzero_create_cfg = get_priorzero_config( - env_id='zork1.z5', - seed=0, - enable_llm=True, - enable_rft=True, - debug_mode=False - ) - - # Hybrid-specific settings - self.wm_training_mode = "parallel" - self.wm_train_freq = 1 - self.llm_train_freq = 5 - self.use_orz_trainer = ORZ_AVAILABLE - - # vLLM settings - self.use_vllm = VLLM_AVAILABLE - self.vllm_required = False # Set to True if vLLM is required - - # ORZ-specific settings (only used if ORZ_AVAILABLE) - if ORZ_AVAILABLE: - self.orz_rollout_batch_size = 32 if DEBUG_MODE else 128 - self.orz_train_batch_size = 8 if DEBUG_MODE else 32 - self.orz_actor_lr = 1e-6 - self.orz_critic_lr = 5e-6 - self.orz_num_episodes = 2 if DEBUG_MODE else 10 - - -# ============================================================================== -# ORZ Data Adapter and Dataset -# ============================================================================== - -class GameSegmentToORZAdapter: - """ - Convert PriorZero game_segments to ORZ-compatible format. - """ - - @staticmethod - def convert_segments_to_prompts(game_segments: List[Any], tokenizer) -> List[Dict]: - """ - Convert game_segments to ORZ prompt format. - - Args: - game_segments: List of GameSegment from PriorZero - tokenizer: HuggingFace tokenizer - - Returns: - List of ORZ-compatible prompt dictionaries - """ - prompts = [] - - for segment in game_segments: - # Extract raw observations if available - if hasattr(segment, 'raw_obs_segment') and segment.raw_obs_segment: - for i, (obs, action) in enumerate(zip( - segment.raw_obs_segment, - segment.action_segment - )): - # Create ORZ format prompt - prompt_dict = { - "prompt": [{"value": obs}], - "final_answer": action, - "file_name": f"segment_{id(segment)}_step_{i}" - } - prompts.append(prompt_dict) - - return prompts - - @staticmethod - def extract_training_data(game_segments: List[Any]) -> Dict[str, List]: - """ - Extract training data from game_segments for ORZ. - - Returns: - Dictionary containing: - - states: List of state descriptions - - actions: List of actions taken - - rewards: List of rewards received - - mcts_policies: List of MCTS visit distributions - """ - training_data = { - 'states': [], - 'actions': [], - 'rewards': [], - 'mcts_policies': [] - } - - for segment in game_segments: - # Extract raw observations (states) - if hasattr(segment, 'raw_obs_segment'): - training_data['states'].extend(segment.raw_obs_segment) - - # Extract actions - if hasattr(segment, 'action_segment'): - training_data['actions'].extend(segment.action_segment) - - # Extract rewards - if hasattr(segment, 'reward_segment'): - training_data['rewards'].extend(segment.reward_segment) - - # Extract MCTS policies - if hasattr(segment, 'mcts_policy_segment'): - training_data['mcts_policies'].extend(segment.mcts_policy_segment) - - return training_data - - -# Only define dataset classes if ORZ is available -if ORZ_AVAILABLE: - from jinja2 import Template - - class JerichoPromptDataset(PromptDataset): - """ - Custom dataset for Jericho text adventure games in ORZ format. - Adapts PriorZero game_segments to ORZ PPO training format. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def process_dialogue(self, dialogue: dict): - """ - Process a single dialogue (observation + action pair) into ORZ format. - - Args: - dialogue: Dict with 'prompt', 'final_answer', 'file_name' - - Returns: - prompt: Formatted prompt string - extra: Dict with answer and metadata - """ - # Template for Jericho text adventure prompts - prompt_template_jinja = """\ -{{bos_token}}A conversation between User and Assistant. The User is playing a text adventure game \ -and needs to decide the next action. The Assistant carefully analyzes the current game state, \ -considers the available actions, and recommends the best action to take. \ -The reasoning process is enclosed within tags, and the recommended action \ -is enclosed within tags. For example: \ - The player is in a dark room and needs light. The lamp is available. \ - take lamp . User: {{prompt}} -Assistant: \ -""" - - prompt_instruction_template_jinja = """\ -Current game state: -{{prompt}} - -What is the best action to take? Put your answer inside tags. -""" - - # Validate dialogue format - assert isinstance(dialogue, dict), "dialogue must be a dict" - assert "prompt" in dialogue, "dialogue must contain prompt" - assert "final_answer" in dialogue, "dialogue must contain final_answer" - - # Build prompt - prompt_instruction_template = Template(prompt_instruction_template_jinja) - prompt_instruction = prompt_instruction_template.render( - prompt=dialogue["prompt"][0]["value"] - ) - - prompt_template = Template(prompt_template_jinja) - if self.tokenizer.bos_token_id is None: - bos_token = "" - else: - bos_token = self.tokenizer.decode([self.tokenizer.bos_token_id]) - - prompt = prompt_template.render( - bos_token=bos_token, - prompt=prompt_instruction - ) - - extra = { - "answer": dialogue["final_answer"], - "file_name": dialogue.get("file_name", "unknown") - } - - return prompt, extra - - -# ============================================================================== -# Main Training Function -# ============================================================================== - -async def train_priorzero_orz_complete( - cfg: dict, - create_cfg: dict, - hybrid_cfg: HybridTrainingConfig, - seed: int = 0, - max_train_iter: int = 10000, - max_env_step: Optional[int] = int(1e10), - enable_save: bool = True, -): - """ - Main hybrid training function with complete ORZ integration. - """ - # ================================================================== - # 1. Compile Configuration - # ================================================================== - cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) - - # ================================================================== - # 2. Create vLLM Engine (optional) - Based on priorzero_entry.py - # ================================================================== - vllm_engine = None - - if hybrid_cfg.use_vllm and VLLM_AVAILABLE: - logger.info("Creating vLLM engine...") - - # [ROBUST FIX] Handle shared GPU environment - # Solution: Use alternative initialization method with fallback - tensor_parallel = cfg.policy.llm_policy_cfg.vllm_tensor_parallel_size - distributed_backend = "ray" if tensor_parallel > 1 else None - - # [ROBUST FIX] Lower GPU memory utilization in shared environment - gpu_mem_util = cfg.policy.llm_policy_cfg.gpu_memory_utilization - if gpu_mem_util > 0.85: - gpu_mem_util = 0.75 # More conservative - logger.info(f"✓ Adjusted GPU memory utilization to {gpu_mem_util} for stability") - - # [ROBUST FIX] Use vLLM V0 engine for stability (as in priorzero_entry.py) - use_v1_env = os.environ.get('VLLM_USE_V1', None) - if use_v1_env is None: - # Only set if not already set by user - os.environ['VLLM_USE_V1'] = '0' - logger.info("✓ Using vLLM V0 engine for stability") - - # Fix tokenizers parallelism warning - os.environ['TOKENIZERS_PARALLELISM'] = 'false' - - try: - from vllm.engine.arg_utils import AsyncEngineArgs - - engine_args = AsyncEngineArgs( - model=cfg.policy.llm_policy_cfg.pretrain_llm_path, - tensor_parallel_size=tensor_parallel, - gpu_memory_utilization=gpu_mem_util, - distributed_executor_backend=distributed_backend, - trust_remote_code=True, - enable_prefix_caching=False, - enforce_eager=False, - ) - vllm_engine = AsyncLLMEngine.from_engine_args(engine_args) - logger.info(f"✓ vLLM Engine created (backend: {distributed_backend or 'default'})") - - except (ValueError, RuntimeError) as e: - if "VLLM_USE_V1" in str(e) or "memory profiling" in str(e): - # Fallback: Try without V1 env var or with eager mode - logger.warning(f"⚠️ Initial vLLM initialization failed: {e}") - logger.info("Retrying with alternative configuration...") - - if 'VLLM_USE_V1' in os.environ: - del os.environ['VLLM_USE_V1'] - - try: - engine_args = AsyncEngineArgs( - model=cfg.policy.llm_policy_cfg.pretrain_llm_path, - tensor_parallel_size=tensor_parallel, - gpu_memory_utilization=gpu_mem_util * 0.9, # Even more conservative - distributed_executor_backend=distributed_backend, - trust_remote_code=True, - enable_prefix_caching=False, - enforce_eager=True, # Force eager mode as fallback - ) - vllm_engine = AsyncLLMEngine.from_engine_args(engine_args) - logger.info(f"✓ vLLM Engine created with fallback configuration") - except Exception as e2: - logger.error(f"❌ Failed to create vLLM engine with fallback: {e2}") - if hybrid_cfg.vllm_required: - raise - logger.warning("Continuing without vLLM (LLM prior will be disabled)") - else: - logger.error(f"❌ Failed to create vLLM engine: {e}") - import traceback - logger.error(f"Full traceback:\n{traceback.format_exc()}") - if hybrid_cfg.vllm_required: - raise - logger.warning("Continuing without vLLM (LLM prior will be disabled)") - else: - logger.info("vLLM disabled or not available - continuing without LLM inference") - - # ================================================================== - # 3. Create Environments - # ================================================================== - logger.info("Creating environments...") - env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) - - collector_env = create_env_manager( - cfg.env.manager, - [partial(env_fn, cfg=c) for c in collector_env_cfg] - ) - evaluator_env = create_env_manager( - cfg.env.manager, - [partial(env_fn, cfg=c) for c in evaluator_env_cfg] - ) - - # Seed environments - collector_env.seed(seed) - evaluator_env.seed(seed, dynamic_seed=False) - set_pkg_seed(seed, use_cuda=True) - logger.info(f"✓ Environments created and seeded (seed={seed})") - - # ================================================================== - # 4. Create Policy, Buffer, and Components - # ================================================================== - logger.info("Creating policy, buffer, and components...") - - # Create policy - policy = create_policy( - cfg.policy, - enable_field=['learn', 'collect', 'eval'] - ) - logger.info("✓ Policy created") - - # Create TensorBoard logger - os.makedirs(f'./{cfg.exp_name}/log/', exist_ok=True) - tb_logger = SummaryWriter( - os.path.join(f'./{cfg.exp_name}/log/', 'serial') - ) if get_rank() == 0 else None - logger.info(f"✓ TensorBoard logger: ./{cfg.exp_name}/log/") - - # Create learner (for world model training) - learner = BaseLearner( - cfg.policy.learn.learner, - policy.learn_mode, - tb_logger, - exp_name=cfg.exp_name - ) - logger.info("✓ BaseLearner created") - - # Create replay buffer - replay_buffer = PriorZeroGameBufferOptimized(cfg.policy) - logger.info("✓ PriorZero replay buffer created") - - # Create collector - collector = PriorZeroCollector( - env=collector_env, - policy=policy.collect_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - vllm_engine=vllm_engine, # May be None - policy_config=cfg.policy, - debug_mode=cfg.get('debug_mode', False), - ) - logger.info("✓ Collector created") - - # Create evaluator - evaluator = PriorZeroEvaluator( - eval_freq=cfg.policy.eval_freq, - n_evaluator_episode=cfg.env.n_evaluator_episode, - stop_value=cfg.env.stop_value, - env=evaluator_env, - policy=policy.eval_mode, - tb_logger=tb_logger, - exp_name=cfg.exp_name, - vllm_engine=vllm_engine, # May be None - ) - logger.info("✓ Evaluator created") - - # Call learner's before_run hook - learner.call_hook('before_run') - - # ================================================================== - # 5. Initialize ORZ Trainer (if available) - # ================================================================== - orz_trainer = None - orz_adapter = GameSegmentToORZAdapter() - orz_tokenizer = None - orz_strategy = None - - if hybrid_cfg.use_orz_trainer and ORZ_AVAILABLE: - logger.info("="*80) - logger.info("Initializing ORZ RayPPOTrainer for LLM training...") - logger.info("="*80) - - try: - # Initialize Ray if not already running - if not ray.is_initialized(): - ray.init(ignore_reinit_error=True) - logger.info("✓ Ray initialized") - - # Create ORZ tokenizer - orz_tokenizer = AutoTokenizer.from_pretrained( - cfg.policy.llm_policy_cfg.pretrain_llm_path, - trust_remote_code=True - ) - if orz_tokenizer.pad_token is None: - orz_tokenizer.pad_token = orz_tokenizer.eos_token - logger.info("✓ ORZ tokenizer created") - - # Create ORZ strategy (DeepSpeed config) - from orz.ppo.utils import get_strategy - orz_strategy = get_strategy({ - 'zero_stage': 2, - 'bf16': True, - 'gradient_checkpointing': True, - }) - logger.info("✓ ORZ strategy created") - - # Create ORZ configuration (matching ORZ's PPOExpConfig pattern) - from dataclasses import dataclass, field - from omegaconf.listconfig import ListConfig - - @dataclass - class ORZConfig: - """Simplified ORZ config for PriorZero integration""" - # Resource settings (simplified for single-node) - total_num_nodes: int = 1 - ref_num_nodes: int = 1 - ref_num_gpus_per_node: int = 1 - actor_num_nodes: int = 1 - actor_num_gpus_per_node: int = 1 - critic_num_nodes: int = 1 - critic_num_gpus_per_node: int = 1 - colocate_all: bool = True - colocate_critic_reward: bool = True - colocate_actor_ref: bool = True - vllm_num_engines: int = 1 - vllm_tensor_parallel_size: int = 1 - zero_stage: int = 2 - adam_offload: bool = False - - # Model paths - pretrain: str = cfg.policy.llm_policy_cfg.pretrain_llm_path - reward_pretrain: Optional[str] = None - critic_pretrain: Optional[str] = cfg.policy.llm_policy_cfg.pretrain_llm_path - - # Save/log paths - save_interval: int = 50 - ckpt_path: str = f'./{cfg.exp_name}/orz_ckpt' - save_path: str = f'./{cfg.exp_name}/orz_save' - tensorboard_log_dir: str = f'./{cfg.exp_name}/orz_log' - - # Training settings - actor_learning_rate: float = hybrid_cfg.orz_actor_lr if hasattr(hybrid_cfg, 'orz_actor_lr') else 1e-6 - critic_learning_rate: float = hybrid_cfg.orz_critic_lr if hasattr(hybrid_cfg, 'orz_critic_lr') else 5e-6 - num_warmup_steps: int = 50 - prompt_max_len: int = 2048 - enable_prefix_caching: bool = False - update_ref_every_epoch: bool = True - advantage_normalize: bool = True - - # Episode settings - num_episodes: int = hybrid_cfg.orz_num_episodes if hasattr(hybrid_cfg, 'orz_num_episodes') else 2 - rollout_batch_size: int = hybrid_cfg.orz_rollout_batch_size if hasattr(hybrid_cfg, 'orz_rollout_batch_size') else 32 - n_samples_per_prompt: int = 8 if DEBUG_MODE else 32 - micro_rollout_batch_size: int = 2 - policy_update_steps: int = 1 - critic_update_steps: int = 1 if DEBUG_MODE else 12 - micro_train_batch_size: int = 1 - micro_forward_batch_size: int = 1 - freezing_actor_steps: int = -1 - - # KL settings - init_kl_coef: float = 0 - kl_loss_coef: float = 0.0 - use_kl_loss: bool = False - use_kl_estimator_k3: bool = True - - # Eval settings - enable_eval: bool = False # Disable ORZ eval (use PriorZero's) - eval_interval: int = 100 - - # Generation settings - packing_max_len: int = 8192 - generate_max_len: int = cfg.policy.llm_policy_cfg.generate_max_len - max_len: int = 4096 - temperature: float = 1.0 - top_p: float = 1.0 - top_k: int = -1 - stop: ListConfig = field(default_factory=lambda: ListConfig([""])) - - # GRPO settings - use_grpo: bool = False - gamma: float = 1.0 - lambd: float = 1.0 - - # vLLM settings - gpu_memory_utilization: float = 0.3 - - # Custom settings for compute_reward_fn - use_compute_reward_fn: bool = True - use_orm_score: bool = False - - orz_cfg = ORZConfig() - - # Create directories for ORZ - os.makedirs(orz_cfg.ckpt_path, exist_ok=True) - os.makedirs(orz_cfg.save_path, exist_ok=True) - os.makedirs(orz_cfg.tensorboard_log_dir, exist_ok=True) - - logger.info("✓ ORZ config created") - logger.info(f" - Model: {orz_cfg.pretrain}") - logger.info(f" - Rollout batch: {orz_cfg.rollout_batch_size}") - logger.info(f" - Episodes: {orz_cfg.num_episodes}") - - # Note: Full RayPPOTrainer initialization requires: - # 1. Creating vLLM engines for distributed inference - # 2. Creating initial dataset from game_segments - # 3. Initializing Ray actors (will be done lazily on first training call) - # - # We defer full initialization until we have actual game_segments to train on - logger.info("✓ ORZ trainer components ready") - logger.info(" (Full RayPPOTrainer will be initialized on first training iteration)") - - except Exception as e: - logger.error(f"❌ ORZ trainer initialization failed: {e}") - import traceback - logger.error(traceback.format_exc()) - logger.warning("Falling back to PriorZero's built-in LLM training") - hybrid_cfg.use_orz_trainer = False - - # ================================================================== - # 6. Main Training Loop - # ================================================================== - logger.info("="*80) - logger.info("Starting PriorZero-ORZ Complete Training") - logger.info("="*80) - logger.info(f"Experiment: {cfg.exp_name}") - logger.info(f"Max iterations: {max_train_iter}") - logger.info(f"Training mode: {hybrid_cfg.wm_training_mode}") - logger.info(f"Use ORZ trainer: {hybrid_cfg.use_orz_trainer}") - logger.info(f"Use vLLM: {vllm_engine is not None}") - logger.info(f"LLM model: {cfg.policy.llm_policy_cfg.pretrain_llm_path}") - logger.info(f"World model: UniZero") - logger.info("="*80) - - # Training state - best_eval_reward = -float('inf') - total_game_segments_collected = 0 - - try: - while learner.train_iter < max_train_iter and collector.envstep < max_env_step: - current_iter = learner.train_iter - - # ============================================================== - # Step 1: Evaluation (if needed) - # ============================================================== - if current_iter > 0 and evaluator.should_eval(current_iter): - logger.info(f"\n{'='*60}") - logger.info(f"[Iter {current_iter}] Evaluating...") - logger.info(f"{'='*60}") - - eval_result = await evaluator.eval( - save_ckpt_fn=learner.save_checkpoint if enable_save else None, - train_iter=current_iter, - envstep=collector.envstep - ) - - if eval_result is not None: - stop, eval_reward_dict = eval_result - mean_reward = eval_reward_dict.get('reward_mean', 0) - logger.info(f"✓ Evaluation: reward_mean={mean_reward:.2f}") - - if mean_reward > best_eval_reward: - best_eval_reward = mean_reward - logger.info(f"🎯 New best reward: {best_eval_reward:.2f}") - - if stop: - logger.info(f"🎉 Training converged! (reward >= {cfg.env.stop_value})") - break - - # ============================================================== - # Step 2: Collect Data using MCTS - # ============================================================== - logger.info(f"\n[Iter {current_iter}] Collecting data...") - - collect_kwargs = { - 'temperature': 0.25, - 'epsilon': 0.0 - } - - try: - new_data = await collector.collect( - train_iter=current_iter, - policy_kwargs=collect_kwargs - ) - except Exception as e: - logger.error(f"❌ Collection failed: {e}") - logger.warning("Skipping this iteration...") - continue - - # Add to replay buffer - from lzero.entry.utils import calculate_update_per_collect - update_per_collect = calculate_update_per_collect(cfg, new_data, world_size=1) - - # Update buffer - replay_buffer.push_game_segments(new_data) - logger.info( - f"✓ Collected {len(new_data)} segments " - f"(total: {replay_buffer.get_num_of_game_segments()} segments, " - f"{replay_buffer.get_num_of_transitions()} transitions)" - ) - - total_game_segments_collected += len(new_data) - - # ============================================================== - # Step 3: World Model Training - # ============================================================== - if current_iter % hybrid_cfg.wm_train_freq == 0: - if replay_buffer.get_num_of_transitions() >= cfg.policy.batch_size: - logger.info(f"[Iter {current_iter}] Training world model...") - - # Sample and train - for _ in range(update_per_collect): - train_data = replay_buffer.sample( - cfg.policy.batch_size, - policy - ) - - # Train (includes both WM and LLM in PriorZero) - log_dict = learner.train(train_data, collector.envstep) - - # Log to TensorBoard - if tb_logger and get_rank() == 0: - for k, v in log_dict.items(): - tb_logger.add_scalar(f'train/{k}', v, collector.envstep) - - logger.info( - f"✓ WM training done - " - f"wm_loss: {log_dict.get('wm_total_loss', 0):.4f}, " - f"llm_sft_loss: {log_dict.get('llm_sft_loss', 0):.4f}" - ) - else: - logger.info(f"Skipping training - not enough data yet") - - # ============================================================== - # Step 4: LLM Training with ORZ (if enabled) - # ============================================================== - if (hybrid_cfg.use_orz_trainer and orz_trainer is not None and - current_iter % hybrid_cfg.llm_train_freq == 0 and - current_iter > 0): - logger.info(f"[Iter {current_iter}] Training LLM with ORZ...") - - try: - # Extract game_segments from recent collections - training_data = orz_adapter.extract_training_data(new_data) - num_samples = len(training_data['states']) - - if num_samples > 0: - logger.info(f" Extracted {num_samples} training samples for ORZ") - - # Initialize ORZ trainer on first use (lazy initialization) - if orz_trainer is None: - logger.info(" Initializing ORZ RayPPOTrainer...") - - # Convert game_segments to ORZ dataset format - dialogues = orz_adapter.convert_segments_to_prompts( - new_data, - orz_tokenizer - ) - - # Create ORZ dataset - orz_dataset = JerichoPromptDataset( - dialogues, - orz_tokenizer, - orz_cfg.prompt_max_len, - orz_strategy, - pretrain_mode=False, - num_processors=1 - ) - - # Create custom reward trainer - from orz.exps.examples.ppo.ppo_base_exp import BasePPOExp - - class JerichoRewardTrainer(RayPPOTrainer): - """Custom reward trainer for Jericho text adventures""" - - async def custom_reward_fn( - self, - prompts: List[str], - outputs: List[Any], - extras: List[dict], - reward_model_fn, - ): - """ - Compute rewards for Jericho actions. - Reward is 1.0 if action matches ground truth, else 0.0 - """ - import torch - scores = [] - responses = [] - - for output, extra in zip(outputs, extras): - response = output["response"] - responses.append(response) - - # Extract action from response - # Look for ... tags - import re - pattern = re.compile(r"(.*?)", re.DOTALL) - matches = re.findall(pattern, response) - predicted_action = matches[-1].strip() if matches else "" - - # Ground truth action - true_action = extra["answer"] - - # Simple exact match for now - # TODO: Could use fuzzy matching or LLM-based similarity - score = 1.0 if predicted_action.lower() == true_action.lower() else 0.0 - scores.append(score) - - # Log statistics - avg_score = sum(scores) / len(scores) if scores else 0.0 - logger.info(f" ORZ reward - avg: {avg_score:.3f}, samples: {len(scores)}") - - # Create score tensors (reward only on last token) - output_tokens = self._tokenize(responses, self.cfg.generate_max_len, padding=False)["input_ids"] - score_tensors = [] - for score, output_token in zip(scores, output_tokens): - score_tensor = torch.zeros(len(output_token)) - if len(output_token) > 0: - score_tensor[-1] = score - score_tensors.append(score_tensor) - - # Remove empty responses - res_prompts, res_responses, res_score_tensors = [], [], [] - for prompt, response, score_tensor in zip(prompts, responses, score_tensors): - if len(response) > 0: - res_prompts.append(prompt) - res_responses.append(response) - res_score_tensors.append(score_tensor) - - return res_prompts, res_responses, res_score_tensors - - # Create vLLM engines for ORZ - logger.info(" Creating vLLM inference engines for ORZ...") - from orz.exps.examples.ppo.ppo_base_exp import BasePPOExp - - # Use BasePPOExp helper to create engines - class TempExp(BasePPOExp): - def __init__(self): - self.cfg = orz_cfg - self.tokenizer = orz_tokenizer - self.strategy = orz_strategy - - temp_exp = TempExp() - vllm_engines = temp_exp.create_inference_engine() - logger.info(f" ✓ Created {len(vllm_engines)} vLLM engines") - - # Get colocate placement groups if needed - colocate_pg = temp_exp.get_colocate_pg if orz_cfg.colocate_all else None - - # Create ORZ trainer - orz_trainer = JerichoRewardTrainer( - cfg=orz_cfg, - strategy=orz_strategy, - tokenizer=orz_tokenizer, - train_dataset=orz_dataset, - eval_dataset=None, # No separate eval for now - vllm_engines=vllm_engines, - colocate_pg=colocate_pg - ) - - logger.info(" ✓ ORZ RayPPOTrainer initialized") - - # Run ORZ training for one episode - logger.info(f" Running ORZ PPO training (episode {current_iter // hybrid_cfg.llm_train_freq})...") - - # Train using ORZ's fit_episode method - # Note: This will do full PPO update with actor/critic training - await orz_trainer.fit_episode() - - logger.info(f" ✓ ORZ training completed for iteration {current_iter}") - - else: - logger.warning(" No training samples extracted from game_segments") - - except Exception as e: - logger.error(f" ✗ ORZ training failed: {e}") - import traceback - logger.error(traceback.format_exc()) - logger.warning(" Continuing with PriorZero LLM training only") - - # ============================================================== - # Step 5: Logging and Checkpointing - # ============================================================== - if current_iter % 10 == 0: - logger.info(f"\n{'='*60}") - logger.info(f"Progress Summary (Iter {current_iter})") - logger.info(f"{'='*60}") - logger.info(f"Env steps: {collector.envstep}") - logger.info(f"Game segments collected: {total_game_segments_collected}") - logger.info(f"Buffer size: {replay_buffer.get_num_of_transitions()} transitions") - logger.info(f"Best eval reward: {best_eval_reward:.2f}") - logger.info(f"{'='*60}\n") - - # Save checkpoint periodically - if enable_save and current_iter % 100 == 0 and current_iter > 0: - logger.info(f"[Iter {current_iter}] Saving checkpoint...") - learner.save_checkpoint(collector.envstep) - logger.info("✓ Checkpoint saved") - - except KeyboardInterrupt: - logger.info("\n⚠️ Training interrupted by user") - except Exception as e: - logger.error(f"\n❌ Training failed with error: {e}") - import traceback - traceback.print_exc() - raise - finally: - # ============================================================== - # Cleanup - # ============================================================== - logger.info("\nCleaning up...") - - # Save final checkpoint - if enable_save: - logger.info("Saving final checkpoint...") - try: - learner.save_checkpoint(collector.envstep) - except Exception as e: - logger.error(f"Failed to save checkpoint: {e}") - - # Close environments - try: - collector_env.close() - evaluator_env.close() - except Exception as e: - logger.error(f"Failed to close environments: {e}") - - # Close loggers - if tb_logger: - try: - tb_logger.close() - except Exception as e: - logger.error(f"Failed to close tensorboard: {e}") - - logger.info("✓ Cleanup complete") - logger.info("="*80) - logger.info("Training finished!") - logger.info(f"Total iterations: {learner.train_iter}") - logger.info(f"Total env steps: {collector.envstep}") - logger.info(f"Best eval reward: {best_eval_reward:.2f}") - logger.info("="*80) - - -# ============================================================================== -# Entry Point -# ============================================================================== - -async def main(): - """Main entry point.""" - # Create hybrid configuration - hybrid_cfg = HybridTrainingConfig() - - # Run training - await train_priorzero_orz_complete( - cfg=hybrid_cfg.priorzero_cfg, - create_cfg=hybrid_cfg.priorzero_create_cfg, - hybrid_cfg=hybrid_cfg, - seed=0, - max_train_iter=10000 if not DEBUG_MODE else 100, - enable_save=True, - ) - - -if __name__ == "__main__": - logger.info("="*80) - logger.info("PriorZero-ORZ Complete Training Pipeline") - logger.info("="*80) - logger.info(f"Debug mode: {DEBUG_MODE}") - logger.info(f"ORZ available: {ORZ_AVAILABLE}") - logger.info(f"vLLM available: {VLLM_AVAILABLE}") - logger.info("="*80) - - # Run async training - asyncio.run(main()) diff --git a/zoo/jericho/priorzero/priorzero_policy.py b/zoo/jericho/priorzero/priorzero_policy.py index 26e50e060..0403423a2 100644 --- a/zoo/jericho/priorzero/priorzero_policy.py +++ b/zoo/jericho/priorzero/priorzero_policy.py @@ -1,151 +1,36 @@ -# priorzero_policy.py -""" -[PRIORZERO] PriorZero Policy Implementation - -This module implements the PriorZero policy that combines: -1. UniZero world model for planning in latent space -2. LLM policy model for providing high-quality action priors - -Key Features: -- Dual-model training: world model + LLM policy -- LLM-guided MCTS: inject LLM priors into MCTS root node -- SFT + RFT: supervised fine-tuning with MCTS policies + reinforcement fine-tuning with environment rewards -- Full alignment with UniZero implementation - -Author: PriorZero Team -Date: 2025-01-20 -""" - +import asyncio import copy +import inspect import re import sys +import time +import cProfile import logging +from contextlib import contextmanager from pathlib import Path from typing import List, Dict, Any, Tuple, Union, Optional -# [CRITICAL] Ensure local LightZero is used -from ensure_local_lightzero import ensure_local_lightzero -ensure_local_lightzero() - import numpy as np import torch +import torch.distributed as dist import torch.nn.functional as F from ding.utils import POLICY_REGISTRY from ding.model import model_wrap from transformers import AutoTokenizer, AutoModelForCausalLM from peft import get_peft_model, LoraConfig, TaskType +import os # Import from local LightZero from lzero.policy.unizero import UniZeroPolicy as OriginalUniZeroPolicy -from lzero.policy import ( - phi_transform, - InverseScalarTransform, - scalar_transform, # [PRIORZERO] Added for reward/value transformation - DiscreteSupport, # [PRIORZERO] Added for categorical distribution support - to_torch_float_tensor, - mz_network_output_unpack -) +from lzero.policy import phi_transform, InverseScalarTransform, scalar_transform, DiscreteSupport +from lzero.policy import to_torch_float_tensor,mz_network_output_unpack, prepare_obs from lzero.policy.utils import select_action from lzero.mcts import UniZeroMCTSCtree as MCTSCtree from lzero.entry.utils import initialize_zeros_batch -# Import UniZeroModel to ensure it's registered in MODEL_REGISTRY -import lzero.model.unizero_model # noqa: F401 - - -# ============================================================================== -# Helper Functions for LLM Prior Processing -# ============================================================================== - -def parse_llm_action_ranking( - text: str, - action_map: Dict[str, int], - action_space_size: int, - fallback_to_uniform: bool = True -) -> np.ndarray: - """ - [PRIORZERO-NEW] - Parse LLM generated action ranking text into a policy distribution. - - Args: - text: LLM generated text with ranked actions (e.g., "1. take key\\n2. go north") - action_map: Mapping from action text to action index - action_space_size: Size of the action space - fallback_to_uniform: If True, return uniform distribution when no valid action found - - Returns: - policy: Probability distribution over actions (shape: [action_space_size]) - """ - # Extract ranked actions using regex - # Supports formats: "1. action", "1) action", "1: action" - ranked_actions = re.findall(r'(?:^|\n)\s*\d+[\.\):\s]+(.+?)(?=\n|$)', text, re.MULTILINE) - - policy = np.zeros(action_space_size, dtype=np.float32) - found_count = 0 - - for rank, action_text in enumerate(ranked_actions): - action_text = action_text.strip().lower() - - # Try exact match first - if action_text in action_map: - action_idx = action_map[action_text] - # Assign decreasing weights (higher rank = higher weight) - policy[action_idx] = len(ranked_actions) - rank - found_count += 1 - else: - # Try fuzzy matching (find best substring match) - best_match_score = 0 - best_action_idx = None - for candidate_text, candidate_idx in action_map.items(): - if candidate_text in action_text or action_text in candidate_text: - score = len(set(candidate_text.split()) & set(action_text.split())) - if score > best_match_score: - best_match_score = score - best_action_idx = candidate_idx - - if best_action_idx is not None: - policy[best_action_idx] = len(ranked_actions) - rank - found_count += 1 - - # Normalize to probability distribution - if policy.sum() > 0: - policy /= policy.sum() - elif fallback_to_uniform: - # If LLM didn't generate any valid actions, return uniform distribution - policy = np.ones(action_space_size, dtype=np.float32) / action_space_size - - return policy - - -def format_mcts_policy_to_text( - mcts_policy: np.ndarray, - action_inv_map: Dict[int, str], - top_k: int = 5 -) -> str: - """ - [PRIORZERO-NEW] - Convert MCTS policy vector into ranked action text for SFT training. - - Args: - mcts_policy: MCTS visit count distribution (shape: [action_space_size]) - action_inv_map: Mapping from action index to action text - top_k: Number of top actions to include - - Returns: - Formatted text with ranked actions (e.g., "1. take key\\n2. go north\\n...") - """ - # Sort actions by policy probability (descending) - sorted_indices = np.argsort(mcts_policy)[::-1] - - output_lines = [] - rank = 1 - for idx in sorted_indices: - if mcts_policy[idx] > 0 and rank <= top_k: - action_text = action_inv_map.get(idx, f"action_{idx}") - output_lines.append(f"{rank}. {action_text}") - rank += 1 - - return "\n".join(output_lines) if output_lines else "No valid actions found." +import lzero.model.unizero_model +from ding.utils import build_logger +from priorzero_utils import compute_approx_kl def build_llm_prompt( current_obs: str, @@ -155,7 +40,14 @@ def build_llm_prompt( ) -> str: """ [PRIORZERO-NEW] - Build a high-quality prompt for LLM to generate action ranking. + Build a high-quality prompt for LLM to generate the next action. + + When use_cot is True, the model should: + - First output its reasoning inside + - Then output the SINGLE best next action inside + + When use_cot is False, the model should: + - Output ONLY the SINGLE best next action inside Args: current_obs: Current observation text @@ -168,18 +60,19 @@ def build_llm_prompt( """ prompt_parts = [] - # System instruction prompt_parts.append( "You are an expert player in a text-based adventure game. " - "Your goal is to maximize the score by taking the best actions." + "Your goal is to maximize the score by choosing the best possible next action. " + "You must choose exactly ONE best next action." ) - - # Add history if available - if history and len(history) > 0: + if history is not None and len(history) > 0: + history = list(history) prompt_parts.append("\n=== Recent History ===") - for i, (obs, action, reward) in enumerate(history[-5:]): # Last 5 steps - prompt_parts.append(f"Step {i+1}:") - prompt_parts.append(f" Observation: {obs[:100]}...") # Truncate long obs + + for i, (obs, action, reward) in enumerate(history, start=1): + obs_str = obs + prompt_parts.append(f"Step {i}:") + prompt_parts.append(f" Observation: {obs_str}") prompt_parts.append(f" Action: {action}") prompt_parts.append(f" Reward: {reward}") @@ -187,31 +80,36 @@ def build_llm_prompt( prompt_parts.append("\n=== Current Situation ===") prompt_parts.append(current_obs) - # Task instruction + # Available actions (if provided) + if action_descriptions: + prompt_parts.append("\n=== Available Actions ===") + prompt_parts.append( + "You MUST choose the best action from the list below. " + "Do not invent actions that are not in this list." + ) + for action_text, desc in action_descriptions.items(): + # action_text: should match exactly the string we want inside ... + prompt_parts.append(f"- {action_text}: {desc}") + + # Task + output format if use_cot: + # CoT 模式:先 ,再 prompt_parts.append( "\n=== Task ===\n" - "Think step-by-step:\n" - "1. Analyze the current situation and your goal\n" - "2. Consider what actions might help you progress\n" - "3. Rank the best actions in order of priority\n" - "\nProvide your analysis and then list the top 5 actions in this format:\n" - "1. [first action]\n" - "2. [second action]\n" - "..." + "Analyze the recent history and the current situation, and decide on the SINGLE best next action.\n\n" + "OUTPUT FORMAT:\n" + "- First, write your detailed reasoning inside ....\n" + "- Then, on a new line, output ONLY the chosen action text inside ....\n" + "Example:\nyour step-by-step reasoning here\nthe best action text here\n\n" ) else: prompt_parts.append( "\n=== Task ===\n" - "List the top 5 best actions in order of priority:\n" - "1. [first action]\n" - "2. [second action]\n" - "..." + "Analyze the recent history and the current situation, and decide on the SINGLE best next action." + "Please keep the output concise, avoiding any other content.\n\n" ) - return "\n".join(prompt_parts) - # ============================================================================== # PriorZero Policy Class # ============================================================================== @@ -258,20 +156,26 @@ class PriorZeroPolicy(OriginalUniZeroPolicy): ), ) - def __init__(self, cfg: Dict, model: torch.nn.Module = None, enable_field: List[str] = None): + def __init__(self, cfg: Dict, model: torch.nn.Module = None, enable_field: List[str] = None, **kwargs): # [PRIORZERO-NEW] Initialize LLM-related attributes BEFORE super().__init__ - # because super().__init__ will call _init_learn which needs these attributes - self.llm_policy_model = None + # because super().__init__ will call _init_learn which needs these attributes self.llm_tokenizer = None - self._optimizer_llm = None self._lr_scheduler_llm = None + self._last_llm_grad_norm = 0.0 self.llm_policy_cfg = cfg.llm_policy_cfg # Set from cfg, not self._cfg (not set yet) + self.profile_cfg = getattr(cfg, 'profile_cfg', {}) + self._profile_enabled = bool(self.profile_cfg.get('enable_cprofile', False)) + self._profile_dir = f"./{kwargs['exp_name']}/log/profile" + self._profile_log_interval = int(self.profile_cfg.get('log_interval', 50)) + self._profile_stats = { 'train_world_model': {'count': 0, 'total': 0.0, 'max': 0.0}, + 'train_llm_sft': {'count': 0, 'total': 0.0, 'max': 0.0}, + 'train_llm_rft': {'count': 0, 'total': 0.0, 'max': 0.0} + } + self._profile_stats_file = f'{self._profile_dir}/train_time.log' + if self._profile_enabled: + os.makedirs(self._profile_dir, exist_ok=True) + self.vllm_engine = None - # Action mapping (will be set from config) - self.action_map = None # str -> int - self.action_inv_map = None # int -> str - - # Call parent init (this will trigger _init_learn, _init_collect, _init_eval) super().__init__(cfg, model, enable_field) def _init_learn(self) -> None: @@ -280,102 +184,37 @@ def _init_learn(self) -> None: Initialize both UniZero world model and LLM policy model with their optimizers. Align with UniZero implementation - use logging instead of self._logger. """ - import logging - - # ====================================================================== - # 1. Initialize UniZero World Model (from parent class) - # ====================================================================== super()._init_learn() logging.info("✓ UniZero World Model and optimizer initialized") - # [PRIORZERO-FIX] Ensure scalar transform handles are initialized - # These are normally initialized in UniZeroPolicy.__init__ but we need to ensure they exist - if not hasattr(self, 'value_support') or self.value_support is None: - self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) - if not hasattr(self, 'reward_support') or self.reward_support is None: - self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) - if not hasattr(self, 'value_inverse_scalar_transform_handle'): - self.value_inverse_scalar_transform_handle = InverseScalarTransform( - self.value_support, self._cfg.model.categorical_distribution - ) - if not hasattr(self, 'reward_inverse_scalar_transform_handle'): - self.reward_inverse_scalar_transform_handle = InverseScalarTransform( - self.reward_support, self._cfg.model.categorical_distribution - ) - logging.info("✓ Scalar transform handles verified/initialized") - - # ====================================================================== - # 2. [PRIORZERO-NEW] Initialize LLM Policy Model - # ====================================================================== - logging.info(f"Loading LLM from: {self.llm_policy_cfg.pretrain_llm_path}") - - # Load tokenizer - self.llm_tokenizer = AutoTokenizer.from_pretrained( - self.llm_policy_cfg.pretrain_llm_path, - trust_remote_code=True, - padding_side='left' # For batch generation - ) - if self.llm_tokenizer.pad_token is None: - self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token - - # Load LLM - self.llm_policy_model = AutoModelForCausalLM.from_pretrained( - self.llm_policy_cfg.pretrain_llm_path, - trust_remote_code=True, - torch_dtype=torch.bfloat16, # Use bfloat16 to save memory - device_map=None, # We'll manually move to device - ) - - # Apply LoRA if enabled - if self.llm_policy_cfg.use_lora: - logging.info("Applying LoRA for parameter-efficient fine-tuning") - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - r=self.llm_policy_cfg.lora_r, - lora_alpha=self.llm_policy_cfg.lora_alpha, - lora_dropout=self.llm_policy_cfg.lora_dropout, - target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Qwen-specific - ) - self.llm_policy_model = get_peft_model(self.llm_policy_model, lora_config) - self.llm_policy_model.print_trainable_parameters() - - self.llm_policy_model.to(self._cfg.device) - self.llm_policy_model.train() - - # ====================================================================== - # 3. [PRIORZERO-NEW] Initialize LLM Optimizer - # ====================================================================== - self._optimizer_llm = torch.optim.AdamW( - self.llm_policy_model.parameters(), - lr=self.llm_policy_cfg.llm_learning_rate, - weight_decay=self.llm_policy_cfg.llm_weight_decay, - betas=(0.9, 0.999), - ) - - # Optional: learning rate scheduler - self._lr_scheduler_llm = torch.optim.lr_scheduler.CosineAnnealingLR( - self._optimizer_llm, - T_max=100000, # Will be set from config - eta_min=self.llm_policy_cfg.llm_learning_rate * 0.1 - ) + @contextmanager + def _profile_block(self, name: str): + if not self._profile_enabled: + yield None + return + profiler = cProfile.Profile() + start_time = time.perf_counter() + profiler.enable() + try: + yield profiler + finally: + profiler.disable() + elapsed = time.perf_counter() - start_time + self._record_profile_time(name, elapsed) + + def _record_profile_time(self, name: str, elapsed: float) -> None: + log_every = max(1, self._profile_log_interval) + self._profile_stats[name]['count'] += 1 + self._profile_stats[name]['total'] += elapsed + self._profile_stats[name]['max'] = max(self._profile_stats[name]['max'], elapsed) + if self._profile_stats[name]['count'] % log_every == 0: + avg = self._profile_stats[name]['total'] / self._profile_stats[name]['count'] + with open(self._profile_stats_file, mode='a', encoding='utf-8') as f: + f.write( + f"{time.time():.3f}\tname={name}\tcount={self._profile_stats[name]['count']}\t" + f"total_s={self._profile_stats[name]['total']:.4f}\tavg_s={avg:.4f}\tmax_s={self._profile_stats[name]['max']:.4f}\n" + ) - logging.info(f"✓ LLM Policy Model ({self.llm_policy_cfg.pretrain_llm_path}) initialized") - logging.info(f" - LLM learning rate: {self.llm_policy_cfg.llm_learning_rate}") - logging.info(f" - LoRA enabled: {self.llm_policy_cfg.use_lora}") - - # ====================================================================== - # 4. [PRIORZERO-NEW] Load Action Mappings - # ====================================================================== - if hasattr(self._cfg, 'action_map') and self._cfg.action_map is not None: - self.action_map = self._cfg.action_map - self.action_inv_map = {v: k for k, v in self.action_map.items()} - logging.info(f"✓ Action mappings loaded ({len(self.action_map)} actions)") - else: - logging.warning("⚠ Action mappings not found in config. Will use index-based actions.") - # Fallback: create dummy mappings - action_space_size = self._cfg.model.action_space_size - self.action_inv_map = {i: f"action_{i}" for i in range(action_space_size)} - self.action_map = {v: k for k, v in self.action_inv_map.items()} def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: """ @@ -394,551 +233,74 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in Returns: log_dict: Dictionary of training metrics """ - import logging - self._learn_model.train() - self.llm_policy_model.train() - - # Unpack data - # NOTE: game_segments is our custom GameSegment with mcts_policy_segment - # [FIX] Handle both 3-element (from buffer) and 4-element (with explicit train_iter) formats - if len(data) == 4: - # Format: [current_batch, target_batch, train_iter, game_segments] - # This is when learner explicitly adds train_iter - current_batch, target_batch, train_iter, game_segments = data - elif len(data) == 3: - # Format: [current_batch, target_batch, game_segments] - # This is the standard format from PriorZeroGameBuffer.sample() - current_batch, target_batch, game_segments = data - train_iter = self._train_iteration # Get from instance variable - import logging - logger = logging.getLogger(__name__) - logger.debug( - f"[PRIORZERO] Using 3-element format. game_segments: " - f"{type(game_segments)}, count: {len(game_segments) if game_segments else 0}" - ) - else: - raise ValueError(f"Unexpected data format: expected 3 or 4 elements, got {len(data)}") + self._target_model.train() - # ============================================================================== - # Part 1: UniZero World Model Training (Full Implementation) - # ============================================================================== + current_batch, target_batch, train_iter = data - # Unpack batches - (obs_batch_ori, action_batch, mask_batch, batch_index_tensor, - weights, make_time) = current_batch[:6] + obs_batch_ori, action_batch, target_action_batch, mask_batch, batch_index_tensor, weights, make_time, timestep_batch, raw_obs_list, history_obs_list, action_logprob_list = current_batch target_reward, target_value, target_policy = target_batch + + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( + -1).long() + timestep_batch = torch.from_numpy(timestep_batch).to(self._cfg.device).unsqueeze( + -1).long() - # Handle optional timestep - if len(current_batch) > 6: - timestep_batch = current_batch[6] - else: - timestep_batch = None - - # Convert to tensors and move to device data_list = [mask_batch, target_reward, target_value, target_policy, weights] - (mask_batch, target_reward, target_value, - target_policy, weights) = to_torch_float_tensor(data_list, self._cfg.device) + (mask_batch, target_reward, target_value, target_policy, weights) = to_torch_float_tensor(data_list, self._cfg.device) - # Reshape targets batch_size = self._cfg.batch_size target_reward = target_reward.view(batch_size, -1) target_value = target_value.view(batch_size, -1) - # Apply scalar transform (for value and reward) - # [FIX] Use scalar_transform function (not self.scalar_transform) - # scalar_transform is a standalone function imported from lzero.policy transformed_target_reward = scalar_transform(target_reward) transformed_target_value = scalar_transform(target_value) # Convert to categorical distribution (for distributional RL) - target_reward_categorical = phi_transform( - self.reward_support, transformed_target_reward - ) - target_value_categorical = phi_transform( - self.value_support, transformed_target_value - ) - - # Prepare batch for world model - # NOTE: This follows the exact format required by UniZero world model - # [FIX] Convert obs_batch_ori to tensor if needed - if not isinstance(obs_batch_ori, torch.Tensor): - # [DEBUG] Check obs_batch_ori shape - import logging - logger = logging.getLogger(__name__) - if isinstance(obs_batch_ori, np.ndarray): - logger.info(f"[DEBUG] obs_batch_ori type: numpy, shape: {obs_batch_ori.shape}, dtype: {obs_batch_ori.dtype}") - - # [FIX] Reshape if observations are flattened (2D instead of 3D) - # Expected: [batch_size, num_unroll_steps+1, obs_dim] (buffer includes next_obs) - # Got: [batch_size, (num_unroll_steps+1) * obs_dim] - if len(obs_batch_ori.shape) == 2: - # Infer num_unroll_steps and obs_dim - # For text: obs_dim should be max_seq_len (e.g., 512) - obs_dim = 512 # Standard max_seq_len for BERT - total_size = obs_batch_ori.shape[1] - if total_size % obs_dim == 0: - inferred_steps = total_size // obs_dim - # Simply reshape to [batch_size, inferred_steps, obs_dim] - # The truncation to match action_batch will happen later (like unizero.py line 675) - obs_batch_ori = obs_batch_ori.reshape(batch_size, inferred_steps, obs_dim) - logger.info(f"[RESHAPE] Reshaped obs_batch_ori from (batch_size, {total_size}) to {obs_batch_ori.shape}") - else: - logger.warning(f"[RESHAPE_ERROR] Cannot reshape: total_size ({total_size}) not divisible by obs_dim ({obs_dim})") - - # Check if it's an object array (inhomogeneous shapes) - if obs_batch_ori.dtype == np.object_: - logger.warning(f"[SHAPE_ISSUE] obs_batch_ori is object array - inhomogeneous shapes!") - logger.warning(f"[SHAPE_ISSUE] First element shape: {obs_batch_ori[0].shape if len(obs_batch_ori) > 0 else 'N/A'}") - if len(obs_batch_ori) > 1: - logger.warning(f"[SHAPE_ISSUE] Second element shape: {obs_batch_ori[1].shape}") - # Try to handle inhomogeneous array by padding/truncating - # For now, just raise a descriptive error - raise ValueError( - f"obs_batch_ori has inhomogeneous shapes. " - f"First element shape: {obs_batch_ori[0].shape}, " - f"Cannot directly convert to tensor. " - f"This suggests the replay buffer is storing observations with different sequence lengths." - ) - obs_batch_ori = torch.from_numpy(obs_batch_ori).to(self._cfg.device) - - # [FIX] Convert action_batch to tensor and handle shape correctly - if not isinstance(action_batch, torch.Tensor): - action_batch = torch.from_numpy(action_batch).to(self._cfg.device) - - if action_batch.shape[-1] == 1: - actions_processed = action_batch.squeeze(-1).long() - elif len(action_batch.shape) == 1: - actions_processed = action_batch.long() - else: - actions_processed = action_batch.long() - - if timestep_batch is not None: - # Convert timestep_batch to tensor if needed - if not isinstance(timestep_batch, torch.Tensor): - timestep_batch = torch.from_numpy(timestep_batch).to(self._cfg.device) - - # Handle timestep_batch shape - if timestep_batch.shape[-1] == 1: - timestep_processed = timestep_batch.squeeze(-1).long() - elif len(timestep_batch.shape) == 1: - timestep_processed = timestep_batch.long() - else: - timestep_processed = timestep_batch.long() - - batch_for_gpt = { - 'observations': obs_batch_ori, - 'actions': actions_processed, - 'timestep': timestep_processed, - 'rewards': target_reward_categorical[:, :-1], - 'target_value': target_value_categorical[:, :-1], - 'target_policy': target_policy[:, :-1], - } - else: - batch_for_gpt = { - 'observations': obs_batch_ori, - 'actions': actions_processed, - 'rewards': target_reward_categorical[:, :-1], - 'target_value': target_value_categorical[:, :-1], - 'target_policy': target_policy[:, :-1], - } - - # [FIX] Following unizero.py lines 673-675 exactly: - # Convert mask_batch to boolean, then truncate to align with observations/rewards - batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data. Shape: (B, T) - - # [DEBUG] Log shapes before truncation - logger.info(f"[SHAPE_DEBUG] Before truncation: obs={batch_for_gpt['observations'].shape}, " - f"mask_padding={batch_for_gpt['mask_padding'].shape}, " - f"actions={batch_for_gpt['actions'].shape}") - - # [CRITICAL] Truncate observations to align with rewards/actions - # - observations from buffer include next_obs → shape (B, T+1, obs_dim) - # - mask_padding is already (B, T) from buffer - DO NOT truncate again! - # - After target processing: rewards[:, :-1] → (B, T-1) - # - So only observations need truncation - batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] # Shape: (B, T-1, obs_dim) - - # [FIX] Check if mask_padding needs truncation based on actual shape - if batch_for_gpt['mask_padding'].shape[1] > batch_for_gpt['observations'].shape[1]: - logger.warning(f"[SHAPE_FIX] Truncating mask_padding from {batch_for_gpt['mask_padding'].shape} to match obs") - batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] - - logger.info(f"[SHAPE_DEBUG] After truncation: obs={batch_for_gpt['observations'].shape}, " - f"mask_padding={batch_for_gpt['mask_padding'].shape}") - - # [FIX] Add missing 'ends' field (following unizero.py line 676) - # 'ends' marks terminal states in the trajectory (0 = not terminal) + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + batch_for_gpt = { + 'actions': action_batch.squeeze(-1), + 'timestep': timestep_batch.squeeze(-1), + 'rewards': target_reward_categorical[:, :-1], + 'target_value': target_value_categorical[:, :-1], + 'target_policy': target_policy[:, :-1], + } + if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size, -1, self._cfg.model.observation_shape) + elif len(self._cfg.model.observation_shape) == 3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size, -1, *self._cfg.model.observation_shape) + + batch_for_gpt['mask_padding'] = mask_batch == 1.0 + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) - - # [FIX] Add 'scalar_target_value' field for priority calculation (following unizero.py line 681) batch_for_gpt['scalar_target_value'] = target_value - # [FIX] Log shapes for debugging - import logging - logger = logging.getLogger(__name__) - logger.info(f"[BATCH_SHAPES] obs: {batch_for_gpt['observations'].shape}, actions: {batch_for_gpt['actions'].shape}, rewards: {batch_for_gpt['rewards'].shape}, mask_padding: {batch_for_gpt['mask_padding'].shape}") - - # Compute world model loss - wm_losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, - self._target_model.world_model.tokenizer, - self.value_inverse_scalar_transform_handle, - ) - - # Weighted world model loss (for prioritized experience replay) - wm_total_loss = (weights * wm_losses.loss_total).mean() - - # ============================================================================== - # Part 2: [PRIORZERO-NEW] LLM Policy Training (SFT + RFT) - # ============================================================================== + with self._profile_block(name="train_world_model"): + wm_losses, pred_values = self._learn_model.world_model.compute_loss( + batch_for_gpt, + self._target_model.world_model.tokenizer, + self.value_inverse_scalar_transform_handle, + ) - llm_sft_loss = torch.tensor(0.0, device=self._cfg.device) - llm_rft_loss = torch.tensor(0.0, device=self._cfg.device) - num_sft_samples = 0 - num_rft_samples = 0 + wm_total_loss = (weights * wm_losses.loss_total).mean() - # [FIX] Only perform LLM training if game_segments available - # [DEBUG] Always log game_segments status - logger = logging.getLogger(__name__) - logger.info(f"[LLM Training] game_segments type: {type(game_segments)}, " - f"is None: {game_segments is None}, " - f"len: {len(game_segments) if game_segments is not None else 'N/A'}") - - # [DEBUG] Check first segment's data - if game_segments is not None and len(game_segments) > 0: - seg0 = game_segments[0] - logger.info(f"[LLM Training] First segment stats: " - f"mcts_policies={len(seg0.mcts_policy_segment) if hasattr(seg0, 'mcts_policy_segment') else 0}, " - f"raw_obs={len([x for x in (seg0.raw_obs_segment if hasattr(seg0, 'raw_obs_segment') else []) if x is not None])}/{len(seg0.raw_obs_segment) if hasattr(seg0, 'raw_obs_segment') else 0}, " - f"actions={len(seg0.action_segment) if hasattr(seg0, 'action_segment') else 0}") - - if game_segments is not None and len(game_segments) > 0: - # Collect training data from game segments - sft_prompts = [] - sft_targets = [] - rft_prompts = [] - rft_rewards = [] - - # [DEBUG] Log segment information - logger.info(f"[LLM Training] Processing {len(game_segments)} game segments") - - for seg_idx, segment in enumerate(game_segments): - # [FIX] Use action_segment length, not obs_segment - # obs_segment includes frame_stack + unroll_steps, while - # mcts_policy_segment only has entries for actual actions taken - segment_length = len(segment.action_segment) - - # [FIX] Ensure mcts_policy_segment has the same length - # It might be a list or numpy array depending on whether game_segment_to_array() was called - mcts_policy_length = len(segment.mcts_policy_segment) if hasattr(segment, 'mcts_policy_segment') else 0 - - # [DEBUG] Log segment lengths for debugging - if self._cfg.get('debug_segment_processing', False): - obs_len = len(segment.obs_segment) if hasattr(segment, 'obs_segment') else 0 - raw_obs_len = len(segment.raw_obs_segment) if hasattr(segment, 'raw_obs_segment') else 0 - logging.info( - f"[Segment {seg_idx}] action_len={segment_length}, " - f"mcts_policy_len={mcts_policy_length}, obs_len={obs_len}, raw_obs_len={raw_obs_len}" - ) - - # [SAFETY] Use the minimum of the two lengths to avoid IndexError - max_index = min(segment_length, mcts_policy_length) - - if max_index == 0: - if self._cfg.get('debug_segment_processing', False): - logging.warning(f"[Segment {seg_idx}] Empty segment, skipping") - continue # Skip empty segments - - for i in range(max_index): - # [FIX] Safe access to mcts_policy_segment with bounds check - try: - mcts_policy = segment.mcts_policy_segment[i] - except (IndexError, KeyError, TypeError) as e: - # Log detailed error information for debugging - if self._cfg.get('debug_segment_processing', False): - logging.error( - f"[Segment {seg_idx}, Index {i}] Failed to access mcts_policy_segment: {e}\n" - f" segment_length={segment_length}, mcts_policy_length={mcts_policy_length}\n" - f" mcts_policy_segment type: {type(segment.mcts_policy_segment)}" - ) - continue - - # Skip if no MCTS policy available - if mcts_policy is None: - continue - - # [FIX] Use raw_obs_segment for text observations - # PriorZero's GameSegment stores raw text in raw_obs_segment - raw_obs_text = None - if hasattr(segment, 'raw_obs_segment') and i < len(segment.raw_obs_segment): - raw_obs_text = segment.raw_obs_segment[i] - elif i < len(segment.obs_segment): - # Fallback to obs_segment if raw_obs_segment not available - raw_obs_text = str(segment.obs_segment[i]) - - # Skip if raw_obs_text is None - if raw_obs_text is None: - continue - - # Build history context - history = [] - for j in range(max(0, i - self.llm_policy_cfg.history_length), i): - # [FIX] Use raw_obs_segment for history as well - obs_text = None - if hasattr(segment, 'raw_obs_segment') and j < len(segment.raw_obs_segment): - obs_text = segment.raw_obs_segment[j] - elif j < len(segment.obs_segment): - obs_text = str(segment.obs_segment[j]) - - if obs_text is not None and j < len(segment.action_segment): - history.append(( - obs_text, - self.action_inv_map.get(segment.action_segment[j], f"action_{segment.action_segment[j]}"), - float(segment.reward_segment[j]) if j < len(segment.reward_segment) else 0.0 - )) - - # Build prompt - instruction = build_llm_prompt( - current_obs=raw_obs_text, - history=history, - use_cot=self.llm_policy_cfg.use_cot - ) - - # Apply chat template - prompt = self.llm_tokenizer.apply_chat_template( - [{"role": "user", "content": instruction}], - tokenize=False, - add_generation_prompt=True - ) - - # ============================================================ - # SFT: Supervised Fine-Tuning with MCTS Policy - # ============================================================ - if self.llm_policy_cfg.sft_target == 'mcts_policy': - # [FIX] Use the mcts_policy we already safely retrieved above - # Don't access segment.mcts_policy_segment[i] again to avoid IndexError - mcts_policy_vec = mcts_policy - - # Convert MCTS policy to ranked action text - target_text = format_mcts_policy_to_text( - mcts_policy_vec, - self.action_inv_map, - top_k=5 - ) - - sft_prompts.append(prompt) - sft_targets.append(target_text) - num_sft_samples += 1 - - # ============================================================ - # RFT: Reinforcement Fine-Tuning with Environment Reward - # ============================================================ - if self.llm_policy_cfg.enable_rft and i < len(segment.reward_segment): - env_reward = float(segment.reward_segment[i]) - - # TODO - # Only use transitions with non-zero reward for RFT - if abs(env_reward) > 1e-9: - rft_prompts.append(prompt) - rft_rewards.append(env_reward) - num_rft_samples += 1 - - # ============================================================ - # Train LLM with SFT (with gradient accumulation for memory efficiency) - # ============================================================ - # num_sft_samples=0 # TODO - if num_sft_samples > 0: - # [PRIORZERO-OOM-FIX] Use micro-batching with gradient accumulation - micro_batch_size = self.llm_policy_cfg.llm_micro_batch_size - num_micro_batches = (num_sft_samples + micro_batch_size - 1) // micro_batch_size - accumulation_steps = self.llm_policy_cfg.llm_gradient_accumulation_steps - - # Prepare full texts (prompt + target + eos) - full_texts = [ - p + t + self.llm_tokenizer.eos_token - for p, t in zip(sft_prompts, sft_targets) - ] - - # Process in micro-batches - accumulated_sft_loss = 0.0 - for micro_batch_idx in range(num_micro_batches): - start_idx = micro_batch_idx * micro_batch_size - end_idx = min((micro_batch_idx + 1) * micro_batch_size, num_sft_samples) - - # Get micro-batch - micro_batch_texts = full_texts[start_idx:end_idx] - micro_batch_prompts = sft_prompts[start_idx:end_idx] - - # Tokenize micro-batch - inputs = self.llm_tokenizer( - micro_batch_texts, - padding=True, - truncation=True, - max_length=self.llm_policy_cfg.prompt_max_len, - return_tensors="pt" - ).to(self._cfg.device) - - # Create labels (mask prompt tokens to only compute loss on target) - labels = inputs.input_ids.clone() - labels[labels == self.llm_tokenizer.pad_token_id] = -100 - - # Mask prompt tokens - for i, prompt in enumerate(micro_batch_prompts): - prompt_tokens = self.llm_tokenizer.encode(prompt, add_special_tokens=False) - prompt_len = len(prompt_tokens) - labels[i, :prompt_len] = -100 - - # Forward pass - llm_outputs = self.llm_policy_model( - input_ids=inputs.input_ids, - attention_mask=inputs.attention_mask, - labels=labels - ) - - # Scale loss by number of accumulation steps (for correct gradient magnitude) - micro_batch_loss = llm_outputs.loss / accumulation_steps - accumulated_sft_loss += micro_batch_loss.item() - - # Backward pass (accumulate gradients) - micro_batch_loss.backward() - - # Free memory - del inputs, labels, llm_outputs - torch.cuda.empty_cache() - - # Average loss for logging - llm_sft_loss = torch.tensor(accumulated_sft_loss, device=self._cfg.device) - - # ============================================================ - # Train LLM with RFT (Policy Gradient with gradient accumulation) - # ============================================================ - if num_rft_samples > 0 and self.llm_policy_cfg.enable_rft: - # [PRIORZERO-OOM-FIX] Use micro-batching with gradient accumulation - micro_batch_size = self.llm_policy_cfg.llm_micro_batch_size - num_micro_batches = (num_rft_samples + micro_batch_size - 1) // micro_batch_size - accumulation_steps = self.llm_policy_cfg.llm_gradient_accumulation_steps - - # Process in micro-batches - accumulated_rft_loss = 0.0 - for micro_batch_idx in range(num_micro_batches): - start_idx = micro_batch_idx * micro_batch_size - end_idx = min((micro_batch_idx + 1) * micro_batch_size, num_rft_samples) - - # Get micro-batch - micro_batch_prompts = rft_prompts[start_idx:end_idx] - micro_batch_rewards = rft_rewards[start_idx:end_idx] - - # Tokenize prompts - inputs = self.llm_tokenizer( - micro_batch_prompts, - padding=True, - truncation=True, - max_length=self.llm_policy_cfg.prompt_max_len, - return_tensors="pt" - ).to(self._cfg.device) - - # [FIX] Forward pass WITH gradient tracking (remove no_grad) - outputs = self.llm_policy_model( - input_ids=inputs.input_ids, - attention_mask=inputs.attention_mask - ) - - # Compute policy gradient loss (REINFORCE) - # Loss = -reward * log_prob(action) - logits = outputs.logits - log_probs = F.log_softmax(logits, dim=-1) - - # Get log probability of actual tokens - shifted_log_probs = log_probs[:, :-1, :].contiguous() - shifted_labels = inputs.input_ids[:, 1:].contiguous() - - # Gather log probs of actual tokens - token_log_probs = shifted_log_probs.gather( - dim=-1, - index=shifted_labels.unsqueeze(-1) - ).squeeze(-1) - - # Mask padding tokens - mask = (shifted_labels != self.llm_tokenizer.pad_token_id).float() - token_log_probs = token_log_probs * mask - - # Sum log probs per sequence - sequence_log_probs = token_log_probs.sum(dim=-1) / (mask.sum(dim=-1) + 1e-8) - - # Compute REINFORCE loss for micro-batch - rewards_tensor = torch.tensor( - micro_batch_rewards, - device=self._cfg.device, - dtype=torch.float32 - ) - - # Normalize rewards within micro-batch (important for stable training) - if len(micro_batch_rewards) > 1: - rewards_tensor = (rewards_tensor - rewards_tensor.mean()) / (rewards_tensor.std() + 1e-8) - - micro_batch_rft_loss = -(rewards_tensor * sequence_log_probs).mean() / accumulation_steps - accumulated_rft_loss += micro_batch_rft_loss.item() - - # Backward pass (accumulate gradients) - micro_batch_rft_loss.backward() - - # Free memory - del inputs, outputs, logits, log_probs, rewards_tensor - torch.cuda.empty_cache() - - # Average loss for logging - llm_rft_loss = torch.tensor(accumulated_rft_loss, device=self._cfg.device) - - # ============================================================================== - # Part 3: Joint Optimization - # ============================================================================== - - # [PRIORZERO-OOM-FIX] Note: LLM gradients already accumulated via micro-batching above - # Only need to compute world model gradients here - - # Combine losses (for logging only - LLM loss already backpropagated) - llm_loss = ( - self.llm_policy_cfg.llm_loss_weight * llm_sft_loss + - self.llm_policy_cfg.rft_loss_weight * llm_rft_loss - ) - total_loss = wm_total_loss + llm_loss # For logging - - # Zero world model gradients only (LLM gradients already accumulated) self._optimizer_world_model.zero_grad() - - # Backward pass for world model only wm_total_loss.backward() - - # Gradient clipping for both models wm_grad_norm = torch.nn.utils.clip_grad_norm_( self._learn_model.world_model.parameters(), self._cfg.grad_clip_value ) - llm_grad_norm = torch.nn.utils.clip_grad_norm_( - self.llm_policy_model.parameters(), - self._cfg.grad_clip_value - ) - - # Optimizer step for both models + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) self._optimizer_world_model.step() - self._optimizer_llm.step() # Apply accumulated LLM gradients - - # Zero LLM gradients after step (ready for next iteration) - self._optimizer_llm.zero_grad() - - # Learning rate scheduler step (optional) - if self._lr_scheduler_llm is not None: - self._lr_scheduler_llm.step() - - # Update target model (soft update) self._target_model.update(self._learn_model.state_dict()) - # ============================================================================== - # Part 4: Logging (Aligned with UniZero) - # ============================================================================== - - # Extract intermediate losses from world model (like UniZero) intermediate_losses = wm_losses.intermediate_losses obs_loss = intermediate_losses.get('loss_obs', torch.tensor(0.0)) reward_loss = intermediate_losses.get('loss_rewards', torch.tensor(0.0)) @@ -952,15 +314,6 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in middle_step_losses = intermediate_losses.get('middle_step_losses', {}) last_step_losses = intermediate_losses.get('last_step_losses', {}) - # Analysis metrics (dormant ratio, weight magnitude, etc.) - dormant_ratio_encoder = intermediate_losses.get('dormant_ratio_encoder', 0.0) - dormant_ratio_transformer = intermediate_losses.get('dormant_ratio_transformer', 0.0) - dormant_ratio_head = intermediate_losses.get('dormant_ratio_head', 0.0) - avg_weight_mag_encoder = intermediate_losses.get('avg_weight_mag_encoder', 0.0) - avg_weight_mag_transformer = intermediate_losses.get('avg_weight_mag_transformer', 0.0) - avg_weight_mag_head = intermediate_losses.get('avg_weight_mag_head', 0.0) - e_rank_last_linear = intermediate_losses.get('e_rank_last_linear', 0.0) - e_rank_sim_norm = intermediate_losses.get('e_rank_sim_norm', 0.0) latent_state_l2_norms = intermediate_losses.get('latent_state_l2_norms', torch.tensor(0.0)) latent_action_l2_norms = intermediate_losses.get('latent_action_l2_norms', 0.0) @@ -989,16 +342,16 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Build comprehensive log dict (aligned with UniZero) log_dict = { # ============ Core Losses ============ - 'weighted_total_loss': wm_total_loss.item(), - 'obs_loss': obs_loss.item() if torch.is_tensor(obs_loss) else obs_loss, - 'reward_loss': reward_loss.item() if torch.is_tensor(reward_loss) else reward_loss, - 'policy_loss': policy_loss.item() if torch.is_tensor(policy_loss) else policy_loss, - 'value_loss': value_loss.item() if torch.is_tensor(value_loss) else value_loss, - 'latent_recon_loss': latent_recon_loss.item() if torch.is_tensor(latent_recon_loss) else latent_recon_loss, - 'perceptual_loss': perceptual_loss.item() if torch.is_tensor(perceptual_loss) else perceptual_loss, - 'orig_policy_loss': orig_policy_loss.item() if torch.is_tensor(orig_policy_loss) else orig_policy_loss, - 'policy_entropy': policy_entropy.item() if torch.is_tensor(policy_entropy) else policy_entropy, - 'target_policy_entropy': average_target_policy_entropy.item(), + 'wm_total_loss': wm_total_loss.item(), + 'wm_obs_loss': obs_loss.item() if torch.is_tensor(obs_loss) else obs_loss, + 'wm_reward_loss': reward_loss.item() if torch.is_tensor(reward_loss) else reward_loss, + 'wm_policy_loss': policy_loss.item() if torch.is_tensor(policy_loss) else policy_loss, + 'wm_value_loss': value_loss.item() if torch.is_tensor(value_loss) else value_loss, + 'wm_latent_recon_loss': latent_recon_loss.item() if torch.is_tensor(latent_recon_loss) else latent_recon_loss, + 'wm_perceptual_loss': perceptual_loss.item() if torch.is_tensor(perceptual_loss) else perceptual_loss, + 'wm_orig_policy_loss': orig_policy_loss.item() if torch.is_tensor(orig_policy_loss) else orig_policy_loss, + 'wm_policy_entropy': policy_entropy.item() if torch.is_tensor(policy_entropy) else policy_entropy, + 'wm_target_policy_entropy': average_target_policy_entropy.item(), # ============ Step-wise Losses ============ @@ -1018,14 +371,6 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'analysis/last_step_loss_obs': last_step_losses.get('loss_obs', torch.tensor(0.0)).item() if isinstance(last_step_losses.get('loss_obs'), torch.Tensor) else 0.0, # ============ Analysis Metrics ============ - 'analysis/dormant_ratio_encoder': dormant_ratio_encoder, - 'analysis/dormant_ratio_transformer': dormant_ratio_transformer, - 'analysis/dormant_ratio_head': dormant_ratio_head, - 'analysis/avg_weight_mag_encoder': avg_weight_mag_encoder, - 'analysis/avg_weight_mag_transformer': avg_weight_mag_transformer, - 'analysis/avg_weight_mag_head': avg_weight_mag_head, - 'analysis/e_rank_last_linear': e_rank_last_linear, - 'analysis/e_rank_sim_norm': e_rank_sim_norm, 'analysis/latent_state_l2_norms': latent_state_l2_norms.item() if torch.is_tensor(latent_state_l2_norms) else latent_state_l2_norms, 'analysis/latent_action_l2_norms': latent_action_l2_norms, @@ -1043,66 +388,23 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'temperature_policy': temperature_policy, # ============ Targets ============ - 'target_reward': target_reward.mean().item(), - 'target_value': target_value.mean().item(), + 'wm_target_reward': target_reward.mean().item(), + 'wm_target_value': target_value.mean().item(), 'transformed_target_reward': transformed_target_reward.mean().item(), 'transformed_target_value': transformed_target_value.mean().item(), 'value_priority': value_priority_np.mean().item(), 'value_priority_orig': value_priority_np, # ============ Gradient Norms ============ - 'total_grad_norm_before_clip_wm': wm_grad_norm.item(), - 'llm_grad_norm': llm_grad_norm.item(), + 'wm_grad_norm': wm_grad_norm.item(), + 'llm_grad_norm': self._last_llm_grad_norm, # ============ Learning Rates ============ 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], - 'llm_lr': self._optimizer_llm.param_groups[0]['lr'], - - # ============ [PRIORZERO] LLM-specific Metrics ============ - 'llm_sft_loss': llm_sft_loss.item(), - 'llm_rft_loss': llm_rft_loss.item(), - 'llm_total_loss': llm_loss.item(), - 'num_sft_samples': float(num_sft_samples), - 'num_rft_samples': float(num_rft_samples), - 'total_loss': total_loss.item(), } - # ============================================================================== - # [PRIORZERO-NEW] WandB Logging (if enabled) - # ============================================================================== - if self._cfg.get('use_wandb', False): - try: - import wandb - if wandb.run is not None: - # Log all metrics to WandB with hierarchical naming - wandb.log({ - # World Model Metrics - 'train/wm/total_loss': log_dict['wm_total_loss'], - 'train/wm/value_loss': log_dict['wm_value_loss'], - 'train/wm/policy_loss': log_dict['wm_policy_loss'], - 'train/wm/reward_loss': log_dict['wm_reward_loss'], - 'train/wm/grad_norm': log_dict['wm_grad_norm'], - 'train/wm/learning_rate': log_dict['wm_lr'], - - # LLM Policy Metrics - 'train/llm/sft_loss': log_dict['llm_sft_loss'], - 'train/llm/rft_loss': log_dict['llm_rft_loss'], - 'train/llm/total_loss': log_dict['llm_total_loss'], - 'train/llm/grad_norm': log_dict['llm_grad_norm'], - 'train/llm/learning_rate': log_dict['llm_lr'], - 'train/llm/num_sft_samples': float(log_dict['num_sft_samples']), - 'train/llm/num_rft_samples': float(log_dict['num_rft_samples']), - - # Combined Metrics - 'train/total_loss': log_dict['total_loss'], - }, step=self._train_iteration) - except Exception as e: - # Don't fail training if wandb logging fails - import logging - logging.warning(f"WandB logging failed: {e}") - return log_dict - + def _monitor_vars_learn(self) -> List[str]: """ [PRIORZERO-MODIFIED] @@ -1115,60 +417,31 @@ def _monitor_vars_learn(self) -> List[str]: """ return [ - # ============ LLM Loss Metrics ============ + # ============ LLM Loss Metrics ============ 'llm_sft_loss', # Supervised fine-tuning loss 'llm_rft_loss', # Reinforcement fine-tuning loss 'llm_total_loss', # Combined LLM loss 'llm_grad_norm', # LLM gradient norm 'llm_lr', # LLM learning rate - + 'rft_logprob_mean', + 'rft_seq_neglogprob_mean', + 'rft_advantage_mean', + 'rft_advantage_std', + 'rft_ratio_used_mean', + 'rft_kl_mean', # ============ LLM Training Statistics ============ - 'num_sft_samples', # Number of SFT samples in batch - 'num_rft_samples', # Number of RFT samples in batch - + # 'num_sft_samples', # Number of SFT samples in batch + # 'num_rft_samples', # Number of RFT samples in batch # ============ Combined Metrics ============ 'total_loss', # Total loss (WM + LLM) 'wm_total_loss', # World model total loss 'wm_grad_norm', # World model gradient norm - 'wm_lr', # World model learning rate - # ============ World Model Component Losses ============ 'wm_value_loss', 'wm_policy_loss', 'wm_reward_loss', 'wm_obs_loss', - 'analysis/dormant_ratio_encoder', - 'analysis/dormant_ratio_transformer', - 'analysis/dormant_ratio_head', - - 'analysis/avg_weight_mag_encoder', - 'analysis/avg_weight_mag_transformer', - 'analysis/avg_weight_mag_head', - 'analysis/e_rank_last_linear', - 'analysis/e_rank_sim_norm', - - 'analysis/latent_state_l2_norms', - 'analysis/l2_norm_before', - 'analysis/l2_norm_after', - 'analysis/grad_norm_before', - 'analysis/grad_norm_after', - - 'analysis/first_step_loss_value', - 'analysis/first_step_loss_policy', - 'analysis/first_step_loss_rewards', - 'analysis/first_step_loss_obs', - - 'analysis/middle_step_loss_value', - 'analysis/middle_step_loss_policy', - 'analysis/middle_step_loss_rewards', - 'analysis/middle_step_loss_obs', - - 'analysis/last_step_loss_value', - 'analysis/last_step_loss_policy', - 'analysis/last_step_loss_rewards', - 'analysis/last_step_loss_obs', - 'adaptive_alpha', "adaptive_target_entropy_ratio", 'alpha_loss', @@ -1179,62 +452,53 @@ def _monitor_vars_learn(self) -> List[str]: 'collect_mcts_temperature', 'cur_lr_world_model', 'cur_lr_tokenizer', - - 'weighted_total_loss', - 'obs_loss', - 'policy_loss', - 'orig_policy_loss', - 'policy_entropy', - 'latent_recon_loss', - 'target_policy_entropy', - 'reward_loss', - 'value_loss', + + 'wm_orig_policy_loss', + 'wm_policy_entropy', + 'wm_latent_recon_loss', + 'wm_target_policy_entropy', 'consistency_loss', 'value_priority', - 'target_reward', - 'target_value', + 'wm_target_reward', + 'wm_target_value', 'total_grad_norm_before_clip_wm', # tokenizer 'commitment_loss', 'reconstruction_loss', - 'perceptual_loss', - - - "logits_value_mean", - "logits_value_max", - "logits_value_min", - "logits_policy_mean", - "logits_policy_max", - "logits_policy_min", - - "temperature_value", - "temperature_reward", - "temperature_policy", - "current_policy_label_eps", - 'adaptive_alpha', - "adaptive_target_entropy_ratio", + 'wm_perceptual_loss', + + "logits_value_mean", + "logits_value_max", + "logits_value_min", + "logits_policy_mean", + "logits_policy_max", + "logits_policy_min", + + "temperature_value", + "temperature_reward", + "temperature_policy", + "current_policy_label_eps", + 'adaptive_alpha', + "adaptive_target_entropy_ratio", 'alpha_loss', "current_encoder_clip_value", - - # ==================== [新增] 添加范数和中间张量监控变量 ==================== - # 模块总范数 - 'norm/encoder/_total_norm', - 'norm/transformer/_total_norm', - 'norm/head_value/_total_norm', - 'norm/head_reward/_total_norm', - 'norm/head_policy/_total_norm', - # 中间张量 x 的统计信息 - 'norm/x_token/mean', - 'norm/x_token/std', - 'norm/x_token/max', - 'norm/x_token/min', ] - # 注意:我们不把每一层的范数都加到这里,因为数量太多会导致日志混乱。 - # 在实践中,如果通过总范数发现问题,可以临时在TensorBoard中搜索特定层的范数, - # 或者在本地打印 `norm_log_dict` 来进行详细分析。 - # wandb等工具可以更好地处理大量的动态指标。 # ======================================================================== + def pad_to_fixed_length(self, data, target_len=55, pad_val=-1e9, dtype=torch.float32): + """ + data: List[Sequence[Number]],每个元素长度可以不一样(比如 3 或 4) + 返回: tensor, 形状 [B, target_len],多余部分全是 pad_val + """ + batch_size = len(data) + out = torch.full((batch_size, target_len), pad_val, dtype=dtype) + for i, seq in enumerate(data): + if isinstance(seq, np.ndarray): + seq = seq.tolist() + L = min(len(seq), target_len) + if L > 0: + out[i, :L] = torch.tensor(seq[:L], dtype=dtype) + return out def _forward_collect( self, @@ -1244,6 +508,7 @@ def _forward_collect( to_play: List[int] = None, epsilon: float = 0.0, ready_env_id: List[int] = None, + timestep: List = [0], **kwargs ) -> Dict[int, Dict[str, Any]]: """ @@ -1272,161 +537,85 @@ def _forward_collect( """ self._collect_model.eval() - # ====================================================================== - # [PRIORZERO-NEW] Get LLM Prior Outputs - # ====================================================================== - llm_prior_outputs = kwargs.pop('llm_prior_outputs', None) + llm_prior_logprob = kwargs.pop('llm_prior_logprob', None) + valid_actions_list = kwargs.get('valid_actions_list', None) - if llm_prior_outputs is None: - # If no LLM prior available, fall back to standard UniZero behavior + if llm_prior_logprob is None: logging.debug("No LLM priors provided, using standard UniZero MCTS") return super()._forward_collect( data, action_mask, temperature, to_play, epsilon, - ready_env_id=ready_env_id, **kwargs + ready_env_id=ready_env_id, timestep=timestep ) - - # ====================================================================== - # Parse LLM Outputs into Policy Priors - # ====================================================================== + policy_priors = [] - for output in llm_prior_outputs: - # Extract generated text - generated_text = output.outputs[0].text if hasattr(output, 'outputs') else str(output) - - # Parse into policy distribution - prior_policy = parse_llm_action_ranking( - generated_text, - self.action_map, - self._cfg.model.action_space_size, - fallback_to_uniform=True - ) - - # Convert to log probabilities (for compatibility with MCTS) - policy_logits = torch.log(torch.from_numpy(prior_policy) + 1e-9) - policy_priors.append(policy_logits) - - policy_priors = torch.stack(policy_priors).to(self._cfg.device) - + for idx, actions in enumerate(valid_actions_list): + prior = [] + for action in actions: + prior.append(llm_prior_logprob[idx][action]) + policy_priors.append(prior) + policy_priors = self.pad_to_fixed_length(data=policy_priors, target_len=self.cfg.model.action_space_size, pad_val=-1e9) # ====================================================================== # World Model Initial Inference # ====================================================================== + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} with torch.no_grad(): - # Run representation network to get latent state - network_output = self._collect_model.initial_inference(data) + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - # Unpack network outputs - latent_state_roots, reward_roots, pred_values, policy_logits_roots = \ - mz_network_output_unpack(network_output) - - # [PRIORZERO-KEY] Replace policy logits with LLM priors network_output.policy_logits = policy_priors - - # Prepare for MCTS if not self._cfg.mcts_ctree: - # Python implementation (not recommended for performance) raise NotImplementedError("Python MCTS not supported for PriorZero") # ====================================================================== # MCTS Search with LLM-Guided Priors # ====================================================================== - # This is the key part where LLM priors guide the search - - # [FIX] Align with UniZero: construct legal_actions from action_mask - active_collect_env_num = len(ready_env_id) - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] - for j in range(active_collect_env_num)] - - # Get timestep if available - timestep = kwargs.get('timestep', None) - - # [FIX] Align with UniZero: transform values and prepare data pred_values_np = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots_np = latent_state_roots.detach().cpu().numpy() - # reward_roots_np = reward_roots.detach().cpu().numpy() - policy_logits_for_mcts = policy_priors.detach().cpu().numpy().tolist() - - # [FIX] Align with UniZero: Create MCTS roots with legal_actions (not action_space_size) - roots = MCTSCtree.roots(active_collect_env_num, legal_actions) - - # [FIX] Align with UniZero: noises based on number of valid actions per environment + policy_logits = policy_priors.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] noises = [ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() - for j in range(active_collect_env_num) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) ] + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots_np, to_play, timestep=timestep) - # [FIX] Align with UniZero: prepare roots (note reward_roots_np, not list(pred_values_np)) - roots.prepare( - self._cfg.root_noise_weight, - noises, - reward_roots, - # reward_roots_np, - policy_logits_for_mcts, - to_play if to_play is not None else [-1] * active_collect_env_num, - ) - - # Run MCTS search - MCTSCtree(self._cfg).search( - roots, - self._collect_model, - latent_state_roots_np, - reward_roots, - to_play if to_play is not None else [-1] * latent_state_roots_np.shape[0], - ) - - # Extract search results roots_visit_count = roots.get_distributions() roots_values = roots.get_values() - # ====================================================================== - # [PRIORZERO] Get valid_actions_list for dynamic action mapping - # ====================================================================== - valid_actions_list = kwargs.get('valid_actions_list', None) - - # ====================================================================== - # Select Actions and Prepare Output (Aligned with UniZero) - # ====================================================================== - output = {} - + batch_action = [] for i, env_id in enumerate(ready_env_id): - # [FIX] Get visit count distribution (only contains legal actions) distributions = roots_visit_count[i] value = roots_values[i] - # [FIX] Use select_action from UniZero (aligns with UniZero line 1115-1117) - # NOTE: Only legal actions possess visit counts, so action_index_in_legal_action_set - # represents the index within the legal action set, not the entire action set action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( distributions, - temperature=temperature if temperature is not None else self._collect_mcts_temperature, + temperature=self._collect_mcts_temperature, deterministic=False ) - # [FIX] Convert action_index_in_legal_action_set to the actual action in full action space - # (aligns with UniZero line 1119) legal_action_indices = np.where(action_mask[i] == 1.0)[0] action = legal_action_indices[action_index_in_legal_action_set] - # [PRIORZERO] Create dynamic action_inv_map for this specific state - # This maps action_index -> action_text using the current state's valid_actions - if valid_actions_list is not None and i < len(valid_actions_list): - dynamic_action_inv_map = { - idx: act_text - for idx, act_text in enumerate(valid_actions_list[i]) - } - else: - # Fallback to static mapping if valid_actions not available - dynamic_action_inv_map = self.action_inv_map - output[env_id] = { 'action': int(action), 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, 'searched_value': value, 'predicted_value': pred_values_np[i], - 'dynamic_action_inv_map': dynamic_action_inv_map, # [PRIORZERO] Include dynamic mapping + 'predicted_policy_logits': policy_logits[i], + 'timestep': timestep[i], } - + batch_action.append(action) + self.last_batch_obs = data + self.last_batch_action = batch_action return output def _state_dict_learn(self) -> Dict[str, Any]: @@ -1436,13 +625,6 @@ def _state_dict_learn(self) -> Dict[str, Any]: """ state_dict = super()._state_dict_learn() - # Add LLM model and optimizer - state_dict['llm_model'] = self.llm_policy_model.state_dict() - state_dict['optimizer_llm'] = self._optimizer_llm.state_dict() - - if self._lr_scheduler_llm is not None: - state_dict['lr_scheduler_llm'] = self._lr_scheduler_llm.state_dict() - return state_dict def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: @@ -1451,16 +633,4 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: Load state dict for both world model and LLM. """ super()._load_state_dict_learn(state_dict) - - # Load LLM model and optimizer - if 'llm_model' in state_dict: - self.llm_policy_model.load_state_dict(state_dict['llm_model']) - logging.info("✓ LLM model state loaded") - - if 'optimizer_llm' in state_dict: - self._optimizer_llm.load_state_dict(state_dict['optimizer_llm']) - logging.info("✓ LLM optimizer state loaded") - - if 'lr_scheduler_llm' in state_dict and self._lr_scheduler_llm is not None: - self._lr_scheduler_llm.load_state_dict(state_dict['lr_scheduler_llm']) - logging.info("✓ LLM scheduler state loaded") + diff --git a/zoo/jericho/priorzero/priorzero_prompts.py b/zoo/jericho/priorzero/priorzero_prompts.py index 4ce7ef787..4c5830575 100644 --- a/zoo/jericho/priorzero/priorzero_prompts.py +++ b/zoo/jericho/priorzero/priorzero_prompts.py @@ -1,19 +1,3 @@ -""" -PriorZero LLM Prompts Module - -This module provides optimized prompt templates for PriorZero's LLM policy, -based on the successful prompt structure from Open-Reasoner-Zero. - -Key Features: -- Structured reasoning with and tags -- Clear role definitions (User/Assistant paradigm) -- Explicit format examples to guide the LLM -- Game-specific context integration - -Author: PriorZero Team -Date: 2025-10-21 -""" - from jinja2 import Template from typing import List, Dict, Any, Optional diff --git a/zoo/jericho/priorzero/priorzero_utils.py b/zoo/jericho/priorzero/priorzero_utils.py new file mode 100644 index 000000000..6e60afef1 --- /dev/null +++ b/zoo/jericho/priorzero/priorzero_utils.py @@ -0,0 +1,38 @@ +import torch + + +def compute_approx_kl( + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + kl_estimator: str = "k1", +) -> torch.Tensor: + """ + Compute the approximate KL divergence between two distributions. + Schulman blog: http://joschu.net/blog/kl-approx.html + + Args: + log_probs: Log probabilities of the new distribution. + log_probs_base: Log probabilities of the base distribution. + """ + + if kl_estimator == "k1": + log_ratio = log_probs.float() - log_probs_base.float() + + # The k2 estimator is the non negative kl approximation in + # http://joschu.net/blog/kl-approx.html + # The k2_loss is approximately equivalent to the + # one-step KL divergence penalty with the k1 estimator + # used in https://arxiv.org/pdf/2310.10505. + if kl_estimator == "k2": + log_ratio = log_probs.float() - log_probs_base.float() + log_ratio = log_ratio**2 / 2.0 + + # The k3 estimator is the non negative kl approximation in + # http://joschu.net/blog/kl-approx.html + if kl_estimator == "k3": + log_ratio = log_probs.float() - log_probs_base.float() + log_ratio = -log_ratio + log_ratio = log_ratio.exp() - 1 - log_ratio + + log_ratio = log_ratio.clamp(min=-10, max=10) + return log_ratio \ No newline at end of file diff --git a/zoo/jericho/priorzero/utils/generator.py b/zoo/jericho/priorzero/utils/generator.py new file mode 100644 index 000000000..7cad27a6d --- /dev/null +++ b/zoo/jericho/priorzero/utils/generator.py @@ -0,0 +1,93 @@ +from typing import List, Dict, Any, Optional, Tuple +import ray +import torch + +class SamplesGenerator: + def __init__(self, vllm_engines, strategy, tokenizer, prompt_max_len, temperature, top_p): + self.strategy = strategy + self.args = strategy.args + self.vllm_engines = vllm_engines + self.tokenizer = tokenizer + self.prompt_max_len = prompt_max_len + self.temperature = temperature + self.top_p = top_p + + @torch.no_grad() + def _generate_vllm(self, all_prompts: List[str], all_labels: List[str], reduction: str = "mean"): + """Generate samples using vLLM engine. + + Args: + all_prompts: List of prompts to generate from + all_labels: List of labels corresponding to prompts + **kwargs: Additional arguments for generation + + Returns: + List of Experience objects containing generated samples + """ + from vllm import SamplingParams + assert reduction in ("mean", "sum") + assert len(all_prompts) == len(all_labels) + + llms = self.vllm_engines + + sampling_params = SamplingParams( + temperature=self.temperature, + top_p=self.top_p, + max_tokens=1, + include_stop_str_in_output=True, + logprobs=None, + prompt_logprobs=1 + ) + + all_context_texts = [] + for user_prompt in all_prompts: + context_text = self.tokenizer.apply_chat_template( + [{"role": "user", "content": user_prompt}], + tokenize=False, + add_generation_prompt=True, + ) + all_context_texts.append(context_text) + all_full_texts = [c + l + self.tokenizer.eos_token for c, l in zip(all_context_texts, all_labels)] + + full_prompt_token_ids = self.tokenizer(all_full_texts, add_special_tokens=False, max_length=self.prompt_max_len + 1, padding=False, truncation=True)["input_ids"] + context_token_ids = self.tokenizer(all_context_texts, add_special_tokens=False, max_length=self.prompt_max_len, padding=False, truncation=True)["input_ids"] + + prompt_lens = [len(x) for x in context_token_ids] + label_lens = [len(full_ids) - p_len for full_ids, p_len in zip(full_prompt_token_ids, prompt_lens)] + + + refs = [] + batch_size = (len(full_prompt_token_ids) + len(llms) - 1) // len(llms) + for i, llm in enumerate(llms): + full_prompt_token = full_prompt_token_ids[i * batch_size : (i + 1) * batch_size] + refs.append(llm.add_requests.remote(sampling_params=sampling_params, prompt_token_ids=full_prompt_token)) + ray.get(refs) + + all_output_refs = [] + for i, llm in enumerate(llms): + all_output_refs.append(llm.get_responses.remote()) + all_outputs = sum(ray.get(all_output_refs), []) + + scores = [] + for output, full_ids, p_len, l_len in zip(all_outputs, full_prompt_token_ids, prompt_lens, label_lens): + prompt_logprobs = getattr(output, "prompt_logprobs", None) + if prompt_logprobs is None: + scores.append(float("-inf")) + continue + + token_lps = [] + for idx in range(p_len, p_len + l_len): + label_token_id = full_ids[idx] + logprob_dict = prompt_logprobs[idx] + + token_lps.append(logprob_dict[label_token_id].logprob) + + if len(token_lps) == 0: + scores.append(float("-inf")) + continue + if reduction == "sum": + scores.append(sum(token_lps)) + else: + scores.append(sum(token_lps) / len(token_lps)) + + return scores diff --git a/zoo/jericho/priorzero/utils/vllm_engine.py b/zoo/jericho/priorzero/utils/vllm_engine.py new file mode 100644 index 000000000..16b9c9765 --- /dev/null +++ b/zoo/jericho/priorzero/utils/vllm_engine.py @@ -0,0 +1,248 @@ +import os +import queue +from typing import Any, List + +import ray +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +@ray.remote +def get_all_env_variables(): + return os.environ + + +class BaseLLMRayActor: + def __init__(self, *args, bundle_indices: list = None, **kwargs): + kwargs.pop("agent_func_path", None) + noset_visible_devices = ray_noset_visible_devices() + if kwargs.get("distributed_executor_backend") == "ray": + # a hack to make the script work. + # stop ray from manipulating *_VISIBLE_DEVICES + # at the top-level when the distributed_executor_backend is ray. + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + os.environ.pop("ROCR_VISIBLE_DEVICES", None) + os.environ.pop("HIP_VISIBLE_DEVICES", None) + elif noset_visible_devices: + # We need to set CUDA_VISIBLE_DEVICES to the ray assigned GPU + # when the distributed_executor_backend is not ray and + # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set. + os.environ["CUDA_VISIBLE_DEVICES"] = str(ray.get_gpu_ids()[0]) + + num_gpus = kwargs.pop("num_gpus") + if bundle_indices is not None: + os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(num_gpus) + os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) + print(f"creating LLM with bundle_indices={bundle_indices}") + + # Number of actors that will send prompt to this engine + self.requests = {} + self.response_queues = queue.Queue() + + full_determinism = kwargs.pop("full_determinism", False) + if full_determinism: + # https://github.com/vllm-project/vllm/blob/effc5d24fae10b29996256eb7a88668ff7941aed/examples/offline_inference/reproduciblity.py#L11 + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + + self.kwargs = kwargs + + import vllm + from packaging import version + + if version.parse(vllm.__version__) >= version.parse("0.9.0"): + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + + +@ray.remote +class LLMRayActor(BaseLLMRayActor): + def __init__(self, *args, bundle_indices: list = None, **kwargs): + super().__init__(*args, bundle_indices=bundle_indices, **kwargs) + + import vllm + + self.llm = vllm.LLM(*args, **self.kwargs) + + def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend, use_ray): + return self.llm.collective_rpc( + "init_process_group", + args=(master_address, master_port, rank_offset, world_size, group_name, backend, use_ray), + ) + + def update_weight(self, name, dtype, shape, empty_cache=False): + return self.llm.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)) + + def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles, empty_cache=False): + return self.llm.collective_rpc("update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache)) + + def reset_prefix_cache(self): + self.llm.llm_engine.reset_prefix_cache() + + def sleep(self, level=1): + self.llm.sleep(level=level) + + def wake_up(self): + self.llm.wake_up() + + def add_requests(self, sampling_params, prompt_token_ids): + """ + Process requests from rank0 and generate responses. + Since only rank0 will send requests, we don't need to track actor ranks. + """ + from vllm.inputs import TokensPrompt + + requests = [TokensPrompt(prompt_token_ids=r) for r in prompt_token_ids] + responses = self.llm.generate(prompts=requests, sampling_params=sampling_params) + self.response_queues.put(responses) + + def get_responses(self): + """ + Return the responses for the actor with the given rank + """ + return self.response_queues.get() + + +def create_vllm_engines( + num_engines: int, + tensor_parallel_size: int, + pretrain: str, + seed: int, + full_determinism: bool, + enable_prefix_caching: bool, + enforce_eager: bool, + max_model_len: int, + shared_pg=None, + gpu_memory_utilization=None, + vllm_enable_sleep=False, + llm_actor_cls=LLMRayActor, + logprobs_mode=None, + agent_func_path=None, +): + import vllm + from packaging import version + + assert version.parse(vllm.__version__) > version.parse("0.8.2"), "OpenRLHF only supports vllm > 0.8.2" + + vllm_engines = [] + distributed_executor_backend = "uni" if tensor_parallel_size == 1 else "ray" + use_hybrid_engine = shared_pg is not None + num_gpus = int(tensor_parallel_size == 1) + if use_hybrid_engine and tensor_parallel_size == 1: + # every worker will use 0.2 GPU, so that we can schedule + # 2 instances on the same GPUs. + num_gpus = 0.2 + + if not use_hybrid_engine: + # Create a big placement group to ensure that all engines are packed + bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_engines * tensor_parallel_size)] + shared_pg = placement_group(bundles, strategy="PACK") + ray.get(shared_pg.ready()) + + for i in range(num_engines): + bundle_indices = None + if tensor_parallel_size > 1: + bundle_indices = get_bundle_indices(shared_pg, i, tensor_parallel_size) + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=shared_pg, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_indices[0] if bundle_indices else i, + ) + + additional_kwargs = {} + if logprobs_mode: + additional_kwargs["logprobs_mode"] = logprobs_mode + additional_kwargs["max_logprobs"] = 1 + assert version.parse(vllm.__version__) > version.parse( + "0.10.0" + ), "vLLM > 0.10.0 is required for logprobs_mode" + + vllm_engines.append( + llm_actor_cls.options( + num_cpus=num_gpus, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + ).remote( + model=pretrain, + enforce_eager=enforce_eager, + worker_extension_cls="openrlhf.trainer.ray.vllm_worker_wrap.WorkerWrap", + tensor_parallel_size=tensor_parallel_size, + seed=seed + i, + distributed_executor_backend=distributed_executor_backend, + max_model_len=max_model_len, + enable_prefix_caching=enable_prefix_caching, + dtype="bfloat16", + trust_remote_code=True, + full_determinism=full_determinism, + gpu_memory_utilization=gpu_memory_utilization, + bundle_indices=bundle_indices, + num_gpus=0.2 if use_hybrid_engine else 1, + enable_sleep_mode=vllm_enable_sleep, + agent_func_path=agent_func_path, + **additional_kwargs, + ) + ) + + return vllm_engines + + +def batch_vllm_engine_call(engines: List[Any], method_name: str, *args, rank_0_only: bool = True, **kwargs): + """ + Batch call a method on multiple vLLM engines. + Args: + engines: List of vLLM engine instances + method_name: Name of the method to call + rank_0_only: Only execute on rank 0 if True + *args: Positional arguments to pass to the method + **kwargs: Keyword arguments to pass to the method + Returns: + List of results from ray.get() if on rank 0, None otherwise + """ + import torch + + if torch.distributed.is_initialized(): + if rank_0_only and torch.distributed.get_rank() != 0: + return None + + refs = [] + for engine in engines: + method = getattr(engine, method_name) + refs.append(method.remote(*args, **kwargs)) + + return ray.get(refs) + + +# Address https://github.com/ray-project/ray/issues/51117 +# This function is used to get the bundle indices of a placement group +# and ensure that the bundles placed on the same node are grouped together. +def get_bundle_indices(placement_group, index, length): + import ray + + pg_infos = ray.util.placement_group_table(placement_group) + + node_id_to_bundles = {} + for bundle, node_id in pg_infos["bundles_to_node_id"].items(): + node_id_to_bundles.setdefault(node_id, []).append(bundle) + + sorted_bundle_indices = sum(node_id_to_bundles.values(), []) + return sorted_bundle_indices[index * length : (index + 1) * length] + + +def ray_noset_visible_devices(env_vars=os.environ): + NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [ + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES", + "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES", + "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS", + "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", + ] + return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST) + + +def get_physical_gpu_id(): + import torch + + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + return str(props.uuid)