diff --git a/lzero/entry/train_unizero_multitask_segment_ddp.py b/lzero/entry/train_unizero_multitask_segment_ddp.py index 3fdcfa099..8f914ff32 100644 --- a/lzero/entry/train_unizero_multitask_segment_ddp.py +++ b/lzero/entry/train_unizero_multitask_segment_ddp.py @@ -2,6 +2,7 @@ import os from functools import partial from typing import Tuple, Optional, List +import concurrent.futures import torch import numpy as np @@ -13,17 +14,24 @@ from ding.worker import BaseLearner from tensorboardX import SummaryWriter -from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler + from lzero.policy import visit_count_temperature from lzero.worker import MuZeroEvaluator as Evaluator from lzero.worker import MuZeroSegmentCollector as Collector from ding.utils import EasyTimer import torch.nn.functional as F - +import sys +import os import torch.distributed as dist +# Import MOE statistics functions from utils +from lzero.entry.utils import ( + collect_and_log_moe_statistics, + TemperatureScheduler, + log_buffer_memory_usage +) # ------------------------------------------------------------ -# 1. 额外增加 learner 专用 process-group +# 1. 额外增加 learner 专用 process-group # (在 main / learner 初始化时调用一次) # ------------------------------------------------------------ def build_learner_group(learner_ranks: list[int]) -> dist.ProcessGroup: @@ -367,7 +375,9 @@ def train_unizero_multitask_segment_ddp( model_path: Optional[str] = None, max_train_iter: Optional[int] = int(1e10), max_env_step: Optional[int] = int(1e10), - benchmark_name: str = "atari" + benchmark_name: str = "atari", + finetune_components=[], + cal_moe_profile: bool = True # 新增:控制MOE性能监控的开关 ) -> 'Policy': """ Overview: @@ -520,20 +530,25 @@ def train_unizero_multitask_segment_ddp( # 编译配置 cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) - # 创建共享的policy - policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) - + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + cfg.policy.logger=tb_logger + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) # MOE + policy.logger=tb_logger + + # 加载预训练模型(如果提供) if model_path is not None: logging.info(f'开始加载模型: {model_path}') - policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device),finetune_components=finetune_components) logging.info(f'完成加载模型: {model_path}') # 创建TensorBoard日志记录器 log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') tb_logger = SummaryWriter(log_dir) - # 创建共享的learner + # 创建共享的learner #todo: cfg.policy.learn.learner.hook.log_show_after_iter learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) policy_config = cfg.policy @@ -645,6 +660,7 @@ def train_unizero_multitask_segment_ddp( # if learner.train_iter > 10 and evaluator.should_eval(learner.train_iter): # only for debug # if evaluator.should_eval(learner.train_iter): print('=' * 20) + print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...') # =========TODO========= @@ -720,7 +736,7 @@ def train_unizero_multitask_segment_ddp( print(f"not_enough_data:{not_enough_data}") # 获取当前温度 current_temperature_task_weight = temperature_scheduler.get_temperature(learner.train_iter) - + # if learner.train_iter == 0 or learner.train_iter % cfg.policy.eval_freq == 0 : if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0 : @@ -811,7 +827,12 @@ def train_unizero_multitask_segment_ddp( # 在训练时,DDP会自动同步梯度和参数 log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs) - # logging.error(f'Rank {rank}: one learn step done') + # +++++++++++++++++++++++++++++++++ MOE expert selection statistics logging +++++++++++++++++++++++++++++++++ + if cal_moe_profile and cfg.policy.model.world_model_cfg.multiplication_moe_in_transformer and cfg.policy.model.world_model_cfg.num_experts_of_moe_in_transformer: + # Control MoE statistics logging frequency + moe_log_interval = getattr(cfg.policy, 'moe_log_interval', 1) # Default: log once every 500 iterations + if learner.train_iter % moe_log_interval == 0: + collect_and_log_moe_statistics(policy, tb_logger, learner.train_iter, world_size, rank) # 判断是否需要计算task_exploitation_weight if i == 0: diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index b51eb7f11..60c0d7631 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -1,5 +1,8 @@ import os -from typing import Optional, Callable, Union, List, Tuple +import time +from typing import Optional, Callable, Union, List, Tuple, Dict +from io import BytesIO +import concurrent.futures import psutil import torch @@ -7,12 +10,11 @@ from pympler.asizeof import asizeof from tensorboardX import SummaryWriter - -import torch import numpy as np -import torch import torch.nn.functional as F import matplotlib.pyplot as plt +import seaborn as sns +from PIL import Image # ============================================================ # freeze_non_lora.py @@ -362,3 +364,962 @@ def log_buffer_run_time(train_iter: int, buffer: "GameBuffer", writer: SummaryWr # Reset the time records in the buffer. buffer.reset_runtime_metrics() + + +# ============================================================ +# MOE Expert Selection Statistics Functions +# ============================================================ + +# Global heatmap figure cache to avoid repeated creation +_GLOBAL_HEATMAP_FIG = None +_GLOBAL_HEATMAP_AX = None + + +def merge_expert_stats_across_ranks(all_expert_stats): + """ + Overview: + Merge expert selection statistics data from all distributed training ranks. + Combines statistics from different GPU processes for comprehensive analysis. + Arguments: + - all_expert_stats (:obj:`list`): List of expert statistics from all ranks. + Returns: + - merged_stats (:obj:`dict`): Merged statistics dictionary with structure + {task_id: {window_type: stats}}. + Examples: + >>> stats_list = [rank0_stats, rank1_stats, rank2_stats] + >>> merged = merge_expert_stats_across_ranks(stats_list) + >>> print(f"Merged {len(merged)} tasks") + """ + merged_stats = {} # {task_id: {window_type: stats}} + + for rank_expert_stats in all_expert_stats: + if rank_expert_stats: + for task_id, task_stats in rank_expert_stats.items(): + if task_id not in merged_stats: + merged_stats[task_id] = {} + + for window_type, stats in task_stats.items(): + # Only process statistics with actual data (tasks handled by current GPU) + if stats and stats.get('total_selections', 0) > 0: + merged_stats[task_id][window_type] = { + 'frequencies': np.array(stats['frequencies']), + 'total_selections': stats['total_selections'], + 'data_points': stats['data_points'] + } + return merged_stats + + +def _get_or_create_heatmap_figure(figsize): + """ + Overview: + Get or create a reusable heatmap figure for memory efficiency. + Maintains global figure cache to reduce memory allocation overhead. + Arguments: + - figsize (:obj:`tuple`): Figure size as (width, height). + Returns: + - fig (:obj:`matplotlib.figure.Figure`): Matplotlib figure object. + - ax (:obj:`matplotlib.axes.Axes`): Matplotlib axes object. + Examples: + >>> fig, ax = _get_or_create_heatmap_figure((10, 8)) + >>> ax.plot([1, 2, 3], [4, 5, 6]) + """ + global _GLOBAL_HEATMAP_FIG, _GLOBAL_HEATMAP_AX + if _GLOBAL_HEATMAP_FIG is None: + _GLOBAL_HEATMAP_FIG, _GLOBAL_HEATMAP_AX = plt.subplots(figsize=figsize) + else: + # Clear previous content + _GLOBAL_HEATMAP_AX.clear() + # Adjust image size + _GLOBAL_HEATMAP_FIG.set_size_inches(figsize) + return _GLOBAL_HEATMAP_FIG, _GLOBAL_HEATMAP_AX + + +def create_heatmap_with_values_fast(matrix, task_ids, title="Task-Expert Selection Frequencies"): + """ + Overview: + Efficiently create annotated blue-themed heatmap with performance optimizations. + Optimizations include matplotlib figure reuse, selective value annotations, + optimized image conversion pipeline, and reduced DPI for faster computation. + Arguments: + - matrix (:obj:`numpy.ndarray`): Input matrix for heatmap visualization. + - task_ids (:obj:`list`): List of task identifiers for y-axis labels. + - title (:obj:`str`, optional): Heatmap title. Default is "Task-Expert Selection Frequencies". + Returns: + - img_array (:obj:`numpy.ndarray`): Image array in CHW format for TensorBoard logging. + Shapes: + - matrix: :math:`(N_{tasks}, N_{experts})` where N_tasks and N_experts are dimensions. + - img_array: :math:`(3, H, W)` where H and W are image height and width. + Examples: + >>> import numpy as np + >>> matrix = np.random.rand(5, 8) + >>> task_ids = [0, 1, 2, 3, 4] + >>> heatmap = create_heatmap_with_values_fast(matrix, task_ids) + >>> print(f"Heatmap shape: {heatmap.shape}") # (3, height, width) + """ + try: + figsize = (max(6, matrix.shape[1]), max(4, matrix.shape[0])) + fig, ax = _get_or_create_heatmap_figure(figsize) + + # Intelligently choose whether to display value annotations + show_annot = matrix.size <= 64 # Only display values for 8x8 or smaller matrices + + # Use matplotlib directly to avoid seaborn overhead + im = ax.imshow(matrix, cmap='Blues', aspect='auto') + + # Selectively add value annotations + if show_annot: + for i in range(matrix.shape[0]): + for j in range(matrix.shape[1]): + value = matrix[i, j] + color = 'white' if value > 0.5 else 'black' + ax.text(j, i, f'{value:.3f}', ha='center', va='center', + color=color, fontsize=8) + + # Set labels and title + ax.set_xticks(range(matrix.shape[1])) + ax.set_yticks(range(matrix.shape[0])) + ax.set_xticklabels([f'E{i}' for i in range(matrix.shape[1])], fontsize=10) + ax.set_yticklabels([f'T{tid}' for tid in task_ids], fontsize=10) + ax.set_title(title, fontsize=12, pad=15) + ax.set_xlabel('Experts', fontsize=10) + ax.set_ylabel('Tasks', fontsize=10) + + # Simplified colorbar + if not hasattr(fig, '_colorbar_created'): + plt.colorbar(im, ax=ax, label='Frequency') + fig._colorbar_created = True + + # Optimized image conversion: using lower DPI and simplified pipeline + fig.canvas.draw() + try: + # Get RGB data directly from canvas + if hasattr(fig.canvas, 'buffer_rgba'): + buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) + buf = buf.reshape(fig.canvas.get_width_height()[::-1] + (4,)) + img_array = buf[:, :, :3] # Remove alpha channel + else: + buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + img_array = buf.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + + # Convert to CHW format + img_array = img_array.transpose(2, 0, 1) + + except Exception: + # Fallback: create simple blue gradient matrix + h, w = matrix.shape + img_array = np.zeros((3, h*20, w*20), dtype=np.uint8) + # Simple matrix upscaling and mapping to blue channel + matrix_resized = np.repeat(np.repeat(matrix, 20, axis=0), 20, axis=1) + img_array[2] = (matrix_resized * 255).astype(np.uint8) + + return img_array + + except Exception as e: + print(f"Warning: Heatmap generation failed: {e}, using fallback") + # Ultimate fallback: return blank image + return np.zeros((3, 100, 100), dtype=np.uint8) + + +def create_heatmap_with_values(matrix, task_ids, title="Task-Expert Selection Frequencies"): + """ + Overview: + Create annotated blue-themed heatmap using seaborn - original version for fallback. + This function serves as a backup when the optimized version encounters issues. + Arguments: + - matrix (:obj:`numpy.ndarray`): Input matrix for heatmap visualization. + - task_ids (:obj:`list`): List of task identifiers for y-axis labels. + - title (:obj:`str`, optional): Heatmap title. Default is "Task-Expert Selection Frequencies". + Returns: + - img_array (:obj:`numpy.ndarray`): Image array in CHW format for TensorBoard logging. + Shapes: + - matrix: :math:`(N_{tasks}, N_{experts})` where N_tasks and N_experts are dimensions. + - img_array: :math:`(3, H, W)` where H and W are image height and width. + Examples: + >>> import numpy as np + >>> matrix = np.random.rand(5, 8) + >>> task_ids = [0, 1, 2, 3, 4] + >>> heatmap = create_heatmap_with_values(matrix, task_ids) + >>> print(f"Heatmap shape: {heatmap.shape}") # (3, height, width) + """ + fig, ax = plt.subplots(figsize=(max(8, matrix.shape[1]), max(6, matrix.shape[0]))) + + # Use blue color scheme + sns.heatmap(matrix, + annot=True, # Display values + fmt='.3f', # Value format + cmap='Blues', # Blue theme + ax=ax, + cbar_kws={'label': 'Selection Frequency'}, + xticklabels=[f'Expert{i}' for i in range(matrix.shape[1])], + yticklabels=[f'Task{tid}' for tid in task_ids]) + + ax.set_title(title, fontsize=14, pad=20) + ax.set_xlabel('Experts', fontsize=12) + ax.set_ylabel('Tasks', fontsize=12) + + plt.tight_layout() + + # Save to BytesIO + buf = BytesIO() + plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') + buf.seek(0) + + # Convert to numpy array for tensorboard + img = Image.open(buf) + img_array = np.array(img) + buf.close() + plt.close(fig) + + # Convert to CHW format (Channel, Height, Width) + if len(img_array.shape) == 3: + img_array = img_array.transpose(2, 0, 1) + + return img_array + + +def log_expert_selection_details(tb_logger, merged_stats, valid_task_ids, matrix, window_type, train_iter): + """ + Overview: + Log detailed expert selection statistics for each task. + Records frequency entropy, variance, and total selections for analysis. + Arguments: + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - merged_stats (:obj:`dict`): Merged expert selection statistics across ranks. + - valid_task_ids (:obj:`list`): List of valid task identifiers. + - matrix (:obj:`numpy.ndarray`): Expert selection frequency matrix. + - window_type (:obj:`str`): Time window type (immediate, short, medium, long). + - train_iter (:obj:`int`): Current training iteration for logging. + Examples: + >>> log_expert_selection_details(tb_logger, stats, [0,1,2], matrix, 'immediate', 1000) + """ + for i, task_id in enumerate(valid_task_ids): + frequencies = matrix[i] + stats = merged_stats[task_id][window_type] + + # Calculate and record task expert selection entropy (uniformity metric) + task_frequencies = np.array(frequencies) + task_frequencies = task_frequencies + 1e-8 # Avoid log(0) + task_entropy = -np.sum(task_frequencies * np.log(task_frequencies)) + tb_logger.add_scalar( + f'MOE_Details/Task{task_id}_{window_type}/ExpertSelectionEntropy', + task_entropy, global_step=train_iter + ) + + # Record task expert selection variance (dispersion) + expert_variance = np.var(task_frequencies) + tb_logger.add_scalar( + f'MOE_Details/Task{task_id}_{window_type}/ExpertSelectionVariance', + expert_variance, global_step=train_iter + ) + + # Record task-level summary statistics + tb_logger.add_scalar( + f'MOE_Details/Task{task_id}_{window_type}/TotalSelections', + stats['total_selections'], global_step=train_iter + ) + tb_logger.add_scalar( + f'MOE_Details/Task{task_id}_{window_type}/DataPoints', + stats['data_points'], global_step=train_iter + ) + + +def log_global_moe_statistics(tb_logger, matrix, window_type, valid_task_ids, train_iter): + """ + Overview: + Log global MOE statistics including expert usage uniformity and extremes. + Provides system-wide view of expert utilization patterns. + Arguments: + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - matrix (:obj:`numpy.ndarray`): Expert selection frequency matrix. + - window_type (:obj:`str`): Time window type (immediate, short, medium, long). + - valid_task_ids (:obj:`list`): List of valid task identifiers. + - train_iter (:obj:`int`): Current training iteration for logging. + Examples: + >>> log_global_moe_statistics(tb_logger, matrix, 'immediate', [0,1,2], 1000) + """ + # Record basic information + tb_logger.add_scalar( + f'MOE_Global/{window_type}/NumActiveTasks', + len(valid_task_ids), global_step=train_iter + ) + tb_logger.add_scalar( + f'MOE_Global/{window_type}/NumExperts', + matrix.shape[1], global_step=train_iter + ) + + # Calculate expert usage uniformity + expert_avg_usage = np.mean(matrix, axis=0) # Average usage frequency per expert + usage_entropy = -np.sum(expert_avg_usage * np.log(expert_avg_usage + 1e-8)) + tb_logger.add_scalar( + f'MOE_Global/{window_type}/ExpertUsageEntropy', + usage_entropy, global_step=train_iter + ) + + # Record most and least used experts + most_used_expert = np.argmax(expert_avg_usage) + least_used_expert = np.argmin(expert_avg_usage) + tb_logger.add_scalar( + f'MOE_Global/{window_type}/MostUsedExpert', + most_used_expert, global_step=train_iter + ) + tb_logger.add_scalar( + f'MOE_Global/{window_type}/LeastUsedExpert', + least_used_expert, global_step=train_iter + ) + + +def process_and_log_moe_heatmaps_fast(tb_logger, merged_stats, window_type, train_iter): + """ + Overview: + Efficiently process and log MOE heatmaps with performance optimizations. + Includes vectorized data processing, conditional heatmap generation, + and batch statistical processing. + Arguments: + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - merged_stats (:obj:`dict`): Merged expert selection statistics across ranks. + - window_type (:obj:`str`): Time window type (immediate, short, medium, long). + - train_iter (:obj:`int`): Current training iteration for logging. + Examples: + >>> process_and_log_moe_heatmaps_fast(tb_logger, stats, 'immediate', 1000) + """ + # Quick filtering of valid tasks + valid_task_data = [(tid, stats[window_type]['frequencies']) + for tid, stats in merged_stats.items() + if window_type in stats] + + if not valid_task_data: + return + + # Vectorized matrix construction + valid_task_ids, frequencies_list = zip(*valid_task_data) + matrix = np.array(frequencies_list) + + # Conditional heatmap generation: only for small matrices + if matrix.size <= 200: # Only generate heatmap when tasks*experts <= 200 + try: + heatmap_img = create_heatmap_with_values_fast( + matrix, valid_task_ids, + f'MOE {window_type} Task-Expert Selection' + ) + + # Log heatmap to tensorboard + tb_logger.add_image( + f'MOE_Heatmap/{window_type}_TaskExpert_Heatmap', + heatmap_img, + global_step=train_iter, + dataformats='CHW' + ) + except Exception as e: + print(f"Warning: Heatmap generation failed: {e}") + + # Always log statistical data (lightweight operation) + log_expert_selection_details(tb_logger, merged_stats, valid_task_ids, matrix, window_type, train_iter) + log_global_moe_statistics(tb_logger, matrix, window_type, valid_task_ids, train_iter) + + +def process_and_log_moe_heatmaps(tb_logger, merged_stats, window_type, train_iter): + """ + Overview: + Process and log MOE heatmaps - original version for fallback. + This function serves as a backup when the optimized version encounters issues. + Arguments: + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - merged_stats (:obj:`dict`): Merged expert selection statistics across ranks. + - window_type (:obj:`str`): Time window type (immediate, short, medium, long). + - train_iter (:obj:`int`): Current training iteration for logging. + Examples: + >>> process_and_log_moe_heatmaps(tb_logger, stats, 'immediate', 1000) + """ + all_task_ids = sorted(merged_stats.keys()) + task_expert_matrix = [] + valid_task_ids = [] + + # Collect frequency data from valid tasks + for task_id in all_task_ids: + if window_type in merged_stats[task_id]: + frequencies = merged_stats[task_id][window_type]['frequencies'] + task_expert_matrix.append(frequencies) + valid_task_ids.append(task_id) + + if not task_expert_matrix: + return + + # Convert to numpy matrix (num_tasks, num_experts) + matrix = np.array(task_expert_matrix) + + # Create annotated blue-themed heatmap + heatmap_img = create_heatmap_with_values( + matrix, valid_task_ids, + f'MOE {window_type} Task-Expert Selection Frequencies' + ) + + # Log heatmap to tensorboard + tb_logger.add_image( + f'MOE_Heatmap/{window_type}_TaskExpert_Heatmap', + heatmap_img, + global_step=train_iter, + dataformats='CHW' + ) + + # Log detailed and global statistics + log_expert_selection_details(tb_logger, merged_stats, valid_task_ids, matrix, window_type, train_iter) + + +def convert_stats_to_serializable(moe_stats): + """ + Overview: + Convert tensor data in MOE statistics to serializable numpy format. + Ensures compatibility with distributed communication protocols. + Arguments: + - moe_stats (:obj:`dict`): MOE statistics containing tensor data. + Returns: + - converted (:obj:`dict`): Converted statistics with numpy arrays. + Examples: + >>> tensor_stats = {'task_0': {'immediate': {'frequencies': torch.tensor([0.1, 0.9])}}} + >>> numpy_stats = convert_stats_to_serializable(tensor_stats) + >>> type(numpy_stats['task_0']['immediate']['frequencies']) # + """ + if not moe_stats: + return {} + + converted = {} + for task_id, task_stats in moe_stats.items(): + converted[task_id] = {} + for window_type, stats in task_stats.items(): + if stats and 'frequencies' in stats: + converted[task_id][window_type] = { + 'frequencies': stats['frequencies'].cpu().numpy().tolist(), + 'total_selections': stats['total_selections'], + 'data_points': stats['data_points'] + } + return converted + + +def gather_distributed_moe_stats(local_stats, world_size): + """ + Overview: + Gather MOE statistics from all GPUs in distributed training environment. + Handles communication failures gracefully with fallback to local statistics. + Arguments: + - local_stats (:obj:`dict`): Local GPU's MOE statistics. + - world_size (:obj:`int`): Total number of distributed training processes. + Returns: + - all_stats (:obj:`list`): List of statistics from all ranks. + Examples: + >>> local_data = {'task_0': {'immediate': {'frequencies': [0.1, 0.9]}}} + >>> all_data = gather_distributed_moe_stats(local_data, 4) + >>> len(all_data) # 4 (from 4 GPUs) + """ + all_stats = [None for _ in range(world_size)] + try: + dist.all_gather_object(all_stats, local_stats) + return all_stats + except Exception as e: + print(f"Distributed MOE statistics gathering failed: {e}") + return [local_stats] # fallback to local statistics + + +def collect_and_log_moe_statistics(policy, tb_logger, train_iter, world_size, rank): + """ + Overview: + Collect and log MOE expert selection statistics including heatmaps and distribution analysis. + Comprehensive function that handles distributed data collection, merging, and visualization. + Arguments: + - policy (:obj:`Policy`): Training policy object containing world model. + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - train_iter (:obj:`int`): Current training iteration number. + - world_size (:obj:`int`): Total number of GPUs in distributed training. + - rank (:obj:`int`): Current GPU rank identifier. + Examples: + >>> collect_and_log_moe_statistics(policy, tb_logger, 1000, 8, 0) + """ + try: + # Step 1: Get MOE statistics from policy's transformer model + moe_stats = None + + transformer = policy._model.world_model.transformer + if hasattr(transformer, 'get_expert_selection_stats'): + moe_stats = transformer.get_expert_selection_stats() + + if moe_stats is None: + print(f"Rank {rank}: Warning: Unable to get MOE statistics, train_iter={train_iter}") + return + + # Step 2: Convert tensor data to serializable format + serializable_stats = convert_stats_to_serializable(moe_stats) + + print(f"Rank {rank}: Local MOE statistics - tasks: {len(serializable_stats)}, train_iter={train_iter}") + + # Step 3: Gather statistics from all GPUs in distributed setting + all_expert_stats = gather_distributed_moe_stats(serializable_stats, world_size) + + # Step 4: Merge statistics data + merged_stats = merge_expert_stats_across_ranks(all_expert_stats) + + if not merged_stats: + print(f"Rank {rank}: Warning: Merged MOE statistics empty, train_iter={train_iter}") + return + + # Step 5: All GPUs log MOE statistics + print(f"Rank {rank}: Starting MOE statistics logging - merged tasks: {len(merged_stats)}, train_iter={train_iter}") + + # Generate heatmaps and statistics for each time window + for window_type in ['immediate', 'short', 'medium', 'long']: + if any(window_type in task_stats for task_stats in merged_stats.values()): + process_and_log_moe_heatmaps_fast(tb_logger, merged_stats, window_type, train_iter) + + # Log overall MOE usage + tb_logger.add_scalar('MOE_Global/ActiveTasks', len(merged_stats), global_step=train_iter) + + # Step 6: Add distribution difference computation and logging + if any('immediate' in task_stats for task_stats in merged_stats.values()): + print(f"Rank {rank}: Starting inter-task distribution difference calculation...") + collect_and_log_divergences_with_heatmaps(tb_logger, merged_stats, train_iter) + + print(f"Rank {rank}: MOE statistics logging completed, train_iter={train_iter}") + + except Exception as e: + print(f"Rank {rank}: MOE statistics collection failed - {e}, train_iter={train_iter}") + import traceback + traceback.print_exc() + + +# ====== GPU-Optimized Distribution Divergence Calculation and Visualization Functions ====== +def jensen_shannon_divergence_batch_gpu(distributions_tensor): + """ + Overview: + GPU batch computation of JS divergence matrix - fully vectorized, no loops. + Efficiently computes Jensen-Shannon divergence between all pairs of distributions. + Arguments: + - distributions_tensor (:obj:`torch.Tensor`): Shape (n_tasks, n_experts), GPU tensor. + Returns: + - js_matrix (:obj:`torch.Tensor`): Shape (n_tasks, n_tasks), symmetric matrix. + Shapes: + - distributions_tensor: :math:`(N_{tasks}, N_{experts})` + - js_matrix: :math:`(N_{tasks}, N_{tasks})` + Examples: + >>> dist_tensor = torch.rand(5, 8).cuda() + >>> js_matrix = jensen_shannon_divergence_batch_gpu(dist_tensor) + >>> print(js_matrix.shape) # torch.Size([5, 5]) + """ + device = distributions_tensor.device + n_tasks, n_experts = distributions_tensor.shape + + # 1. Normalize to probability distributions + eps = 1e-8 + distributions_tensor = distributions_tensor / (distributions_tensor.sum(dim=1, keepdim=True) + eps) + + # 2. Use broadcasting to compute average distributions for all task pairs + # P_i: (n_tasks, 1, n_experts), P_j: (1, n_tasks, n_experts) + P_i = distributions_tensor.unsqueeze(1) + P_j = distributions_tensor.unsqueeze(0) + M = 0.5 * (P_i + P_j) # shape: (n_tasks, n_tasks, n_experts) + + # 3. Batch compute KL divergences - fully vectorized + # KL(P_i || M) for all pairs + log_ratio_i = torch.log((P_i + eps) / (M + eps)) + kl_i_m = torch.sum(P_i * log_ratio_i, dim=2) # (n_tasks, n_tasks) + + # KL(P_j || M) for all pairs + log_ratio_j = torch.log((P_j + eps) / (M + eps)) + kl_j_m = torch.sum(P_j * log_ratio_j, dim=2) # (n_tasks, n_tasks) + + # 4. JS divergence matrix + js_matrix = 0.5 * (kl_i_m + kl_j_m) + + return js_matrix + + +def wasserstein_distance_batch_gpu(distributions_tensor): + """ + Overview: + GPU batch computation of Wasserstein distance matrix - efficient 1D distribution implementation. + Computes Earth Mover's Distance between all pairs of discrete distributions. + Arguments: + - distributions_tensor (:obj:`torch.Tensor`): Shape (n_tasks, n_experts), GPU tensor. + Returns: + - wasserstein_matrix (:obj:`torch.Tensor`): Shape (n_tasks, n_tasks), symmetric matrix. + Shapes: + - distributions_tensor: :math:`(N_{tasks}, N_{experts})` + - wasserstein_matrix: :math:`(N_{tasks}, N_{tasks})` + Examples: + >>> dist_tensor = torch.rand(5, 8).cuda() + >>> wass_matrix = wasserstein_distance_batch_gpu(dist_tensor) + >>> print(wass_matrix.shape) # torch.Size([5, 5]) + """ + device = distributions_tensor.device + n_tasks, n_experts = distributions_tensor.shape + eps = 1e-8 + + # 1. Normalize to probability distributions + distributions_tensor = distributions_tensor / (distributions_tensor.sum(dim=1, keepdim=True) + eps) + + # 2. Compute cumulative distribution functions (CDF) + cdf_tensor = torch.cumsum(distributions_tensor, dim=1) # (n_tasks, n_experts) + + # 3. Use broadcasting to compute L1 distances between all CDF pairs + cdf_i = cdf_tensor.unsqueeze(1) # (n_tasks, 1, n_experts) + cdf_j = cdf_tensor.unsqueeze(0) # (1, n_tasks, n_experts) + + # Wasserstein distance = L1 norm of cumulative distribution differences + wasserstein_matrix = torch.sum(torch.abs(cdf_i - cdf_j), dim=2) + + return wasserstein_matrix + + +def compute_distribution_divergences_optimized(merged_stats, window_type='immediate'): + """ + Overview: + GPU-optimized version for efficient distribution divergence computation. + Leverages GPU acceleration for batch processing of divergence metrics. + Arguments: + - merged_stats (:obj:`dict`): Merged MOE statistics from all distributed ranks. + - window_type (:obj:`str`, optional): Time window type. Default is 'immediate'. + Returns: + - divergence_data (:obj:`dict`): Comprehensive divergence analysis results including + matrices, statistics, and metadata. + Examples: + >>> stats = {'task_0': {'immediate': {'frequencies': [0.1, 0.9]}}} + >>> result = compute_distribution_divergences_optimized(stats) + >>> print(f"GPU accelerated: {result['gpu_accelerated']}") + """ + # 1. Data preprocessing + valid_tasks = [(tid, stats[window_type]['frequencies']) + for tid, stats in merged_stats.items() + if window_type in stats] + + if len(valid_tasks) < 2: + return {} + + task_ids, frequencies_list = zip(*valid_tasks) + + # 2. Efficient tensor conversion + try: + if isinstance(frequencies_list[0], torch.Tensor): + frequencies_tensor = torch.stack(frequencies_list) + else: + frequencies_tensor = torch.tensor( + np.array(frequencies_list), + dtype=torch.float32 + ) + + # Automatic GPU acceleration + if torch.cuda.is_available(): + frequencies_tensor = frequencies_tensor.cuda() + + except Exception as e: + print(f"GPU conversion failed, using CPU: {e}") + frequencies_tensor = torch.tensor(np.array(frequencies_list), dtype=torch.float32) + + device = frequencies_tensor.device + n_tasks, n_experts = frequencies_tensor.shape + + # 3. GPU batch computation (no loops) + with torch.no_grad(): + # Batch compute JS divergence and Wasserstein distance + js_matrix = jensen_shannon_divergence_batch_gpu(frequencies_tensor) + wasserstein_matrix = wasserstein_distance_batch_gpu(frequencies_tensor) + + # Efficiently extract upper triangular values (avoid duplicate computation) + triu_indices = torch.triu_indices(n_tasks, n_tasks, offset=1, device=device) + js_values = js_matrix[triu_indices[0], triu_indices[1]] + wasserstein_values = wasserstein_matrix[triu_indices[0], triu_indices[1]] + + # Statistical computation (vectorized) + js_stats = { + 'avg': torch.mean(js_values).item(), + 'max': torch.max(js_values).item(), + 'min': torch.min(js_values).item(), + 'std': torch.std(js_values).item() + } + + wasserstein_stats = { + 'avg': torch.mean(wasserstein_values).item(), + 'max': torch.max(wasserstein_values).item(), + 'min': torch.min(wasserstein_values).item(), + 'std': torch.std(wasserstein_values).item() + } + + return { + 'task_ids': task_ids, + 'n_tasks': n_tasks, + 'n_experts': n_experts, + 'device': str(device), + 'gpu_accelerated': 'cuda' in str(device), + + # Return CPU versions for logging + 'js_matrix': js_matrix.cpu().numpy(), + 'wasserstein_matrix': wasserstein_matrix.cpu().numpy(), + 'js_stats': js_stats, + 'wasserstein_stats': wasserstein_stats + } + + +def create_similarity_heatmap_no_diagonal(similarity_matrix, task_ids, metric_name, title_suffix=""): + """ + Overview: + Create task similarity heatmap with diagonal elements removed. + Provides clear visualization of inter-task relationships without self-similarity noise. + Arguments: + - similarity_matrix (:obj:`numpy.ndarray`): Similarity matrix (n_tasks, n_tasks). + - task_ids (:obj:`list`): Task identifier list for axis labels. + - metric_name (:obj:`str`): Metric name ('js_divergence', 'wasserstein_distance'). + - title_suffix (:obj:`str`, optional): Additional title suffix. Default is "". + Returns: + - img_array (:obj:`numpy.ndarray`): Image array in CHW format for TensorBoard. + Shapes: + - similarity_matrix: :math:`(N_{tasks}, N_{tasks})` + - img_array: :math:`(3, H, W)` where H and W are image dimensions. + Examples: + >>> matrix = np.random.rand(5, 5) + >>> task_ids = [0, 1, 2, 3, 4] + >>> heatmap = create_similarity_heatmap_no_diagonal(matrix, task_ids, 'js_divergence') + >>> print(f"Output shape: {heatmap.shape}") # (3, height, width) + """ + try: + # Copy matrix to avoid modifying original data + matrix = similarity_matrix.copy() + + # Set diagonal to NaN so matplotlib displays as blank + np.fill_diagonal(matrix, np.nan) + + figsize = (max(6, len(task_ids)), max(4, len(task_ids))) + fig, ax = plt.subplots(figsize=figsize) # Create new figure to avoid reuse issues + + # Choose color mapping based on metric type + if 'js' in metric_name.lower(): + cmap = 'Reds' + title_name = 'JS Divergence' + vmin, vmax = 0, 1.0 + else: # wasserstein + cmap = 'Blues' + title_name = 'Wasserstein Distance' + vmin, vmax = None, None # Adaptive + + # Use masked array to handle NaN values, diagonal displays as white + masked_matrix = np.ma.masked_invalid(matrix) + im = ax.imshow(masked_matrix, cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto') + + # Add value annotations (skip diagonal) + if len(task_ids) <= 15: # Only add annotations for smaller task counts + for i in range(len(task_ids)): + for j in range(len(task_ids)): + if i != j: # Skip diagonal + value = matrix[i, j] + if not np.isnan(value): + threshold = (vmax or np.nanmax(matrix)) * 0.5 if vmax else np.nanmax(matrix) * 0.5 + color = 'white' if value > threshold else 'black' + ax.text(j, i, f'{value:.3f}', ha='center', va='center', + color=color, fontsize=8) + + # Set labels + ax.set_xticks(range(len(task_ids))) + ax.set_yticks(range(len(task_ids))) + ax.set_xticklabels([f'T{tid}' for tid in task_ids], fontsize=9) + ax.set_yticklabels([f'T{tid}' for tid in task_ids], fontsize=9) + ax.set_title(f'Task {title_name} Matrix {title_suffix} (No Diagonal)', fontsize=12) + ax.set_xlabel('Tasks', fontsize=10) + ax.set_ylabel('Tasks', fontsize=10) + + # Add colorbar + plt.colorbar(im, ax=ax, label=title_name, shrink=0.8) + + # Convert to image array - fix matplotlib version compatibility + fig.canvas.draw() + + try: + # New matplotlib uses buffer_rgba + if hasattr(fig.canvas, 'buffer_rgba'): + buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) + h, w = fig.canvas.get_width_height() + img_array = buf.reshape(h, w, 4)[:, :, :3] # Remove alpha channel + else: + # Old matplotlib fallback + buf = fig.canvas.print_to_string() + img_array = np.frombuffer(buf, dtype=np.uint8) + h, w = fig.canvas.get_width_height() + img_array = img_array.reshape(h, w, 3) + except Exception as conv_e: + print(f"Image conversion method failed: {conv_e}, trying PIL approach") + # Final fallback: convert through PIL + buf = BytesIO() + fig.savefig(buf, format='png', dpi=100, bbox_inches='tight') + buf.seek(0) + img = Image.open(buf) + img_array = np.array(img)[:, :, :3] # Remove alpha channel + buf.close() + + img_array = img_array.transpose(2, 0, 1) # CHW format + plt.close(fig) # Close figure to avoid memory leak + + return img_array + + except Exception as e: + print(f"Warning: No-diagonal heatmap generation failed: {e}") + return np.zeros((3, 100, 100), dtype=np.uint8) + + +def log_pairwise_optimized(tb_logger, divergence_data, train_iter): + """ + Overview: + Optimized task pair logging with batch processing. + Efficiently logs pairwise divergence metrics for all task combinations. + Arguments: + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - divergence_data (:obj:`dict`): Divergence computation results. + - train_iter (:obj:`int`): Current training iteration for logging. + Examples: + >>> log_pairwise_optimized(tb_logger, divergence_data, 1000) + """ + task_ids = divergence_data['task_ids'] + js_matrix = divergence_data['js_matrix'] + wasserstein_matrix = divergence_data['wasserstein_matrix'] + + # Batch construct task pair metric dictionary + pairwise_scalars = {} + + for i, task_i in enumerate(task_ids): + for j, task_j in enumerate(task_ids): + if i < j: # Only log upper triangle + # Construct metric names + js_key = f'TaskPairwise/Immediate_Task{task_i}_Task{task_j}_JS_Divergence' + wass_key = f'TaskPairwise/Immediate_Task{task_i}_Task{task_j}_Wasserstein_Distance' + + pairwise_scalars[js_key] = js_matrix[i, j] + pairwise_scalars[wass_key] = wasserstein_matrix[i, j] + + # Batch write to TensorBoard + for key, value in pairwise_scalars.items(): + tb_logger.add_scalar(key, float(value), global_step=train_iter) + + +def log_divergences_with_heatmaps(tb_logger, divergence_data, train_iter): + """ + Overview: + Log distribution divergence metrics and heatmaps (with diagonal removed). + Comprehensive logging of inter-task distribution analysis results. + Arguments: + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - divergence_data (:obj:`dict`): Divergence computation results. + - train_iter (:obj:`int`): Current training iteration for logging. + Examples: + >>> log_divergences_with_heatmaps(tb_logger, divergence_data, 1000) + """ + if not divergence_data: + return + + js_stats = divergence_data['js_stats'] + wasserstein_stats = divergence_data['wasserstein_stats'] + task_ids = divergence_data['task_ids'] + n_tasks = divergence_data['n_tasks'] + + # Debug: Check matrix data + js_matrix = divergence_data['js_matrix'] + wasserstein_matrix = divergence_data['wasserstein_matrix'] + print(f"DEBUG: JS matrix shape={js_matrix.shape}, range=[{np.min(js_matrix):.6f}, {np.max(js_matrix):.6f}]") + print(f"DEBUG: Wasserstein matrix shape={wasserstein_matrix.shape}, range=[{np.min(wasserstein_matrix):.6f}, {np.max(wasserstein_matrix):.6f}]") + + # 1. Log scalar metrics + scalar_dict = { + 'MOE_Divergence/Immediate_AvgJS_Divergence': js_stats['avg'], + 'MOE_Divergence/Immediate_MaxJS_Divergence': js_stats['max'], + 'MOE_Divergence/Immediate_AvgWasserstein_Distance': wasserstein_stats['avg'], + 'MOE_Divergence/Immediate_MaxWasserstein_Distance': wasserstein_stats['max'], + } + + for key, value in scalar_dict.items(): + tb_logger.add_scalar(key, value, global_step=train_iter) + + # 1.1 Print core metrics to console + print("=" * 65) + print(f" Inter-Task Distribution Divergence Statistics (Iteration: {train_iter})") + print("=" * 65) + print(f"Participating tasks: {n_tasks} | Task IDs: {list(task_ids)}") + print(f"Computing device: {divergence_data.get('device', 'Unknown')} | GPU acceleration: {'Enabled' if divergence_data.get('gpu_accelerated', False) else 'Disabled'}") + print("-" * 65) + print("JS Divergence (Jensen-Shannon Divergence):") + print(f" Average: {js_stats['avg']:.6f} | Maximum: {js_stats['max']:.6f}") + print(f" Minimum: {js_stats['min']:.6f} | Std Dev: {js_stats['std']:.6f}") + print("-" * 65) + print("Wasserstein Distance:") + print(f" Average: {wasserstein_stats['avg']:.6f} | Maximum: {wasserstein_stats['max']:.6f}") + print(f" Minimum: {wasserstein_stats['min']:.6f} | Std Dev: {wasserstein_stats['std']:.6f}") + print("=" * 65) + + # 2. Log similarity matrix heatmaps with diagonal removed + task_ids = divergence_data['task_ids'] + n_tasks = divergence_data['n_tasks'] + + if n_tasks <= 25: # Limit matrix size to avoid oversized heatmaps + try: + # JS divergence matrix heatmap (no diagonal) + js_heatmap = create_similarity_heatmap_no_diagonal( + divergence_data['js_matrix'], + task_ids, + 'js_divergence', + f'(Immediate-{n_tasks} tasks)' + ) + tb_logger.add_image( + 'TaskSimilarity/Immediate_JS_Matrix_NoDiagonal', + js_heatmap, + global_step=train_iter, + dataformats='CHW' + ) + + # Wasserstein distance matrix heatmap (no diagonal) + wass_heatmap = create_similarity_heatmap_no_diagonal( + divergence_data['wasserstein_matrix'], + task_ids, + 'wasserstein_distance', + f'(Immediate-{n_tasks} tasks)' + ) + tb_logger.add_image( + 'TaskSimilarity/Immediate_Wasserstein_Matrix_NoDiagonal', + wass_heatmap, + global_step=train_iter, + dataformats='CHW' + ) + + except Exception as e: + print(f"Warning: Similarity matrix heatmap generation failed: {e}") + + # 3. Log task pair metrics (optional) + if n_tasks <= 20: + log_pairwise_optimized(tb_logger, divergence_data, train_iter) + + +def collect_and_log_divergences_with_heatmaps(tb_logger, merged_stats, train_iter): + """ + Overview: + Complete distribution divergence computation and logging (including no-diagonal heatmaps). + End-to-end pipeline for analyzing and visualizing inter-task distribution differences. + Arguments: + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - merged_stats (:obj:`dict`): Merged MOE statistics from distributed training. + - train_iter (:obj:`int`): Current training iteration for logging. + Examples: + >>> collect_and_log_divergences_with_heatmaps(tb_logger, merged_stats, 1000) + """ + try: + # GPU-optimized computation + divergence_data = compute_distribution_divergences_optimized(merged_stats, 'immediate') + + if not divergence_data: + print(f"Skipping distribution divergence computation - insufficient tasks (need >=2 tasks)") + return + + # Log metrics and heatmaps + log_divergences_with_heatmaps(tb_logger, divergence_data, train_iter) + + # Summary print + print(f">> Distribution divergence statistics completed and logged to TensorBoard") + if divergence_data.get('n_tasks', 0) <= 25: + print(f">> Similarity matrix heatmaps generated (diagonal removed)") + if divergence_data.get('n_tasks', 0) <= 20: + print(f">> Task pair detailed metrics logged") + print() # Blank line separator + + except Exception as e: + print(f"ERROR: Distribution divergence computation failed - {e}") + import traceback + traceback.print_exc() diff --git a/lzero/mcts/ctree/ctree_alphazero/pybind11 b/lzero/mcts/ctree/ctree_alphazero/pybind11 deleted file mode 160000 index 98bd78f06..000000000 --- a/lzero/mcts/ctree/ctree_alphazero/pybind11 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 98bd78f063b2f30570740030cb2d13b2a62a062c diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index 97c3528c0..60f389dc2 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -46,7 +46,7 @@ def default_config(cls: type) -> EasyDict: cfg.cfg_type = cls.__name__ + 'Dict' return cfg - def __init__(self, cfg: EasyDict = None) -> None: + def __init__(self, cfg: EasyDict = None,eval=False) -> None: """ Overview: Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key @@ -56,9 +56,13 @@ def __init__(self, cfg: EasyDict = None) -> None: default_config = self.default_config() default_config.update(cfg) self._cfg = default_config + if eval: + self._cfg.num_simulations=self._cfg.eval_num_simulations + self.inverse_scalar_transform_handle = InverseScalarTransform( self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution ) + @classmethod def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "mz_ctree": diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py index 0e050502d..e9e75d55d 100644 --- a/lzero/model/unizero_model_multitask.py +++ b/lzero/model/unizero_model_multitask.py @@ -6,7 +6,7 @@ from easydict import EasyDict from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ - VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook + VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook #,ModelGradientHook from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model_multitask import WorldModelMT @@ -124,7 +124,7 @@ def __init__( )) elif world_model_cfg.encoder_type == "vit": for task_id in range(1): # TODO: one share encoder - if world_model_cfg.task_num <=8: + if world_model_cfg.task_num ==1: # # vit base # self.representation_network.append(ViT( # image_size =observation_shape[1], @@ -144,16 +144,42 @@ def __init__( patch_size = 8, num_classes = obs_act_embed_dim, dim = 768, - depth = 6, - heads = 6, - mlp_dim = 2048, + depth = 12, + heads = 12, + mlp_dim = 3072, dropout = 0.1, emb_dropout = 0.1, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, - # ==================== 新增/修改部分 开始 ==================== - config=world_model_cfg # <--- 将包含LoRA参数的配置传递给ViT - # ==================== 新增/修改部分 结束 ==================== - + + )) + elif world_model_cfg.task_num <=8: + # # vit base + # self.representation_network.append(ViT( + # image_size =observation_shape[1], + # patch_size = 8, + # num_classes = obs_act_embed_dim, + # dim = 768, + # depth = 12, + # heads = 12, + # mlp_dim = 3072, + # dropout = 0.1, + # emb_dropout = 0.1, + # final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + # )) + # vit small + self.representation_network.append(ViT( + image_size =observation_shape[1], + patch_size = 8, + num_classes = obs_act_embed_dim, + dim = 768, + depth = 12, + heads = 12, + mlp_dim = 3072, + dropout = 0.1, + emb_dropout = 0.1, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + + )) elif world_model_cfg.task_num > 8: # vit base @@ -196,6 +222,11 @@ def __init__( self.encoder_hook = FeatureAndGradientHook() self.encoder_hook.setup_hooks(self.representation_network) + # if True: # Fixme: for debug + # # 增加对encoder的hook,监控传播到encoder 上的梯度 + # self.encoder_output_hook = ModelGradientHook() + # self.encoder_output_hook.setup_hook(self.representation_network) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False, obs_type=world_model_cfg.obs_type) self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') diff --git a/lzero/model/unizero_world_models/__init__.py b/lzero/model/unizero_world_models/__init__.py index c1d02cb8c..e69de29bb 100644 --- a/lzero/model/unizero_world_models/__init__.py +++ b/lzero/model/unizero_world_models/__init__.py @@ -1 +0,0 @@ -from .transformer import Transformer, TransformerConfig diff --git a/lzero/model/unizero_world_models/moe.py b/lzero/model/unizero_world_models/moe.py index 8ee8115ee..11ab3a5a7 100644 --- a/lzero/model/unizero_world_models/moe.py +++ b/lzero/model/unizero_world_models/moe.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from simple_parsing.helpers import Serializable from torch import nn - +import torch.distributed as dist from lzero.model.unizero_world_models.transformer import _maybe_wrap_linear # _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim), config, "feed_forward") @@ -59,7 +59,8 @@ def __init__(self, config, experts: List[nn.Module], gate: nn.Module, num_expert self.num_experts_per_tok = num_experts_per_tok self.gate = gate self.experts = nn.ModuleList(experts) - + self.config=config + # 如果配置中指定了共享专家数量,则构建共享专家分支 if hasattr(config, "n_shared_experts") and config.n_shared_experts > 0: self.shared_expert = nn.Sequential( @@ -69,34 +70,54 @@ def __init__(self, config, experts: List[nn.Module], gate: nn.Module, num_expert ) else: self.shared_expert = None + + # GPU memory expert selection statistics collector - multi-granularity sliding windows + self.device = next(iter(experts)).w1.weight.device if experts else torch.device('cuda') + + # Sliding window configuration + self.window_sizes = { + 'immediate': 100, # Immediate statistics (last 100 steps) + 'short': 1000, # Short-term statistics (last 1000 steps) + 'medium': 10000, # Medium-term statistics (last 10000 steps) + 'long': 100000 # Long-term statistics (last 100000 steps) + } + + # GPU statistics buffer: task_id -> {window_type -> [expert selection history]} + self.expert_stats_gpu = {} + self.step_count = 0 - def forward(self, x: torch.Tensor) -> torch.Tensor: + + def forward(self, x: torch.Tensor, task_id: int = None) -> torch.Tensor: # 保存原始形状后将 x reshape 为二维张量: [batch_size * seq_len, dim] original_shape = x.size() x = x.view(-1, self.dim) - - # 计算门控 logits,shape 为 [N, num_experts],N 为 token 数量 - gate_logits = self.gate(x) - # 选取每个 token 得分最高的 k 个专家 - weights, indices = torch.topk(gate_logits, self.num_experts_per_tok, dim=1) - # 对选中的 logits 做 softmax,获得归一化权重 - weights = F.softmax(weights, dim=1).to(x.dtype) - - # 初始化存放专家计算输出的张量 - expert_output = torch.zeros_like(x) - - # 遍历所有专家,对被该专家选择的 token 分支进行计算 - for expert_id in range(self.num_experts): - # 通过 where 找到 indices 中等于当前 expert_id 的 token 索引 - batch_idx, expert_tok_idx = torch.where(indices == expert_id) - if batch_idx.numel() == 0: - continue - token_subset = x[batch_idx] # 选中的 token,形状 [num_tokens, dim] - # 调用当前专家模块计算输出 - output_expert = self.experts[expert_id](token_subset) - # 获取对应 token 的权重,注意 weights 的形状为 [N, num_experts_per_tok] - token_weights = weights[batch_idx, expert_tok_idx].unsqueeze(-1) - expert_output[batch_idx] += output_expert * token_weights + expert_output=x + if self.num_experts!=0: + # 计算门控 logits,shape 为 [N, num_experts],N 为 token 数量 + gate_logits = self.gate(x) + # 选取每个 token 得分最高的 k 个专家 + weights, indices = torch.topk(gate_logits, self.num_experts_per_tok, dim=1) + # 对选中的 logits 做 softmax,获得归一化权重 + weights = F.softmax(weights, dim=1).to(x.dtype) + + if self.training and task_id is not None: + self._collect_expert_selection_stats(task_id, indices) + + # 初始化存放专家计算输出的张量 + expert_output = torch.zeros_like(x) + + # 遍历所有专家,对被该专家选择的 token 分支进行计算 + for expert_id in range(self.num_experts): + # 通过 where 找到 indices 中等于当前 expert_id 的 token 索引 + batch_idx, expert_tok_idx = torch.where(indices == expert_id) + if batch_idx.numel() == 0: + continue + token_subset = x[batch_idx] # 选中的 token,形状 [num_tokens, dim] + # 调用当前专家模块计算输出 + output_expert = self.experts[expert_id](token_subset) + # 获取对应 token 的权重,注意 weights 的形状为 [N, num_experts_per_tok] + token_weights = weights[batch_idx, expert_tok_idx].unsqueeze(-1) + expert_output[batch_idx] += output_expert * token_weights # 如果使用了共享专家分支,则加上其输出 if self.shared_expert is not None: @@ -107,14 +128,153 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # 恢复原始形状后返回结果 return output.view(original_shape) + + def _collect_expert_selection_stats(self, task_id: int, indices: torch.Tensor): + """ + Overview: + Collect expert selection statistics in GPU memory using multi-granularity sliding windows. + Maintains separate rolling buffers for different time window sizes to track expert usage patterns. + Arguments: + - task_id (:obj:`int`): The identifier of the current task. + - indices (:obj:`torch.Tensor`): Expert indices selected by the router for the current batch. + Shapes: + - indices: :math:`(N, k)` where N is batch size and k is number of experts per token. + Examples: + >>> # Collect stats for task 0 with expert indices + >>> indices = torch.tensor([[0, 2], [1, 3]]) # batch_size=2, k=2 + >>> moe_layer._collect_expert_selection_stats(task_id=0, indices=indices) + """ + self.step_count += 1 + + if task_id not in self.expert_stats_gpu: + self.expert_stats_gpu[task_id] = {} + for window_type in self.window_sizes.keys(): + self.expert_stats_gpu[task_id][window_type] = torch.zeros( + self.window_sizes[window_type], + self.num_experts, + dtype=torch.float32, + device=self.device + ) + + # Calculate expert selection frequency for current batch + indices_flat = indices.flatten() # [N*k] + expert_counts = torch.zeros(self.num_experts, device=self.device, dtype=torch.float32) + for expert_id in range(self.num_experts): + expert_counts[expert_id] = (indices_flat == expert_id).sum().float() + + # Update sliding windows for all granularities + for window_type, window_size in self.window_sizes.items(): + buffer = self.expert_stats_gpu[task_id][window_type] + # Sliding window: new data goes to the end, old data moves forward + buffer[:-1] = buffer[1:].clone() + buffer[-1] = expert_counts + + def get_expert_selection_stats(self, task_id: int = None): + """ + Overview: + Get multi-granularity expert selection frequency statistics. + Simplified version that directly returns current data without complex aggregation. + Arguments: + - task_id (:obj:`int`, optional): The identifier of the specific task. If None, returns stats for all tasks. + Returns: + - stats (:obj:`dict`): Dictionary containing expert selection statistics. + Structure: {task_id: {window_type: {frequencies, total_counts, total_selections, data_points}}} + Examples: + >>> # Get stats for all tasks + >>> all_stats = moe_layer.get_expert_selection_stats() + >>> # Get stats for specific task + >>> task_stats = moe_layer.get_expert_selection_stats(task_id=0) + """ + if task_id is None: + # Return statistics for all tasks + all_stats = {} + for tid in self.expert_stats_gpu.keys(): + all_stats[tid] = self._compute_task_stats(tid) + return all_stats + else: + # Return statistics for specified task + return self._compute_task_stats(task_id) + + def _compute_task_stats(self, task_id: int): + """ + Overview: + Compute multi-granularity statistics for a specified task. + Processes expert selection data across different time window granularities. + Arguments: + - task_id (:obj:`int`): The identifier of the task to compute statistics for. + Returns: + - stats (:obj:`dict`): Dictionary containing computed statistics for each window type. + Structure: {window_type: {frequencies, total_counts, total_selections, data_points}} + Shapes: + - frequencies: :math:`(num\_experts,)` normalized selection frequencies per expert. + - total_counts: :math:`(num\_experts,)` absolute selection counts per expert. + Examples: + >>> # Compute stats for task 0 + >>> task_stats = moe_layer._compute_task_stats(task_id=0) + >>> immediate_freq = task_stats['immediate']['frequencies'] + """ + if task_id not in self.expert_stats_gpu: + return {} + + stats = {} + for window_type, buffer in self.expert_stats_gpu[task_id].items(): + # Simplified version: directly average all existing data, ignoring whether window is full + # buffer shape: [window_size, num_experts] + total_counts = buffer.sum(dim=0) # [num_experts] + total_selections = total_counts.sum() + + if total_selections > 0: + frequencies = total_counts / total_selections + else: + frequencies = torch.zeros(self.num_experts, device=self.device) + + stats[window_type] = { + 'frequencies': frequencies, # Keep tensor format + 'total_counts': total_counts, # Keep tensor format + 'total_selections': total_selections.item(), + 'data_points': min(self.step_count, self.window_sizes[window_type]) + } + + return stats + + def reset_expert_selection_stats(self): + """ + Overview: + Reset expert selection statistics by clearing all accumulated data. + Clears GPU memory buffers and resets step counter to initial state. + Examples: + >>> # Reset all expert selection statistics + >>> moe_layer.reset_expert_selection_stats() + """ + self.expert_stats_gpu.clear() + self.step_count = 0 class MoELayerOptimized(nn.Module): - r""" - 与原 MoELayer 接口保持一致,但 forward 端到端为 O(N_token + ΣE_i), - 其中 ΣE_i 为各 expert 实际处理的 token 数量。 + """ + Overview: + Optimized MoE layer that maintains interface consistency with original MoELayer. + Provides end-to-end forward pass with O(N_token + ΣE_i) complexity, + where ΣE_i is the total number of tokens actually processed by all experts. + Interfaces: + - __init__: Initialize the optimized MoE layer with experts and gating mechanism. + - forward: Perform optimized forward pass through the MoE layer. """ def __init__(self, config, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok: int = 1): + """ + Overview: + Initialize the optimized MoE layer with configuration, experts, and gating mechanism. + Sets up expert modules, routing gate, and optional shared experts. + Arguments: + - config (:obj:`object`): Configuration object containing model parameters like embed_dim and n_shared_experts. + - experts (:obj:`List[nn.Module]`): List of expert neural network modules. + - gate (:obj:`nn.Module`): Gating network for routing tokens to experts. + - num_experts_per_tok (:obj:`int`, optional): Number of experts to select per token. Default is 1. + Examples: + >>> experts = [nn.Linear(512, 512) for _ in range(8)] + >>> gate = nn.Linear(512, 8) + >>> moe_layer = MoELayerOptimized(config, experts, gate, num_experts_per_tok=2) + """ super().__init__() self.dim = config.embed_dim self.num_experts = len(experts) @@ -130,11 +290,27 @@ def __init__(self, config, experts: List[nn.Module], gate: nn.Module, nn.Linear(config.n_shared_experts * (4 * self.dim), self.dim), ) - def forward(self, x: torch.Tensor) -> torch.Tensor: # [B, T, D] + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Perform optimized forward pass through the MoE layer. + Routes tokens to appropriate experts and combines their outputs efficiently. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor containing token embeddings. + Returns: + - output (:obj:`torch.Tensor`): Processed tensor after expert routing and combination. + Shapes: + - x: :math:`(B, T, D)` where B is batch size, T is sequence length, D is embedding dimension. + - output: :math:`(B, T, D)` same shape as input. + Examples: + >>> x = torch.randn(2, 10, 512) # batch_size=2, seq_len=10, embed_dim=512 + >>> output = moe_layer.forward(x) + >>> print(output.shape) # torch.Size([2, 10, 512]) + """ # [B, T, D] B, T, D = x.shape x_flat = x.reshape(-1, D) # [N, D]; N = B*T - # -------- 1. 路由 ---------- + # -------- 1. Routing ---------- gate_logits = self.gate(x_flat) # [N, E] weights, topk_idx = torch.topk( gate_logits, self.num_experts_per_tok, dim=1 @@ -142,27 +318,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [B, T, D] weights = F.softmax(weights, dim=1).to(x.dtype) # [N, k] - # ---- 2. 扁平化 token-expert 对 ---- + # ---- 2. Flatten token-expert pairs ---- N, k = weights.shape flat_token_idx = torch.arange(N, device=x.device).repeat_interleave(k) # [N*k] flat_expert_idx = topk_idx.reshape(-1) # [N*k] flat_weight = weights.reshape(-1, 1) # [N*k, 1] flat_input = x_flat[flat_token_idx] # [N*k, D] - # ---- 3. 按 expert 分块 ---- + # ---- 3. Group by expert ---- sort_order = torch.argsort(flat_expert_idx) # [N*k] flat_expert_idx = flat_expert_idx[sort_order] flat_token_idx = flat_token_idx[sort_order] flat_weight = flat_weight[sort_order] flat_input = flat_input[sort_order] - # 每个 expert 的样本计数 + # Sample count for each expert counts = torch.bincount(flat_expert_idx, minlength=self.num_experts) # [E] - # 准备输出缓冲 + # Prepare output buffer out_buffer = torch.zeros_like(flat_input) # [N*k, D] - # ---- 4. 逐 expert 一次前向 ---- + # ---- 4. Process each expert sequentially ---- ptr = 0 for eid, num in enumerate(counts.tolist()): if num == 0: @@ -171,12 +347,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [B, T, D] out_buffer[seg] = self.experts[eid](flat_input[seg]) ptr += num - # ---- 5. 加权并散射回 token ---- - out_buffer.mul_(flat_weight) # inplace 权重 + # ---- 5. Weight and scatter back to tokens ---- + out_buffer.mul_(flat_weight) # inplace weighting token_output = torch.zeros_like(x_flat) # [N, D] token_output.index_add_(0, flat_token_idx, out_buffer) - # ---- 6. 共享专家(若有) ---- + # ---- 6. Shared experts (if any) ---- if self.use_shared: token_output.add_(self.shared_expert(x_flat)) diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index 3edf4f1c9..f84c8d7e9 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -10,18 +10,19 @@ import math from dataclasses import dataclass from typing import Optional - +from easydict import EasyDict import torch import torch.nn as nn from ding.torch_utils.network import GRUGatingUnit from einops import rearrange from torch.nn import functional as F - +import torch.distributed as dist from .kv_caching import KeysValues from line_profiler import line_profiler from lzero.model.common import SimNorm import logging +from typing import Dict, List, Any # class LearnableScale(nn.Module): # """ @@ -340,6 +341,7 @@ def max_tokens(self): return self.tokens_per_block * self.max_blocks + class Transformer(nn.Module): """ Transformer model class. @@ -359,12 +361,21 @@ def __init__(self, config: TransformerConfig, task_embed=None) -> None: self.config = config self.drop = nn.Dropout(config.embed_pdrop) self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)]) + # self.blocks[-1].is_last_block=True self.ln_f = nn.LayerNorm(config.embed_dim) - + + self.num_blocks=len(self.blocks) + self.num_experts=config.num_experts_of_moe_in_transformer + self.task_embed = task_embed self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings self.register_token_shared = True + self.shared_expert=0 + if hasattr(config, "n_shared_experts") and config.n_shared_experts > 0: + self.shared_expert = config.n_shared_experts + + # TODO: 共享模式下,所有任务使用同一参数 if self.task_embed_option == "register_task_embed": @@ -441,7 +452,6 @@ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: device = self.ln_f.weight.device # Assumption: All submodules are on the same device return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) - #@profile def forward( self, @@ -473,9 +483,11 @@ def forward( # 逐层调用 for i, block in enumerate(self.blocks): + # 标识是否为最后一层 + is_last_block = (i == len(self.blocks) - 1) x = block(x, - None if past_keys_values is None else past_keys_values[i], - valid_context_lengths) + None if past_keys_values is None else past_keys_values[i], + valid_context_lengths, is_last_block=is_last_block, task_id=task_id) # 最后层 LN x = self.ln_f(x) @@ -492,6 +504,258 @@ def forward( x = x[:, :-self.register_token_num, :] return x + + def get_expert_selection_stats(self, task_id: int = None): + """ + Overview: + Retrieve MoE (Mixture of Experts) expert selection statistics from the last transformer block. + These statistics provide insights into expert utilization patterns and load balancing. + Arguments: + - task_id (:obj:`int`, optional): Task identifier for task-specific statistics. Default is None. + Returns: + - stats (:obj:`dict`): Dictionary containing expert selection statistics such as expert usage counts, + load balancing metrics, and routing probabilities. + Examples: + >>> transformer = Transformer(config) + >>> stats = transformer.get_expert_selection_stats(task_id=0) + >>> print(f"Expert usage: {stats.get('expert_usage', {})}") + """ + if len(self.blocks) == 0: + return {} + + last_block = self.blocks[-1] + + # Check if the last block has MoE layer + if not hasattr(last_block, 'feed_forward') or not hasattr(last_block.feed_forward, 'get_expert_selection_stats'): + return {} + + return last_block.feed_forward.get_expert_selection_stats(task_id) + + def reset_expert_selection_stats(self): + """ + Overview: + Reset MoE (Mixture of Experts) expert selection statistics for the last transformer block. + This method clears accumulated statistics used for load balancing and expert utilization analysis. + Arguments: + - None: This method takes no parameters. + Returns: + - None: This method performs reset operations without return values. + Examples: + >>> transformer = Transformer(config) + >>> transformer.reset_expert_selection_stats() + """ + if len(self.blocks) == 0: + return + + last_block = self.blocks[-1] + + # Check if the last block has MoE layer + if hasattr(last_block, 'feed_forward') and hasattr(last_block.feed_forward, 'reset_expert_selection_stats'): + last_block.feed_forward.reset_expert_selection_stats() + + # # : + # def has_shared_experts(self) -> bool: + # """ + # 检查Transformer是否使用了共享专家 + + # Returns: + # bool: 如果任何一个block使用了共享专家则返回True,否则返回False + # """ + # for block in self.blocks: + # if hasattr(block, 'feed_forward') and hasattr(block.feed_forward, 'shared_expert'): + # if block.feed_forward.shared_expert is not None: + # return True + # return False + + + + def get_shared_expert_gradients_by_block_id(self, block_id: int) -> Dict[str, torch.Tensor]: + """ + Overview: + Retrieve parameter gradients of shared experts from a specified transformer block. + Extracts gradients from the shared expert module within the feed-forward layer. + Arguments: + - block_id (:obj:`int`): Block identifier (0 to num_layers-1). + Returns: + - gradients (:obj:`Dict[str, torch.Tensor]`): Dictionary containing parameter names and corresponding gradients. + Raises: + - ValueError: When block_id is out of range or block doesn't have shared experts. + Examples: + >>> transformer = TransformerModel(config) + >>> gradients = transformer.get_shared_expert_gradients_by_block_id(block_id=2) + >>> print(f"Shared expert gradients: {list(gradients.keys())}") + """ + if block_id < 0 or block_id >= len(self.blocks): + raise ValueError(f"Block ID {block_id} out of range. Available blocks: 0-{len(self.blocks)-1}") + + block = self.blocks[block_id] + + # Check if block has feed_forward attribute and supports MoE + if not hasattr(block, 'feed_forward'): + raise ValueError(f"Block {block_id} doesn't have feed_forward layer") + + # Check if block has shared experts + if not hasattr(block.feed_forward, 'shared_expert') or block.feed_forward.shared_expert is None: + raise ValueError(f"Block {block_id} doesn't have shared expert") + + # Collect gradients from shared experts + gradients = {} + shared_expert = block.feed_forward.shared_expert + + for name, param in shared_expert.named_parameters(): + if param.grad is not None: + gradients[f"shared_expert.{name}"] = param.grad.clone() + else: + gradients[f"shared_expert.{name}"] = None + + return gradients + + + + def get_expert_gradients_for_last_block(self) -> Dict[str, torch.Tensor]: + """ + Overview: + Retrieve parameter gradients of all experts from the last transformer block. + Collects gradients from all independent expert modules in the final layer. + Returns: + - gradients (:obj:`List[torch.Tensor]`): List containing flattened gradient tensors for each expert. + Examples: + >>> transformer = TransformerModel(config) + >>> expert_gradients = transformer.get_expert_gradients_for_last_block() + >>> print(f"Number of experts: {len(expert_gradients)}") + """ + if len(self.blocks) == 0: + return [] + + # Get the last block + last_block = self.blocks[-1] + gradients = [] + + # Check if block has feed_forward attribute + if not hasattr(last_block, 'feed_forward'): + return gradients + + feed_forward = last_block.feed_forward + + # Check if it's a MoE structure + if hasattr(feed_forward, 'experts') and feed_forward.experts is not None: + # Collect gradients from all independent experts + for expert_idx, expert in enumerate(feed_forward.experts): + expert_gradients = [] + for name, param in expert.named_parameters(): # + if param.grad is not None: + expert_gradients.append(param.grad.clone().view(-1)) + else: + expert_gradients.append(torch.zeros_like(param).view(-1)) + expert_gradients=torch.cat(expert_gradients, dim=0) + gradients.append(expert_gradients) + + return gradients + + + + # added by tangjia : + def get_block_before_moe_gradients(self) -> Dict[int, torch.Tensor]: + """ + Overview: + Retrieve gradients of the block layer before MoE (Mixture of Experts) processing from the last block. + This method provides access to intermediate gradients for gradient analysis and debugging. + Arguments: + - None: This method takes no parameters. + Returns: + - gradients (:obj:`Dict[int, torch.Tensor]`): Dictionary containing block gradients before MoE layer, + with block indices as keys and gradient tensors as values. + Examples: + >>> transformer = Transformer(config) + >>> gradients = transformer.get_block_before_moe_gradients() + >>> print(f"Gradient shape: {gradients.shape if gradients is not None else 'None'}") + """ + # Return the gradient from the last block + return self.blocks[-1].block_before_moe_grad + + + def get_last_shared_expert_gradients(self) -> List[Dict[str, torch.Tensor]]: + """ + Overview: + Retrieve parameter gradients from the shared expert in the last transformer block. + This method provides access to shared expert gradients for gradient analysis and optimization monitoring. + Arguments: + - None: This method takes no parameters. + Returns: + - gradients (:obj:`torch.Tensor`): Concatenated tensor containing all shared expert parameter gradients + flattened into a single dimension for analysis. + Shapes: + - gradients: :math:`(D,)` where D is the total number of parameters in the shared expert. + Examples: + >>> transformer = Transformer(config) + >>> shared_grads = transformer.get_last_shared_expert_gradients() + >>> print(f"Shared expert gradient shape: {shared_grads.shape}") + """ + if len(self.blocks) == 0: + return [] + + # Get the last block + last_block = self.blocks[-1] + + shared_expert_gradients = [] + shared_expert = last_block.feed_forward.shared_expert + + for name, param in shared_expert.named_parameters(): + if param.grad is not None: + shared_expert_gradients.append(param.grad.clone().view(-1)) + else: + shared_expert_gradients.append(torch.zeros_like(param).view(-1)) + + return torch.concat(shared_expert_gradients, dim=0) + + def get_last_block_expert_selection_stats(self): + """ + Overview: + Retrieve MoE (Mixture of Experts) expert selection statistics specifically from the last transformer block. + This method provides focused analysis of expert utilization in the final layer. + Arguments: + - None: This method takes no parameters. + Returns: + - stats (:obj:`dict`): Dictionary containing expert selection statistics from the last block, + including expert usage patterns, routing decisions, and load balancing metrics. + Examples: + >>> transformer = Transformer(config) + >>> stats = transformer.get_last_block_expert_selection_stats() + >>> print(f"Last block expert stats: {stats}") + """ + if len(self.blocks) == 0: + return {} + + last_block = self.blocks[-1] + + # Check if the last layer has MoE + if hasattr(last_block, 'feed_forward') and hasattr(last_block.feed_forward, 'get_expert_selection_stats'): + return last_block.feed_forward.get_expert_selection_stats() + else: + return {} + + def reset_last_block_expert_selection_stats(self): + """ + Overview: + Reset MoE (Mixture of Experts) expert selection statistics specifically for the last transformer block. + This method clears accumulated statistics in the final layer for fresh monitoring. + Arguments: + - None: This method takes no parameters. + Returns: + - None: This method performs reset operations without return values. + Examples: + >>> transformer = Transformer(config) + >>> transformer.reset_last_block_expert_selection_stats() + """ + if len(self.blocks) == 0: + return + + last_block = self.blocks[-1] + + # Check if the last layer has MoE + if hasattr(last_block, 'feed_forward') and hasattr(last_block.feed_forward, 'reset_expert_selection_stats'): + last_block.feed_forward.reset_expert_selection_stats() + @@ -526,8 +790,8 @@ def __init__(self, config: TransformerConfig) -> None: self.ln1 = nn.LayerNorm(config.embed_dim) self.ln2 = nn.LayerNorm(config.embed_dim) self.attn = SelfAttention(config) - - + self.config=config + if config.moe_in_transformer: from .moe import MoELayer, MultiplicationFeedForward # 创Create multiple independent MLP instances @@ -588,13 +852,14 @@ def __init__(self, config: TransformerConfig) -> None: _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim), config, "feed_forward"), nn.GELU(approximate='tanh'), _maybe_wrap_linear(nn.Linear(4 * config.embed_dim, config.embed_dim), config, "feed_forward"), - nn.Dropout(config.resid_pdrop), + # nn.Dropout(config.resid_pdrop), ) + self.block_before_moe_grad = None def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + valid_context_lengths: Optional[torch.Tensor] = None, is_last_block=False, task_id: int = 0) -> torch.Tensor: """ - Forward pass of the Transformer block. + Forward pass of the Transformer block.self.is_last_block Arguments: - x (:obj:`torch.Tensor`): Input tensor of shape (batch_size, seq_length, embed_dim). @@ -604,15 +869,31 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None Returns: - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). """ + x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths) if self.gru_gating: x = self.gate1(x, x_attn) x = self.gate2(x, self.feed_forward(self.ln2(x))) else: x = x + x_attn - x = x + self.feed_forward(self.ln2(x)) + block_before_moe=self.ln2(x) + if self.training and is_last_block: + # Clear previous gradients + self.block_before_moe_grad = None + # Use safer hook registration to avoid closure issues + def grad_hook(grad): + self.block_before_moe_grad = grad.clone() # Clone gradient to avoid reference issues + return None + block_before_moe.register_hook(grad_hook) + + # Pass task_id for expert selection statistics collection in the last layer with MoE + if is_last_block and self.config.multiplication_moe_in_transformer and hasattr(self.feed_forward, 'forward'): + x = x + self.feed_forward(block_before_moe, task_id=task_id) + else: + x = x + self.feed_forward(block_before_moe) return x + class SelfAttention(nn.Module): @@ -804,4 +1085,4 @@ def get_attention_map(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = No att = att.masked_fill(mask == 0, float('-inf')) att = F.softmax(att, dim=-1) - return att \ No newline at end of file + return att diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index ecb583504..912a02ed4 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -183,7 +183,8 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: # TODO: check the effect of SimNorm # self.act_embedding_table = nn.Sequential( # nn.Linear(config.action_space_size, config.embed_dim, device=self.device, bias=False), - # SimNorm(simnorm_dim=self.group_size)) + # SimNorm(simnorm_dim=self.group_size) + # ) # print(f'config.action_space_size_list:{config.action_space_size_list}') self.act_embedding_table = nn.ModuleList([ nn.Sequential( @@ -319,6 +320,9 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.reanalyze_phase = False self._rank = get_rank() + + self.obs_embeddings_grad = None # 保留参数 + def _scale_grad(self, grad: torch.Tensor) -> torch.Tensor: # ① 1/k 缩放;若想更保守可用 1/√k # return grad / self.task_num @@ -1024,7 +1028,7 @@ def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, va enumerate(past_keys_values)] return torch.cat(x, dim=0) else: - return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths) + return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths,task_id=task_id) #@profile @torch.no_grad() @@ -1796,6 +1800,9 @@ def gather_and_plot(self, local_embeddings, local_task_ids, local_observations): def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, task_id = 0, **kwargs: Any) -> LossWithIntermediateLosses: # Encode observations into latent state representations obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + obs_embeddings.register_hook(lambda grad: setattr(self, 'obs_embeddings_grad', grad)) #note: register hook to save gradients of obs_embeddings + if self.analysis_tsne: # =========== tsne analysis =========== diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py index 13ba63eb2..cc3f3efb6 100644 --- a/lzero/policy/unizero_multitask.py +++ b/lzero/policy/unizero_multitask.py @@ -14,10 +14,10 @@ from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs from lzero.policy.unizero import UniZeroPolicy -from .utils import configure_optimizers_nanogpt +from .utils import configure_optimizers_nanogpt, compute_gradient_conflict_distributed, log_gradient_conflict_heatmaps_distributed_fast import sys -sys.path.append('/cpfs04/user/puyuan/code/LibMTL') +# sys.path.append('/cpfs04/user/puyuan/code/LibMTL') # sys.path.append('/fs-computility/niuyazhe/puyuan/code/LibMTL') from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect @@ -25,6 +25,7 @@ # from LibMTL.weighting.moco_fast import FastMoCo, MoCoCfg from LibMTL.weighting.moco_fast_mem_eff import FastMoCoMemEff as FastMoCo from LibMTL.weighting.moco_fast_mem_eff import MoCoCfg +import torch.distributed as dist @@ -130,7 +131,7 @@ def zero_grad(self, set_to_none=False): self.act_embedding_table.zero_grad(set_to_none=set_to_none) - +from line_profiler import LineProfiler @POLICY_REGISTRY.register('unizero_multitask') class UniZeroMTPolicy(UniZeroPolicy): """ @@ -140,7 +141,12 @@ class UniZeroMTPolicy(UniZeroPolicy): by addressing the limitations found in MuZero-style algorithms, particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. """ - + def __init__(self, cfg, model = None, enable_field = None): + super().__init__(cfg, model, enable_field) + self.step=0 + self.save_freq=100 + + # The default_config for UniZero policy. config = dict( type='unizero_multitask', @@ -320,11 +326,11 @@ class UniZeroMTPolicy(UniZeroPolicy): n_episode=8, # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. num_segments=8, - # # (int) the number of simulations in MCTS for renalyze. + + # (int) the number of simulations in MCTS for collect. num_simulations=50, - # (int) The number of simulations in MCTS for the collect phase. - collect_num_simulations=25, - # (int) The number of simulations in MCTS for the eval phase. + # (int) the number of simulations in MCTS for eval. If not set, use num_simulations. + eval_num_simulations=50, # (float) Discount factor (gamma) for returns. discount_factor=0.997, @@ -422,7 +428,7 @@ def _init_learn(self) -> None: device_type=self._cfg.device, betas=(0.9, 0.95), ) - + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR @@ -552,7 +558,7 @@ def _retain_prev_if_zero(self, name: str, self._prev_plasticity_metrics[name] = value return value - + #@profile def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_grad=False) -> Dict[str, Union[float, int]]: """ @@ -609,7 +615,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) else: obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) - + # Apply augmentations if needed if self._cfg.use_augmentation: obs_batch = self.image_transforms.transform(obs_batch) @@ -641,7 +647,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr # Transform rewards and values to their scaled forms transformed_target_reward = scalar_transform(target_reward) transformed_target_value = scalar_transform(target_value) - + # Convert to categorical distributions target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) target_value_categorical = phi_transform(self.value_support, transformed_target_value) @@ -672,8 +678,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr # Update world model intermediate_losses = defaultdict(float) + losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, task_id=task_id + batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, task_id=task_id# 是否需要统计expert 的选择 ) weighted_total_loss += losses.loss_total # TODO @@ -775,9 +782,131 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr # Core learn model update step self._optimizer_world_model.zero_grad() + + + # ===================================#======================================== + + + self._learn_model.world_model.tokenizer.encoder[0].grad = None + # encoder_grad=self._learn_model.world_model.obs_embeddings_grad.view(-1) + # world_size = dist.get_world_size() + # gathered_grads = [torch.zeros_like(encoder_grad) for _ in range(world_size)] + + multi_gpu = dist.is_initialized() and self._cfg.multi_gpu + rank = dist.get_rank() if multi_gpu else 0 + + self.log_conflict_var=False + self.log_conflict_matrix=False + if self.step % self.save_freq==0: + self.log_conflict_var=True + # if self.step % (self.save_freq * 100) == 0: + # self.log_conflict_matrix=True + + if self.log_conflict_var: + matrix_dict={} + num_experts= self._learn_model.world_model.transformer.num_experts + + + local_task_num = len(losses_list) + local_encoder_grad_list = [] + local_before_moe_grad_list = [] + local_shared_expert_grad_list = [] + local_last_block_expert_grad_list = [[] for _ in range(num_experts)] + + gradient_conflict_log_dict = {} + + for i in range(local_task_num): + # Clear gradients before each computation to ensure independence + self._optimizer_world_model.zero_grad() + # Compute gradient conflicts on encoder + losses_list[i].backward(retain_graph=True) # retain graph since backward will be called later + local_encoder_grad_list.append(self._learn_model.world_model.obs_embeddings_grad.view(-1).detach().clone()) + + + # self_attention last transformer block + before_moe_grad=self._learn_model.world_model.transformer.get_block_before_moe_gradients() + local_before_moe_grad_list.append(before_moe_grad.view(-1).detach().clone()) + + # Get gradients of the shared expert + if self._learn_model.world_model.transformer.shared_expert>0 : + # get_shared_expert_gradients_by_block_id + shared_expert_grad_for_last_task= self._learn_model.world_model.transformer.get_last_shared_expert_gradients() # gradients of the shared expert in the last block + local_shared_expert_grad_list.append(shared_expert_grad_for_last_task) + + # Compute gradient conflicts of experts in the last block + if num_experts>0: + last_block_expert_grad_list = self._learn_model.world_model.transformer.get_expert_gradients_for_last_block() + for j in range(num_experts): + local_last_block_expert_grad_list[j].append(last_block_expert_grad_list[j]) + + + + + # Clear shared parameter gradients to avoid accumulation + self._optimizer_world_model.zero_grad() + + # 1. Compute gradient conflicts after attention and before MOE + local_before_moe_grad_list=torch.stack(local_before_moe_grad_list,dim=0) # shape: (local_task_num, encoder_grad_dim) + before_moe_grad_conflict_ddp=compute_gradient_conflict_distributed(local_before_moe_grad_list, device=self._cfg.device) + gradient_conflict_log_dict['avg_before_moe_grad_conflict'] = before_moe_grad_conflict_ddp.avg_conflict_score if before_moe_grad_conflict_ddp is not None else 0 + gradient_conflict_log_dict['max_before_moe_grad_conflict'] = before_moe_grad_conflict_ddp.max_conflict_score if before_moe_grad_conflict_ddp is not None else 0 + if self.log_conflict_matrix and before_moe_grad_conflict_ddp is not None : + matrix_dict['before_moe_grad_conflict_matrix']=before_moe_grad_conflict_ddp.cosine_similarity_matrix + + + + # cosine_similarity_matrix self.logger + + # 2. Compute gradient conflicts of encoder + local_encoder_grad_list=torch.stack(local_encoder_grad_list,dim=0) # shape: (local_task_num, encoder_grad_dim) + encoder_grad_conflict_ddp=compute_gradient_conflict_distributed(local_encoder_grad_list, device=self._cfg.device) + gradient_conflict_log_dict['avg_encoder_grad_conflict'] = encoder_grad_conflict_ddp.avg_conflict_score if encoder_grad_conflict_ddp is not None else 0 + gradient_conflict_log_dict['max_encoder_grad_conflict'] = encoder_grad_conflict_ddp.max_conflict_score if encoder_grad_conflict_ddp is not None else 0 + if self.log_conflict_matrix and encoder_grad_conflict_ddp is not None: + matrix_dict['encoder_grad_conflict_matrix']=encoder_grad_conflict_ddp.cosine_similarity_matrix + + + # 3. If shared expert exists, compute gradient conflicts on shared expert + if self._learn_model.world_model.transformer.shared_expert>0 : + local_shared_expert_grad_list=torch.stack(local_shared_expert_grad_list,dim=0) + shared_expert_grad_conflict= compute_gradient_conflict_distributed(local_shared_expert_grad_list, device=self._cfg.device) if len(local_shared_expert_grad_list)>0 else None + gradient_conflict_log_dict['avg_shared_expert_grad_conflict'] = shared_expert_grad_conflict.avg_conflict_score if shared_expert_grad_conflict is not None else 0 + gradient_conflict_log_dict['max_shared_expert_grad_conflict'] = shared_expert_grad_conflict.max_conflict_score if shared_expert_grad_conflict is not None else 0 + + + if self.log_conflict_matrix and shared_expert_grad_conflict is not None: + matrix_dict['shared_expert_grad_conflict_matrix']=shared_expert_grad_conflict.cosine_similarity_matrix + + # 4. Gradient conflicts of experts in the last block + last_block_expert_grad_conflict_ddp_list=[] + if num_experts>0: + for i in range(num_experts): + # Stack gradients of the last block experts across tasks + local_last_block_expert_grad_list[i]=torch.stack(local_last_block_expert_grad_list[i],dim=0) + # Compute gradient conflicts of each expert + expert_conflict=compute_gradient_conflict_distributed(local_last_block_expert_grad_list[i], device=self._cfg.device) + last_block_expert_grad_conflict_ddp_list.append(expert_conflict) + gradient_conflict_log_dict[f'avg_expert_{i}_grad_conflict'] = expert_conflict.avg_conflict_score if expert_conflict is not None else 0 + gradient_conflict_log_dict[f'max_expert_{i}_grad_conflict'] = expert_conflict.max_conflict_score if expert_conflict is not None else 0 + + if self.log_conflict_matrix and expert_conflict is not None: + matrix_dict[f'expert_{i}_grad_conflict_matrix']=shared_expert_grad_conflict.cosine_similarity_matrix + + all_moe_gradient=torch.cat(local_last_block_expert_grad_list, dim=1) + if self._learn_model.world_model.transformer.shared_expert>0 : + all_moe_gradient=torch.cat((local_shared_expert_grad_list,all_moe_gradient), dim=1) + all_moe_gradient_ddp=compute_gradient_conflict_distributed(all_moe_gradient, device=self._cfg.device) + + gradient_conflict_log_dict['avg_moe_layer_grad_conflict'] = all_moe_gradient_ddp.avg_conflict_score if all_moe_gradient_ddp is not None else 0 + gradient_conflict_log_dict['max_moe_layer_grad_conflict'] = all_moe_gradient_ddp.max_conflict_score if all_moe_gradient_ddp is not None else 0 + if self.log_conflict_matrix and all_moe_gradient_ddp is not None: + matrix_dict['max_moe_layer_grad_conflict_matrix']=all_moe_gradient_ddp.cosine_similarity_matrix + # 假设每个进程计算出的 losses_list 为可求梯度的 tensor list,比如多个标量 loss 组成的列表 # 例如 losses_list = [loss1, loss2, ...],其中每个 loss_i 都是形如 (1,) 的 tensor 且 requires_grad=True + + self._optimizer_world_model.zero_grad() if self._cfg.use_moco: # 调用 MoCo backward,由 grad_correct 中的 backward 实现梯度校正 if self._cfg.moco_version=="v0": @@ -794,6 +923,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device) weighted_total_loss.backward() + # print(f'Rank {rank} 正在反向传播') + # TODO: 使用 MoCo 或 CAGrad 来计算梯度和权重 # ============= for CAGrad and MoCo ============= # lambd = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) @@ -807,7 +938,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr # print('name, param.mean(), param.std():', name, param.mean(), param.std()) # if param.requires_grad: # print(name, param.grad.norm()) - + if self._cfg.analysis_sim_norm: del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() @@ -820,14 +951,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr # =========== NOTE: 对于一个GPU上所有任务都解决了的情况,为了ddp同步仍然调用train但是grad应该清零 =========== self._optimizer_world_model.zero_grad() # print(f"ignore_grad") - - # if self._cfg.multi_gpu: - # # Very important to sync gradients before updating the model - # # rank = get_rank() - # # print(f'Rank {rank} train task_id: {self._cfg.task_id} sync grad begin...') - # self.sync_gradients(self._learn_model) - # # print(f'Rank {rank} train task_id: {self._cfg.task_id} sync grad end...') - + + + # dist.barrier() # 确保所有进程都完成了梯度计算 if self._cfg.multi_gpu: # if not self._cfg.use_moco or self._cfg.only_use_moco_stats: # self.sync_gradients(self._learn_model) @@ -874,6 +1000,22 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr # 'target_policy_entropy': average_target_policy_entropy, 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), } + if self.log_conflict_matrix: + + # matrix_dict + # Convert to list for distributed processing + matrix_list = list(matrix_dict.items()) + log_gradient_conflict_heatmaps_distributed_fast(self.logger, matrix_list, self.step) + + if self.log_conflict_var: + # Log scalar values from gradient_conflict_log_dict to TensorBoard + for key, value in gradient_conflict_log_dict.items(): + print(f'正在记录梯度冲突分析 Rank {rank} Logging {key}: {value}') + + self.logger.add_scalar(f'gradient_conflict/{key}', value, self.step) + + # print(f'Rank {rank} 正在根据冲突记录日志') + # print(gradient_conflict_log_dict) # 生成任务相关的损失字典,并为每个任务相关的 loss 添加前缀 "noreduce_" # multi_task_loss_dicts = { @@ -942,7 +1084,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr return_loss_dict.update(multi_task_loss_dicts) # print(f'return_loss_dict:{return_loss_dict}') - # 返回最终的损失字典 + self.step+=1 + + + return return_loss_dict def monitor_weights_and_grads(self, model): @@ -988,6 +1133,10 @@ def _monitor_vars_learn(self, num_tasks=2) -> List[str]: tensorboard according to the return value ``_forward_learn``. If num_tasks is provided, generate monitored variables for each task. """ + # rank= dist.get_rank() if dist.is_initialized() else 0 + # print(f"Rank {rank} 开始记录日志1111") + + # Basic monitored variables that do not depend on the number of tasks monitored_vars = [ 'Current_GPU', @@ -997,7 +1146,7 @@ def _monitor_vars_learn(self, num_tasks=2) -> List[str]: 'cur_lr_world_model', 'weighted_total_loss', 'total_grad_norm_before_clip_wm', - ] + ] # rank = get_rank() task_specific_vars = [ @@ -1086,7 +1235,7 @@ def _monitor_vars_learn(self, num_tasks=2) -> List[str]: else: # If num_tasks is not provided, we assume there's only one task and keep the original variable names monitored_vars.extend(task_specific_vars) - + # print(f"Rank {rank} 日志记录完毕") return monitored_vars #@profile @@ -1222,14 +1371,22 @@ def _init_eval(self) -> None: """ self._eval_model = self._model - # 为 eval MCTS 创建一个配置副本,并设置特定的模拟次数 - mcts_eval_cfg = copy.deepcopy(self._cfg) - mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations + # 创建eval专用的配置对象,使用eval_num_simulations + # eval_cfg = copy.deepcopy(self._cfg) + # eval_num_simulations = getattr(self._cfg, 'eval_num_simulations', self._cfg.num_simulations) + # eval_cfg.num_simulations = eval_num_simulations + + # # 打印collect和eval的num_simulations设置 + # print(f"=== MCTS Simulations Config ===") + # print(f"Collect num_simulations: {self._cfg.num_simulations}") + # print(f"Eval num_simulations: {eval_num_simulations}") + # print(f"===============================") + if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(mcts_eval_cfg) + self._mcts_eval = MCTSCtree(self._cfg,eval=True) # 使用eval专用配置 else: - self._mcts_eval = MCTSPtree(mcts_eval_cfg) + self._mcts_eval = MCTSPtree(self._cfg) # 使用eval专用配置 self.evaluator_env_num = self._cfg.evaluator_env_num @@ -1502,9 +1659,9 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any], finetune_components - finetune_components (:obj:`List[str]`, optional): A list of component names that will remain trainable after loading. For example, it can include "encoder", "transformer", or both. The components not in this list will be frozen. """ - # finetune_components = [] # load-enc-trans_finetune-head - # finetune_components = ['transformer'] # load-enc-trans_finetune-trans-head - finetune_components = ["representation_network", "encoder"] # load-enc-trans_finetune-encoder-head + # # finetune_components = [] # load-enc-trans_finetune-head + # # finetune_components = ['transformer'] # load-enc-trans_finetune-trans-head + # finetune_components = ["representation_network", "encoder"] # load-enc-trans_finetune-encoder-head # 定义需要排除的参数前缀,即不加载这些参数 exclude_prefixes = [ @@ -1528,6 +1685,18 @@ def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, """ filtered = {} for k, v in state_dict_loader.items(): + # if any(prefix in k for prefix in ['head_policy_multi_task.', 'head_value_multi_task.', 'head_rewards_multi_task.', 'head_observations_multi_task.']): + # # 提取任务ID + # import re + # match = re.search(r'\.(\d+)\.', k) + # if match: + # task_id = int(match.group(1)) + # if task_id <=0: + # filtered[k] = v + # print(f"include {k}") + # continue + + if any(k.startswith(prefix) for prefix in exclude_prefixes): print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除 continue diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index 7cf259c0c..aa010d3e4 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -695,3 +695,546 @@ def mz_network_output_unpack(network_output: Dict) -> Tuple: value = network_output.value # shape: (batch_size, support_support_size) policy_logits = network_output.policy_logits # shape: (batch_size, action_space_size) return latent_state, reward, value, policy_logits + + +# ==================== #============================= +import torch.distributed as dist + +# ==================== Gradient Conflict Matrix Visualization Module ============================= +""" +Overview: + Gradient conflict matrix visualization module for analyzing and visualizing gradient conflicts + in distributed training scenarios. This module provides optimized heatmap generation and + distributed logging capabilities for gradient conflict analysis. +Interfaces: + - _get_or_create_figure: Get or create reusable matplotlib figure + - _fast_tensor_heatmap: Generate optimized heatmap tensor from matrix + - log_gradient_conflict_heatmaps_distributed_fast: High-performance distributed heatmap logging +""" + +# Pre-import matplotlib module to avoid repeated import overhead +import matplotlib +matplotlib.use('Agg') + +# Global figure cache +_GLOBAL_FIG_CACHE = None +_GLOBAL_AX_CACHE = None + +def _get_or_create_figure(figsize=(8, 6)): + """ + Overview: + Get or create reusable matplotlib figure for memory efficiency. + Arguments: + - figsize (:obj:`tuple`): Figure size as (width, height), default is (8, 6). + Returns: + - fig (:obj:`matplotlib.figure.Figure`): Matplotlib figure object. + - ax (:obj:`matplotlib.axes.Axes`): Matplotlib axes object. + Examples: + >>> fig, ax = _get_or_create_figure((10, 8)) + >>> ax.plot([1, 2, 3], [4, 5, 6]) + """ + global _GLOBAL_FIG_CACHE, _GLOBAL_AX_CACHE + if _GLOBAL_FIG_CACHE is None: + _GLOBAL_FIG_CACHE, _GLOBAL_AX_CACHE = plt.subplots(figsize=figsize) + return _GLOBAL_FIG_CACHE, _GLOBAL_AX_CACHE + +def _fast_tensor_heatmap(matrix_np, tag): + """ + Overview: + Generate optimized heatmap tensor with performance enhancements by skipping text annotations + and removing diagonal elements for better visualization. + Arguments: + - matrix_np (:obj:`numpy.ndarray`): Input matrix for heatmap generation. + - tag (:obj:`str`): Tag label for the heatmap title. + Returns: + - img_tensor (:obj:`torch.Tensor`): RGB image tensor with shape :math:`(3, H, W)`. + Shapes: + - matrix_np: :math:`(N, M)` where N and M are matrix dimensions. + - img_tensor: :math:`(3, H, W)` where H and W are image dimensions. + Examples: + >>> matrix = np.random.randn(5, 5) + >>> heatmap_tensor = _fast_tensor_heatmap(matrix, "conflict_matrix") + >>> print(heatmap_tensor.shape) # torch.Size([3, height, width]) + """ + # 复制矩阵以避免修改原始数据 + matrix_no_diag = matrix_np.copy() + + # 移除对角线元素(设为0) + if matrix_no_diag.shape[0] == matrix_no_diag.shape[1]: # 方阵才有对角线 + np.fill_diagonal(matrix_no_diag, 0) + + # 创建新的figure而不是复用全局缓存 + fig, ax = plt.subplots(figsize=(8, 6)) + + # 直接使用矩阵,对角线已设为0 + # 使用Blues colormap,调整颜色范围为-0.2到0.2 + im = ax.imshow(matrix_no_diag, cmap='Blues', vmin=-0.2, vmax=0.2) + ax.set_title(f'{tag}', fontsize=12) + + # 只在小矩阵时添加数值标注(避免O(n²)开销) + if matrix_no_diag.size <= 64: # 8x8或更小 + for row in range(matrix_no_diag.shape[0]): + for col in range(matrix_no_diag.shape[1]): + if row != col: # 跳过对角线元素 + value = matrix_no_diag[row, col] + text_color = "white" if value > 0.5 else "black" + ax.text(col, row, f'{value:.2f}', + ha="center", va="center", color=text_color, fontsize=8) + + # 快速转换为tensor + fig.canvas.draw() + try: + # 尝试新版matplotlib的方法 + if hasattr(fig.canvas, 'buffer_rgba'): + buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) + buf = buf.reshape(fig.canvas.get_width_height()[::-1] + (4,)) + img_tensor = torch.from_numpy(buf[:, :, :3]).permute(2, 0, 1).float() / 255.0 + elif hasattr(fig.canvas, 'tostring_rgb'): + # 旧版matplotlib方法 + buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + buf = buf.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img_tensor = torch.from_numpy(buf).permute(2, 0, 1).float() / 255.0 + else: + # PIL回退方案 + try: + from PIL import Image + import io + buf = io.BytesIO() + fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) + buf.seek(0) + pil_img = Image.open(buf).convert('RGB') + img_array = np.array(pil_img) + img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).float() / 255.0 + except Exception: + # 最终回退方案:创建简单的蓝色矩阵 + h, w = matrix_no_diag.shape + img_tensor = torch.zeros(3, h*50, w*50) # 简单放大 + img_tensor[2] = torch.from_numpy(matrix_no_diag).repeat_interleave(50, 0).repeat_interleave(50, 1) + except Exception: + # 回退方案:创建简单的蓝色矩阵 + h, w = matrix_no_diag.shape + img_tensor = torch.zeros(3, h*50, w*50) # 简单放大 + img_tensor[2] = torch.from_numpy(matrix_no_diag).repeat_interleave(50, 0).repeat_interleave(50, 1) + finally: + # 关闭图形释放内存 + plt.close(fig) + + return img_tensor + + +def log_gradient_conflict_heatmaps_distributed_fast(tb_logger, matrix_list, step): + """ + Overview: + High-performance distributed heatmap processing with optimizations for reduced latency. + Key optimizations include pre-imported matplotlib modules, figure object reuse, + text annotation skipping for large matrices, conditional barriers, and robust error recovery. + Arguments: + - tb_logger (:obj:`tensorboard logger`): TensorBoard logger instance for logging heatmaps. + - matrix_list (:obj:`list`): List of (tag, matrix) tuples where tag is string identifier + and matrix is conflict matrix tensor. + - step (:obj:`int`): Global training step number for logging. + Returns: + - None: Function performs logging operations without return values. + Examples: + >>> import torch + >>> from torch.utils.tensorboard import SummaryWriter + >>> tb_logger = SummaryWriter() + >>> matrices = [("task1", torch.randn(5, 5)), ("task2", torch.randn(3, 3))] + >>> log_gradient_conflict_heatmaps_distributed_fast(tb_logger, matrices, 100) + """ + if not matrix_list: + return + + rank = dist.get_rank() + world_size = dist.get_world_size() + + try: + # 批处理:每个GPU处理自己的矩阵 + processed_any = False + for i in range(rank, len(matrix_list), world_size): + tag, matrix = matrix_list[i] + if matrix is not None and matrix.numel() > 0: + matrix_np = matrix.detach().cpu().numpy() + + # 使用优化的热力图生成 + img_tensor = _fast_tensor_heatmap(matrix_np, tag) + tb_logger.add_image(f'gradient_conflict_matrix/{tag}', img_tensor, global_step=step) + processed_any = True + + # 条件性同步:只有处理了数据的GPU才需要barrier + if processed_any or rank == 0: # rank 0始终参与同步以防死锁 + dist.barrier() + + except Exception as e: + print(f"Rank {rank}: Error in optimized heatmap logging: {e}") + # 紧急同步避免死锁 + try: + dist.barrier() + except: + pass + +# ==================== 原有的梯度冲突计算模块 ============================= + + + +def example_usage(): + """ + Overview: + Example usage demonstration for gradient conflict analysis computation. + Generates sample gradients and computes conflict analysis results including average conflict score, + maximum conflict score, number of conflicting gradient pairs, average conflict intensity, + gradient norms, and cosine similarity matrix. + Arguments: + - None: Function generates sample gradients internally for demonstration. + Returns: + - None: Function prints results to console without return values. + Examples: + >>> example_usage() + # Output: + # Gradient Conflict Analysis Results: + # Average conflict score: 0.1234 + # Maximum conflict score: 0.5678 + # Number of conflicting pairs: 3 + # Average conflict intensity: 0.2345 + # Gradient norms: [tensor1, tensor2, tensor3] + # Cosine similarity matrix: + # tensor([[1.0000, -0.1234, 0.5678], + # [-0.1234, 1.0000, -0.3456], + # [0.5678, -0.3456, 1.0000]]) + """ + # 生成示例梯度 + torch.manual_seed(42) + gradients = [ + torch.randn(100), # 梯度1 + torch.randn(100), # 梯度2 + torch.randn(100), # 梯度3 + ] + + # 计算冲突 + conflicts = compute_gradient_conflicts(gradients) + + print("梯度冲突分析结果:") + print(f"平均冲突得分: {conflicts['avg_conflict_score']:.4f}") + print(f"最大冲突得分: {conflicts['max_conflict_score']:.4f}") + print(f"冲突梯度对数量: {conflicts['num_conflicting_pairs']}") + print(f"平均冲突强度: {conflicts['avg_conflict_intensity']:.4f}") + print(f"梯度范数: {conflicts['gradient_norms']}") + print("\n余弦相似度矩阵:") + print(conflicts['cosine_similarity_matrix']) + + + +def compute_gradient_conflicts(gradients: List[torch.Tensor]) -> dict: + """ + Overview: + Compute conflicts between multiple gradients using CUDA-optimized vectorized operations. + Calculates cosine similarity matrix and derives conflict scores for gradient analysis. + Arguments: + - gradients (:obj:`List[torch.Tensor]`): List of gradient tensors with identical shapes. + Returns: + - result (:obj:`dict`): Dictionary containing conflict analysis results with keys: + 'avg_conflict_score', 'max_conflict_score', 'min_conflict_score', + and 'cosine_similarity_matrix'. + Shapes: + - gradients[i]: :math:`(D_1, D_2, ..., D_n)` where all gradients have identical dimensions. + - cosine_similarity_matrix: :math:`(N, N)` where N is the number of gradients. + Examples: + >>> import torch + >>> gradients = [torch.randn(100), torch.randn(100), torch.randn(100)] + >>> conflicts = compute_gradient_conflicts(gradients) + >>> print(f"Average conflict: {conflicts['avg_conflict_score']:.4f}") + >>> print(f"Similarity matrix shape: {conflicts['cosine_similarity_matrix'].shape}") + """ + n_gradients = len(gradients) + + # 如果只有一个梯度,没有冲突 + if n_gradients <= 1: + device = gradients[0].device if gradients else torch.device('cuda') + return EasyDict({ + 'avg_conflict_score': 0.0, + 'max_conflict_score': 0.0, + 'min_conflict_score': 0.0, + 'cosine_similarity_matrix': torch.zeros(1, 1, device=device) + }) + + # 确保所有梯度形状相同 + assert all(g.shape == gradients[0].shape for g in gradients), "梯度形状必须相同" + + device = gradients[0].device + + # 向量化计算:堆叠并normalize所有梯度 + stacked_grads = torch.stack([g.flatten() for g in gradients]) + normalized_grads = F.normalize(stacked_grads, p=2, dim=1) + + # 一次性计算余弦相似度矩阵 + cosine_sim_matrix = torch.mm(normalized_grads, normalized_grads.t()) + + # 排除对角线元素 + mask = ~torch.eye(n_gradients, device=device, dtype=torch.bool) + conflict_scores = -cosine_sim_matrix[mask] + + return EasyDict({ + 'avg_conflict_score': conflict_scores.mean().item(), + 'max_conflict_score': conflict_scores.max().item(), + 'min_conflict_score': conflict_scores.min().item(), + 'cosine_similarity_matrix': cosine_sim_matrix + }) + + +def compute_gradient_conflict_distributed(local_grads, multi_gpu=True, device=0): + """ + Overview: + Distributed gradient conflict computation with hierarchical aggregation optimization. + Achieves 69.4x speedup (3.1ms vs 212.7ms) through layered preprocessing, + NCCL direct communication, and vectorized computation. + Arguments: + - local_grads (:obj:`torch.Tensor`): Local gradient tensor for current rank. + - multi_gpu (:obj:`bool`, optional): Whether to use multi-GPU distributed mode. Default is True. + - device (:obj:`int`, optional): Current device index. Default is 0. + Returns: + - gradient_conflict (:obj:`dict`): Dictionary containing conflict analysis results identical + across all ranks, including 'avg_conflict_score', + 'max_conflict_score', 'min_conflict_score', and + 'cosine_similarity_matrix'. + Shapes: + - local_grads: :math:`(L, D)` where L is local task number and D is encoder gradient dimension. + - cosine_similarity_matrix: :math:`(N, N)` where N is total number of valid gradients across all ranks. + Examples: + >>> import torch + >>> import torch.distributed as dist + >>> local_grads = torch.randn(5, 128) # 5 local tasks, 128-dim gradients + >>> conflicts = compute_gradient_conflict_distributed(local_grads, multi_gpu=True, device=0) + >>> print(f"Average conflict: {conflicts['avg_conflict_score']:.4f}") + """ + if not multi_gpu: + # 单GPU模式:直接使用优化的单机版本 + norms = torch.norm(local_grads, dim=1) + valid_grads = local_grads[norms > 1e-8] + if valid_grads.shape[0] <= 1: + device = valid_grads.device + return EasyDict({ + 'avg_conflict_score': 0.0, + 'max_conflict_score': 0.0, + 'min_conflict_score': 0.0, + 'cosine_similarity_matrix': torch.zeros(1, 1, device=device) + }) + + # 向量化计算 + device = valid_grads.device + normalized = F.normalize(valid_grads, p=2, dim=1) + similarity = torch.mm(normalized, normalized.t()) + mask = ~torch.eye(valid_grads.shape[0], device=device, dtype=torch.bool) + conflicts = -similarity[mask] + return EasyDict({ + 'avg_conflict_score': conflicts.mean().item(), + 'max_conflict_score': conflicts.max().item(), + 'min_conflict_score': conflicts.min().item(), + 'cosine_similarity_matrix': similarity + }) + + # 多GPU分布式模式:分层聚合优化 + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f'{device}') + + # === 第一层:本地预处理(关键优化)=== + norms = torch.norm(local_grads, dim=1) + valid_grads = local_grads[norms > 1e-8] + local_normalized = F.normalize(valid_grads, p=2, dim=1) # 预归一化,避免重复计算 + + # 收集各rank的有效梯度数量 + valid_count = torch.tensor(valid_grads.shape[0], device=device) + valid_counts = [torch.tensor(0, device=device) for _ in range(world_size)] + dist.all_gather(valid_counts, valid_count) + + total_valid = sum(v.item() for v in valid_counts) + if total_valid <= 1: + return EasyDict({ + 'avg_conflict_score': 0.0, + 'max_conflict_score': 0.0, + 'min_conflict_score': 0.0, + 'cosine_similarity_matrix': torch.zeros(1, 1, device=device) + }) + + # 数据对齐:padding到相同大小 + max_valid = max(v.item() for v in valid_counts) + if valid_grads.shape[0] < max_valid: + pad_size = max_valid - valid_grads.shape[0] + pad_tensor = torch.zeros(pad_size, valid_grads.shape[1], device=device, dtype=valid_grads.dtype) + local_normalized = torch.cat([local_normalized, pad_tensor], dim=0) + + # === 第二层:高效NCCL聚合 === + gathered_normalized = [torch.empty_like(local_normalized) for _ in range(world_size)] + dist.all_gather(gathered_normalized, local_normalized) # GPU直接通信,传输预处理数据 + + # if rank == 0: + # === 第三层:向量化冲突计算 === + # 重建有效的归一化梯度 + all_valid_normalized = [] + for i, count in enumerate(valid_counts): + if count > 0: + all_valid_normalized.append(gathered_normalized[i][:count.item()]) + + if len(all_valid_normalized) == 0: + return EasyDict({ + 'avg_conflict_score': 0.0, + 'max_conflict_score': 0.0, + 'min_conflict_score': 0.0, + 'cosine_similarity_matrix': torch.zeros(1, 1, device=device) + }) + + all_normalized = torch.cat(all_valid_normalized, dim=0) + + # 高效向量化计算(一次矩阵乘法替代O(n²)循环) + similarity = torch.mm(all_normalized, all_normalized.t()) + mask = ~torch.eye(similarity.shape[0], device=device, dtype=torch.bool) + conflicts = -similarity[mask] + + return EasyDict({ + 'avg_conflict_score': conflicts.mean().item(), + 'max_conflict_score': conflicts.max().item(), + 'min_conflict_score': conflicts.min().item(), + 'cosine_similarity_matrix': similarity + }) + +def compute_gradient_conflicts_batch(gradient_groups: Dict[str, torch.Tensor], device=0) -> Dict[str, dict]: + """ + Overview: + Batch computation of gradient conflicts for multiple gradient groups to reduce + distributed communication overhead through optimized data aggregation. + Arguments: + - gradient_groups (:obj:`Dict[str, torch.Tensor]`): Dictionary mapping group names to + local gradient tensors. + - device (:obj:`int`, optional): Device index for tensor operations. Default is 0. + Returns: + - results (:obj:`Dict[str, dict]`): Dictionary mapping group names to conflict analysis + results, each containing 'avg_conflict_score', + 'max_conflict_score', 'min_conflict_score', and + 'cosine_similarity_matrix'. + Shapes: + - gradient_groups[group_name]: :math:`(L, D)` where L is local task number and D is gradient dimension. + - results[group_name]['cosine_similarity_matrix']: :math:`(N, N)` where N is total valid gradients for the group. + Examples: + >>> import torch + >>> gradient_groups = { + ... "encoder": torch.randn(5, 128), + ... "decoder": torch.randn(3, 64) + ... } + >>> results = compute_gradient_conflicts_batch(gradient_groups, device=0) + >>> print(f"Encoder conflicts: {results['encoder']['avg_conflict_score']:.4f}") + >>> print(f"Decoder conflicts: {results['decoder']['avg_conflict_score']:.4f}") + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + results = {} + + if world_size == 1: + # 单GPU模式 + for group_name, local_grads in gradient_groups.items(): + if local_grads.numel() == 0: + results[group_name] = EasyDict({'avg_conflict_score': 0.0}) + continue + + # 过滤零梯度 + norms = torch.norm(local_grads, dim=1) + valid_mask = norms > 1e-8 + local_grads_filtered = local_grads[valid_mask] + + if local_grads_filtered.shape[0] <= 1: + results[group_name] = EasyDict({ + 'avg_conflict_score': 0.0, + 'max_conflict_score': 0.0, + 'min_conflict_score': 0.0, + 'cosine_similarity_matrix': torch.zeros(1, 1, device=device) + }) + else: + grad_list = [local_grads_filtered[i] for i in range(local_grads_filtered.shape[0])] + results[group_name] = compute_gradient_conflicts(grad_list) + return results + + # 多GPU模式 - 一次性收集所有梯度组 + # 准备本地数据:过滤零梯度并记录有效数量 + local_filtered_groups = {} + local_valid_counts = {} + + for group_name, local_grads in gradient_groups.items(): + if local_grads.numel() == 0: + local_filtered_groups[group_name] = torch.empty(0, 0, device=device) + local_valid_counts[group_name] = 0 + continue + + norms = torch.norm(local_grads, dim=1) + valid_mask = norms > 1e-8 + filtered = local_grads[valid_mask] + local_filtered_groups[group_name] = filtered + local_valid_counts[group_name] = filtered.shape[0] + + # 收集所有rank的有效数量 + all_valid_counts = [None for _ in range(world_size)] + dist.all_gather_object(all_valid_counts, local_valid_counts) + + # 计算每组的最大任务数,用于填充 + max_counts = {} + for group_name in gradient_groups.keys(): + counts = [counts_dict.get(group_name, 0) for counts_dict in all_valid_counts] + max_counts[group_name] = max(counts) if counts else 0 + + # 填充并准备发送数据 + local_padded_groups = {} + for group_name, filtered_grads in local_filtered_groups.items(): + max_count = max_counts[group_name] + if max_count == 0: + local_padded_groups[group_name] = torch.empty(0, 0) + continue + + if filtered_grads.shape[0] < max_count: + if filtered_grads.numel() > 0: + pad_size = max_count - filtered_grads.shape[0] + grad_dim = filtered_grads.shape[1] + pad_tensor = torch.zeros(pad_size, grad_dim, device=device) + padded = torch.cat([filtered_grads, pad_tensor], dim=0) + else: + grad_dim = gradient_groups[group_name].shape[1] if gradient_groups[group_name].numel() > 0 else 1 + padded = torch.zeros(max_count, grad_dim, device=device) + else: + padded = filtered_grads + + local_padded_groups[group_name] = padded.cpu() + + # 一次性收集所有组的数据 + all_gradient_groups = [None for _ in range(world_size)] + dist.all_gather_object(all_gradient_groups, local_padded_groups) + + if rank == 0: + # 处理每个梯度组 + for group_name in gradient_groups.keys(): + # 收集该组的所有有效梯度 + valid_grad_list = [] + for rank_idx, rank_data in enumerate(all_gradient_groups): + if group_name in rank_data: + valid_count = all_valid_counts[rank_idx].get(group_name, 0) + if valid_count > 0: + tensor_valid = rank_data[group_name][:valid_count, :].to(device) + valid_grad_list.append(tensor_valid) + + if len(valid_grad_list) == 0: + results[group_name] = EasyDict({'avg_conflict_score': 0.0}) + else: + all_grads = torch.cat(valid_grad_list, dim=0) + if all_grads.shape[0] <= 1: + results[group_name] = EasyDict({'avg_conflict_score': 0.0}) + else: + grad_list = [all_grads[i] for i in range(all_grads.shape[0])] + results[group_name] = compute_gradient_conflicts(grad_list) + else: + results = None + + # 广播结果到所有rank + results_list = [results] + dist.broadcast_object_list(results_list, src=0) + return results_list[0] + + +if __name__ == "__main__": + example_usage() diff --git a/zoo/atari/config/README.md b/zoo/atari/config/README.md new file mode 100644 index 000000000..b11efeaa0 --- /dev/null +++ b/zoo/atari/config/README.md @@ -0,0 +1,92 @@ +The core of this version update revolves around the Mixture of Experts (MoE) architecture in multi-task reinforcement learning, introducing a powerful suite of tools for analysis and validation. Based on recent experimental research (see "MoE Experimental Analysis Summary"), we have developed features to monitor gradient conflicts and expert specialization in real-time, aiming to provide a deeper understanding of MoE's mechanisms and support its optimization. + +### 1. New Core Feature: Gradient Conflict Analysis System ++ Feature Introduction: + +An advanced, distributed-training-compatible gradient conflict analysis system has been introduced. This system can compute and visualize gradient conflicts between different model components in real-time, including the encoder, MoE layers, and shared experts. + ++ Experimental Relevance (Experiments 1 & 3): + +This feature directly stems from the experimental findings that MoE architectures effectively mitigate gradient conflicts, with most conflicts concentrated in the shared expert. This tool allows developers to quantify this effect, monitor training stability, and provide a data-driven basis for future routing and load-balancing strategies. + ++ **Technical Implementation:** + - **Conflict Calculation Logic:** Multi-level gradient conflict calculation and logging are integrated into the policy module at `lzero/policy/unizero_multitask.py`. + - **Distributed Calculation & Visualization:** High-efficiency functions for distributed gradient computation and heatmap generation are implemented in the utility library at `lzero/policy/utils.py`. + +### 2. New Core Feature: Expert Selection and Specialization Tracking ++ Feature Introduction: + +A new module for in-depth tracking of MoE expert selection behavior has been added. This module uses multi-granularity sliding windows (from an immediate 100 steps to a long-term 100,000 steps) to track the usage frequency of experts for each task, thereby quantifying the expert specialization process. + ++ Experimental Relevance (Experiment 2): + +This feature is designed to validate the conclusion from Experiment 2: as training progresses, experts gradually "specialize" for specific tasks (evidenced by a decrease in expert selection entropy). It provides key insights into how tasks are automatically partitioned among different experts. + ++ **Technical Implementation:** + - **Core Statistics Module:** Task-aware routing, a multi-window statistics collector, and the `get_expert_selection_stats` data retrieval interface are implemented in `lzero/model/unizero_world_models/moe.py`. + +### 3. Architecture Refactoring and Experimental Support ++ **Core Architecture Enhancements:** + - **Task ID Propagation:** The `lzero/model/unizero_world_models/transformer.py` and `world_model_multitask.py` have been refactored to support the propagation of the `task_id` throughout the entire forward pass. + - **Gradient Hooks:** Flexible gradient extraction hooks have been added in `world_model_multitask.py` to provide the underlying data for the analysis systems mentioned above. ++ **Comprehensive Experimental Configurations:** + - **Dedicated Configurations:** A new set of MoE-specific configuration files, such as `atari_unizero_multitask_segment_ddp_config_moe.py`, has been added to the `zoo/atari/config/` directory to facilitate comparative experiments. ++ **Performance and Debugging:** + - **Performance Profiling:** The `LineProfiler` tool has been integrated into `lzero/policy/unizero_multitask.py`. + - **Entry Points & Utilities:** Corresponding modifications have been made in `lzero/entry/train_unizero_multitask_segment_ddp.py` and `lzero/entry/utils.py` to support the new features and configurations. + +# SExperimental Analysis for Mixture-of-Experts (MoE) +This document summarizes the experimental setup and key findings from the analysis of Mixture-of-Experts (MoE) architectures in multitask reinforcement learning. The goal is to understand the mechanisms behind MoE's strong performance. + +### Experiment 1: Analyzing Gradient Conflicts in MoE-based Transformers +**Experimental Setup:** + ++ **Task Domain:** Atari-8. ++ **Architectures Compared:** + 1. **Naive Transformer:** A backbone with four standard Transformer blocks. + 2. **MoE-based Transformer:** A backbone of four Transformer blocks where each MLP layer is replaced by an MoE layer (consisting of one shared expert and eight non-shared experts). ++ **Measurement:** Gradient conflict between tasks is quantified using the maximum negative cosine similarity. + +![](https://cdn.nlark.com/yuque/0/2025/png/22947362/1758900605706-2f47ce39-1eb5-471c-b2aa-9fe98cd6769c.png) + ++ **Analysis Points:** Gradient conflicts were measured at three key locations: + 1. The input right before the MoE layer. + 2. The output of the encoder. + 3. The parameters within the MoE layer itself (shared expert, non-shared experts, and the entire layer). + +**Main Conclusion (Observation 1):** + +The primary finding is that the MoE-based Transformer demonstrates significantly fewer gradient conflicts at the MoE layer and its input compared to the standard Transformer with MLP layers. This suggests that the MoE architecture helps mitigate gradient conflicts not just within its own layer but also in other connected components. Conflict levels at the encoder output were comparable for both models, likely because the encoder learns general representations that inherently have fewer conflicts. + +![](https://cdn.nlark.com/yuque/0/2025/png/22947362/1758900622719-5b0f776e-8aff-4425-8087-19696ac514a3.png) + +### Experiment 2: Investigating MoE Gating Mechanisms +**Experimental Setup:** + ++ **Objective:** To determine if MoE experts effectively differentiate and specialize when dealing with non-stationary data from agent-environment interactions in RL. ++ **Metrics:** + 1. **Expert Selection Entropy:** Measures the uncertainty in expert choice for a given task. Lower entropy indicates higher specialization. + +![](https://cdn.nlark.com/yuque/0/2025/png/22947362/1758900647827-9bdf07f5-bfea-4ae2-b728-6a053ae3c7da.png) + + 2. **Wasserstein Distance:** Measures the similarity between the expert selection distributions of different tasks. ++ **Procedure:** Data on expert choices was collected over time windows of different sizes (_immediate_ = 100 steps, _short_ = 1,000 steps) to form probability distributions for analysis. + +**Main Conclusion (Observation 2):** + +The key observation from this experiment is that as training progresses, the entropy of the expert selection distribution for tasks gradually decreases. This indicates that the selection of experts becomes more certain and concentrated on a smaller subset over time, demonstrating a clear pattern of expert specialization and differentiation in the multitask setting. + +![](https://cdn.nlark.com/yuque/0/2025/png/22947362/1758900661959-e19e904f-f1e3-4832-aa06-2ecf60d6e2b5.png) + +### Experiment 3: Analyzing Gradient Conflicts Between Shared and Non-Shared Experts +**Experimental Setup:** + ++ **Objective:** To further analyze the source of gradient dynamics within the MoE architecture by comparing conflicts between shared and non-shared experts. ++ **Method:** The MoE-based Transformer was used to measure and compare the gradient conflicts experienced by the shared expert versus the eight individual non-shared experts. + +**Main Conclusion (Observation 3):** + +The results show that the shared expert bears a significantly higher level of gradient conflict compared to any of the non-shared, task-specific experts. In fact, most of the gradient conflicts within the entire MoE layer are concentrated on this shared component, while individual experts experience almost no conflict. This is attributed to the gating mechanism, which routes different tasks to different non-shared experts, leading to consistent gradient updates for each. In contrast, the shared expert must handle all tasks simultaneously, causing conflicting updates. Therefore, the introduction of non-shared experts is a key factor in reducing the overall gradient conflict of the MoE layer. + +![](https://cdn.nlark.com/yuque/0/2025/png/22947362/1758900675792-d0ee1bd7-5fba-4ee5-ad6d-c0d719c51823.png) + diff --git a/zoo/atari/config/READNE.zh.md b/zoo/atari/config/READNE.zh.md new file mode 100644 index 000000000..19102c95e --- /dev/null +++ b/zoo/atari/config/READNE.zh.md @@ -0,0 +1,84 @@ +本次版本更新的核心是围绕多任务强化学习中的混合专家模型(MoE)架构,引入了一套强大的分析与验证工具。基于最新的实验研究(参考《MoE实验分析总结》),我们开发了用于实时监控梯度冲突和专家特化过程的功能,旨在深入理解 MoE 的工作机制并为其优化提供数据支持。 + +### 1. 新增核心功能:梯度冲突分析系统 ++ **功能简介:** 引入了一个先进的、支持分布式训练的梯度冲突分析系统。该系统能够实时计算并可视化模型不同组件间的梯度冲突,包括编码器、MoE 层、共享专家等。 ++ **实验关联 (实验一 & 三):** 此功能直接源于实验发现——MoE 架构能有效缓解梯度冲突,且大部分冲突集中在共享专家上。通过此工具,开发者可以量化这一效应,监控训练稳定性,并为后续的路由和负载均衡策略提供数据依据。 ++ **技术实现:** + - **冲突计算逻辑:** 在策略模块 `lzero/policy/unizero_multitask.py` 中集成了多层级的梯度冲突计算与日志记录。 + - **分布式计算与可视化:** 在工具库 `lzero/policy/utils.py` 中实现了高效的分布式梯度计算和热力图生成函数。 + +### 2. 新增核心功能:专家选择与特化追踪 ++ **功能简介:** 新增了对 MoE 专家选择行为的深度追踪模块。该模块采用多粒度滑动窗口(从即时的100步到长期的100,000步)来统计每个任务对专家的使用频率,从而量化专家的特化过程。 ++ **实验关联 (实验二):** 该功能旨在验证实验二的结论,即随着训练进行,专家会逐渐为特定任务而“特化”(表现为专家选择熵的降低)。它为理解任务如何被自动划分给不同专家提供了关键洞察。 ++ **技术实现:** + - **核心统计模块:** 在 `lzero/model/unizero_world_models/moe.py` 中实现了任务感知的路由、多窗口统计收集器以及数据获取接口 `get_expert_selection_stats`。 + +### 3. 架构重构与实验支持 ++ **核心架构增强:** + - **任务ID传递:** 在 `lzero/model/unizero_world_models/transformer.py` 和 `world_model_multitask.py` 中进行了重构,以支持将任务ID (`task_id`) 贯穿整个前向传播过程。 + - **梯度钩子:** 在 `world_model_multitask.py` 中增加了灵活的梯度提取钩子,为上述分析系统提供底层数据。 ++ **完善的实验配置:** + - **专用配置:** 在 `zoo/atari/config/` 目录下新增了多套 MoE 专用配置文件,如 `atari_unizero_multitask_segment_ddp_config_moe.py`,便于进行对比实验。 ++ **性能与调试:** + - **性能分析:** 在 `lzero/policy/unizero_multitask.py` 中集成了性能分析工具 (`LineProfiler`)。 + - **入口与工具:** 在 `lzero/entry/train_unizero_multitask_segment_ddp.py` 和 `lzero/entry/utils.py` 中进行了相应修改,以支持新功能和配置。 + +# 混合专家模型 (MoE) 实验分析总结 +本文档总结了在多任务强化学习中对混合专家(MoE)架构进行的实验设置和主要发现,旨在理解 MoE 模型表现出色的背后机制。 + +### 实验一:分析基于 MoE 的 Transformer 中的梯度冲突 +**实验设置:** + ++ **任务领域:** Atari-8 ++ **对比架构:** + 1. **朴素 Transformer:** 使用四个标准 Transformer 模块作为骨干网络。 + 2. **基于 MoE 的 Transformer:** 骨干网络同样为四个 Transformer 模块,但每个模块中的 MLP 层被替换为 MoE 层(包含一个共享专家和八个非共享专家)。 ++ **测量指标:** 使用最大负余弦相似度来量化任务间的梯度冲突。 + +![](https://cdn.nlark.com/yuque/0/2025/png/22947362/1758900605706-2f47ce39-1eb5-471c-b2aa-9fe98cd6769c.png) + ++ **分析点:** 在三个关键位置测量了梯度冲突: + 1. MoE 层的输入端。 + 2. 编码器的输出端。 + 3. MoE 层内部的参数(包括共享专家、非共享专家以及整个层)。 + +**主要结论 (观察 1):** + +主要发现是,与使用标准 MLP 层的 Transformer 相比,基于 MoE 的 Transformer 在 MoE 层及其输入端的梯度冲突显著减少。这表明 MoE 架构不仅有助于缓解其自身层内的梯度冲突,还能减轻其他相连组件的冲突。两个模型在编码器输出端的冲突水平相当,这可能是因为编码器学习的是通用表示,其本身固有冲突较少。 + +_图表代码:_ + +![](https://cdn.nlark.com/yuque/0/2025/png/22947362/1758900622719-5b0f776e-8aff-4425-8087-19696ac514a3.png?x-oss-process=image%2Fformat%2Cwebp) + +### 实验二:探究 MoE 的门控机制 +**实验设置:** + ++ **目标:** 确定在处理来自强化学习中智能体与环境交互的非平稳数据时,MoE 专家是否能有效地区分和特化。 ++ **评估指标:** + 1. **专家选择熵:** 衡量特定任务选择专家的不确定性。熵值越低,表示专业化程度越高。 + +![](https://cdn.nlark.com/yuque/0/2025/png/22947362/1758900647827-9bdf07f5-bfea-4ae2-b728-6a053ae3c7da.png) + + 2. **Wasserstein 距离:** 衡量不同任务的专家选择分布之间的相似性。 ++ **流程:** 在不同大小的时间窗口(_即时_ = 100 步, _短期_ = 1,000 步)内收集专家选择数据,以构建用于分析的概率分布。 + +**主要结论 (观察 2):** + +该实验的关键观察是,随着训练的进行,任务的专家选择分布熵逐渐降低。这表明专家的选择随着时间的推移变得更加确定,并集中在一个较小的子集上,从而在多任务环境中展示出清晰的专家特化和分化模式。 + +![](https://cdn.nlark.com/yuque/0/2025/png/22947362/1758900661959-e19e904f-f1e3-4832-aa06-2ecf60d6e2b5.png) + +### 实验三:分析共享专家与非共享专家之间的梯度冲突 +**实验设置:** + ++ **目标:** 通过比较共享专家与非共享专家之间的冲突,进一步分析 MoE 架构内部梯度动态的来源。 ++ **方法:** 使用基于 MoE 的 Transformer 来测量和比较共享专家与八个独立的非共享专家所经历的梯度冲突。 + +**主要结论 (观察 3):** + +结果显示,与任何非共享的、任务特定的专家相比,共享专家承受的梯度冲突程度要高得多。事实上,整个 MoE 层内的大部分梯度冲突都集中在这个共享组件上,而单个专家几乎没有冲突。这归因于门控机制将不同任务路由到不同的非共享专家,从而为每个专家带来一致的梯度更新。相比之下,共享专家必须同时处理所有任务,导致更新冲突。因此,引入非共享专家是减少 MoE 层整体梯度冲突的关键因素。 + +![](https://cdn.nlark.com/yuque/0/2025/png/22947362/1758900675792-d0ee1bd7-5fba-4ee5-ad6d-c0d719c51823.png) + + + diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py index bdc5e4f7a..d355df1e5 100644 --- a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py @@ -55,7 +55,7 @@ def compute_batch_config( return batch_sizes, grad_acc_steps def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, - num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + num_simulations, eval_num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size, num_layers): return EasyDict(dict( @@ -143,7 +143,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu # num_layers=12, # todo num_heads=24, - embed_dim=768, + embed_dim=768, #768 obs_type='image', env_num=8, task_num=len(env_id_list), @@ -192,6 +192,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu cos_lr_scheduler=False, num_segments=num_segments, num_simulations=num_simulations, + eval_num_simulations=eval_num_simulations, reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, replay_buffer_size=int(5e5), @@ -204,9 +205,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu reanalyze_partition=reanalyze_partition, ), )) - def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, - num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + num_simulations, eval_num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size, num_layers): configs = [] @@ -247,7 +247,7 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod for task_id, env_id in enumerate(env_id_list): config = create_config( env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, - reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, + eval_num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size, num_layers ) config.policy.task_id = task_id @@ -333,7 +333,7 @@ def create_env_manager(): torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py """ - +# /fs-computility/niuyazhe/tangjia/code/LightZero-dev-multitask-balance-clean/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py from lzero.entry import train_unizero_multitask_segment_ddp from ding.utils import DDPContext import os @@ -346,7 +346,8 @@ def create_env_manager(): num_segments = 8 n_episode = 8 evaluator_env_num = 3 - num_simulations = 50 + num_simulations = 25 # collect时使用的模拟次数 + eval_num_simulations = 50 # eval时使用的模拟次数(可以设为更高获得更好评估质量) max_env_step = int(4e5) reanalyze_ratio = 0.0 @@ -379,7 +380,8 @@ def create_env_manager(): elif num_layers == 8: effective_batch_size = 512 # nlayer8 需要设置replay_ratio=0.5对应的upc=80 # effective_batch_size = 256 # moco nlayer8 需要设置replay_ratio=0.5对应的upc=80 - + elif num_layers == 1: + effective_batch_size = 256 elif len(env_id_list) == 26: # effective_batch_size = 832 # cnn-encoder # effective_batch_size = 1024 # base-vit-encoder transformer-nlayer4 or cnn-encoder @@ -427,7 +429,7 @@ def create_env_manager(): # for seed in [1]: for seed in [0]: configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, - num_simulations, reanalyze_ratio, batch_sizes, num_unroll_steps, infer_context_length, + num_simulations, eval_num_simulations, reanalyze_ratio, batch_sizes, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size, num_layers) diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_debug.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_debug.py index cddaae311..a069aac59 100644 --- a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_debug.py +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_debug.py @@ -1,6 +1,11 @@ from easydict import EasyDict import math +import sys +import os +PROJECT_ROOT = os.path.abspath("/fs-computility/niuyazhe/tangjia/github/LightZero") # 或者直接写死路径 +sys.path.insert(0, PROJECT_ROOT) +# /fs-computility/niuyazhe/tangjia/github/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_debug.py def compute_batch_config(env_id_list, effective_batch_size): n = len(env_id_list) @@ -64,8 +69,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu policy=dict( multi_gpu=True, # Very important for ddp only_use_moco_stats=False, - # use_moco=False, # ==============TODO============== - use_moco=True, # ==============TODO: moco============== + use_moco=False, # ==============TODO============== + # use_moco=True, # ==============TODO: moco============== learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), grad_correct_params=dict( MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, @@ -129,7 +134,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu # num_layers=12, # todo num_heads=24, - embed_dim=768, + embed_dim=768,#768 obs_type='image', env_num=8, task_num=len(env_id_list), @@ -142,9 +147,9 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu num_experts_in_moe_head=4, moe_in_transformer=False, - multiplication_moe_in_transformer=False, # ==============TODO:orig============== - # multiplication_moe_in_transformer=True, # =======TODO: moe8======= - n_shared_experts=1, + # multiplication_moe_in_transformer=False, # ==============TODO:orig============== + multiplication_moe_in_transformer=True, # =======TODO: moe8======= + n_shared_experts=1, # 共享expert 数量 num_experts_per_tok=1, num_experts_of_moe_in_transformer=8, @@ -197,7 +202,7 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod num_segments, total_batch_size, num_layers): configs = [] # ===== only for debug ===== - exp_name_prefix = f'data_unizero_atari_mt_20250522_debug/atari_{len(env_id_list)}games_orig_simnorm-kl_vit_moe8_moco-v1_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + exp_name_prefix = f'debug_log/atari_{len(env_id_list)}games_orig_simnorm-kl_vit_moe8_moco-v1_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' # ========= TODO: global BENCHMARK_NAME ========= @@ -292,7 +297,7 @@ def create_env_manager(): num_games = 8 # 26 # 8 - num_layers = 4 # ==============TODO============== + num_layers = 1 # ==============TODO============== action_space_size = 18 collector_env_num = 8 num_segments = 8 @@ -324,7 +329,8 @@ def create_env_manager(): effective_batch_size = 1024 # nlayer4 需要设置replay_ratio=0.25对应的upc=40 elif num_layers == 8: effective_batch_size = 512 # nlayer8 需要设置replay_ratio=0.5对应的upc=80 - + elif num_layers == 1: + effective_batch_size = 32 elif len(env_id_list) == 26: # effective_batch_size = 832 # cnn-encoder # effective_batch_size = 1024 # base-vit-encoder transformer-nlayer4 or cnn-encoder @@ -337,7 +343,7 @@ def create_env_manager(): batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size) total_batch_size = effective_batch_size # 当前无效 - + num_unroll_steps = 10 # infer_context_length = 4 infer_context_length = 5 # ==============TODO============== @@ -350,7 +356,9 @@ def create_env_manager(): # ======== TODO: only for debug ======== env_id_list = [ - 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4' + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + # 'SeaquestNoFrameskip-v4' ] num_layers = 1 # ==============TODO============== collector_env_num = 2 @@ -363,21 +371,28 @@ def create_env_manager(): infer_context_length = 2 batch_sizes = [2 for _ in range(len(env_id_list))] total_batch_size = 2*len(env_id_list) + + # ===========button from tangjia=========== + import torch.distributed as dist - for seed in [0,1]: + for seed in [100]: configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_sizes, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size, num_layers) - + + + with DDPContext(): + + # print(train_unizero_multitask_segment_ddp.__file__) train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step, benchmark_name= "atari" ) # ======== TODO: only for debug ======== # train_unizero_multitask_segment_ddp(configs[:2], seed=seed, max_env_step=max_env_step) # train on the first four tasks - # 手动销毁进程组 + # 手动销毁进程组 /fs-computility/niuyazhe/tangjia/github/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_debug.py if dist.is_initialized(): dist.destroy_process_group() \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py b/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py deleted file mode 100644 index badcd9585..000000000 --- a/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py +++ /dev/null @@ -1,236 +0,0 @@ -from easydict import EasyDict - -def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): - return EasyDict(dict( - env=dict( - stop_value=int(1e6), - env_id=env_id, - observation_shape=(3, 64, 64), - gray_scale=False, - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - n_evaluator_episode=evaluator_env_num, - manager=dict(shared_memory=False), - full_action_space=True, - collect_max_episode_steps=int(5e3), - eval_max_episode_steps=int(5e3), - # ===== only for debug ===== - # collect_max_episode_steps=int(20), - # eval_max_episode_steps=int(20), - ), - policy=dict( - multi_gpu=True, - only_use_moco_stats=False, - use_moco=False, # ==============TODO============== - # use_moco=True, # ==============TODO============== - learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000))), - grad_correct_params=dict( # Gradient correction parameters - MoCo_beta=0.5, - MoCo_beta_sigma=0.5, - MoCo_gamma=0.1, - MoCo_gamma_sigma=0.5, - MoCo_rho=0, - calpha=0.5, - rescale=1, - ), - task_num=len(env_id_list), - task_id=0, - model=dict( - observation_shape=(3, 64, 64), - action_space_size=action_space_size, - norm_type=norm_type, - num_res_blocks=2, - num_channels=256, - world_model_cfg=dict( - final_norm_option_in_obs_head='LayerNorm', - final_norm_option_in_encoder='LayerNorm', - predict_latent_loss_type='mse', # TODO: for latent state layer_norm - - # final_norm_option_in_obs_head='SimNorm', - # final_norm_option_in_encoder='SimNorm', - # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm - - share_head=False, # TODO - analysis_dormant_ratio_weight_rank=False, # TODO - dormant_threshold=0.025, - - continuous_action_space=False, - - task_embed_option=None, # ==============TODO: none ============== - use_task_embed=False, # ==============TODO============== - - # task_embed_option='concat_task_embed', # ==============TODO: none ============== - # use_task_embed=True, # ==============TODO============== - # task_embed_dim=96, - # task_embed_dim=128, - - use_shared_projection=False, - - max_blocks=num_unroll_steps, - max_tokens=2 * num_unroll_steps, - context_length=2 * infer_context_length, - device='cuda', - action_space_size=action_space_size, - num_layers=8, - num_heads=24, - embed_dim=768, - obs_type='image', - env_num=8, - task_num=len(env_id_list), - use_normal_head=True, - use_softmoe_head=False, - use_moe_head=False, - num_experts_in_moe_head=4, - moe_in_transformer=False, - multiplication_moe_in_transformer=False, - num_experts_of_moe_in_transformer=4, - - # LoRA 参数(启用LoRA) - lora_r=0, - # lora_r=8, - lora_alpha=32, - lora_dropout=0.1, - # 默认目标模块:attn和feed_forward - lora_target_modules=["attn", "feed_forward"], - # 调整finetune_components - ), - ), - use_task_exploitation_weight=False, # TODO - task_complexity_weight=False, # TODO - total_batch_size=total_batch_size, - allocated_batch_sizes=False, - train_start_after_envsteps=int(0), - use_priority=False, - print_task_priority_logs=False, - cuda=True, - model_path=None, - num_unroll_steps=num_unroll_steps, - game_segment_length=20, - update_per_collect=80, - replay_ratio=0.25, - batch_size=batch_size, - optim_type='AdamW', - cos_lr_scheduler=True, - num_segments=num_segments, - num_simulations=num_simulations, - reanalyze_ratio=reanalyze_ratio, - n_episode=n_episode, - replay_buffer_size=int(5e5), - eval_freq=int(2e4), - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - buffer_reanalyze_freq=buffer_reanalyze_freq, - reanalyze_batch_size=reanalyze_batch_size, - reanalyze_partition=reanalyze_partition, - ), - )) - -def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): - configs = [] - # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/amidar_load-enc-trans_finetune-head/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh8_upc80_seed{seed}/' - exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/amidar_load-enc-trans_finetune-head-encoder/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh8_upc80_seed{seed}/' - # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/amidar_load-enc-trans_finetune-head-trans/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh8_upc80_seed{seed}/' - # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/amidar_load-enc-trans_finetune-head-trans-lora/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh24_upc80_seed{seed}/' - - # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/pong_load-enc-trans_finetune-head-trans-lora/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh24_upc80_seed{seed}/' - # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/pong_load-enc-trans_finetune-head/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh24_upc80_seed{seed}/' - - - - for task_id, env_id in enumerate(env_id_list): - config = create_config( - env_id, - action_space_size, - collector_env_num, - evaluator_env_num, - n_episode, - num_simulations, - reanalyze_ratio, - batch_size, - num_unroll_steps, - infer_context_length, - norm_type, - buffer_reanalyze_freq, - reanalyze_batch_size, - reanalyze_partition, - num_segments, - total_batch_size - ) - config.policy.task_id = task_id - config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" - configs.append([task_id, [config, create_env_manager()]]) - return configs - -def create_env_manager(): - return EasyDict(dict( - env=dict( - type='atari_lightzero', - import_names=['zoo.atari.envs.atari_lightzero_env'], - ), - env_manager=dict(type='subprocess'), - policy=dict( - type='unizero_multitask', - import_names=['lzero.policy.unizero_multitask'], - ), - )) - -if __name__ == "__main__": - """ - Overview: - This script should be executed with GPUs. - Run the following command to launch the script: - python -m torch.distributed.launch --nproc_per_node=1 --master_port=29507 ./zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py - torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py - """ - - from lzero.entry import train_unizero_multitask_segment_ddp - from ding.utils import DDPContext - from easydict import EasyDict - - # env_id_list = ['PongNoFrameskip-v4'] # Debug setup - env_id_list = ['AmidarNoFrameskip-v4'] # Debug setup - - action_space_size = 18 - - # NCCL environment setup - import os - os.environ["NCCL_TIMEOUT"] = "3600000000" - - # for seed in [0, 1, 2]: - for seed in [0]: - collector_env_num = 8 - num_segments = 8 - n_episode = 8 - evaluator_env_num = 3 - num_simulations = 50 - max_env_step = int(4e5) - - reanalyze_ratio = 0.0 - total_batch_size = 512 - batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] - - num_unroll_steps = 10 - infer_context_length = 4 - norm_type = 'LN' - # buffer_reanalyze_freq = 1 / 50 - buffer_reanalyze_freq = 1 / 10000000 - reanalyze_batch_size = 160 - reanalyze_partition = 0.75 - - # ======== TODO: only for debug ======== - # collector_env_num = 2 - # num_segments = 2 - # n_episode = 2 - # evaluator_env_num = 2 - # num_simulations = 1 - # reanalyze_batch_size = 2 - # batch_size = [4, 4, 4, 4, 4, 4, 4, 4] - - configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size) - - # pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu_1127/8games_brf0.02_nlayer8-nhead24_seed1/8games_brf0.02_1-encoder-LN-res2-channel256_gsl20_8-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed1/Pong_unizero-mt_seed1/ckpt/iteration_200000.pth.tar' - # pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_atari_mt_20250217/atari_8games_notaskembed_bs64_brf0.02_seed0_dev-uz-mz-mt-cont/Pong_seed0_250218_124624/ckpt/ckpt_best.pth.tar' - - pretrained_model_path = '/fs-computility/ai-shen/puyuan/code/LightZero/data_lz/data_unizero_atari_mt_20250307/atari_8games_brf0.02_not-share-head_final-ln_seed0/Pong_seed0/ckpt/ckpt_best.pth.tar' - with DDPContext(): - train_unizero_multitask_segment_ddp(configs, seed=seed, model_path=pretrained_model_path, max_env_step=max_env_step) \ No newline at end of file