diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index cdb9568a9..e0017ce57 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -12,6 +12,8 @@ from .train_muzero_multitask_segment_ddp import train_muzero_multitask_segment_ddp from .train_unizero_multitask_segment_ddp import train_unizero_multitask_segment_ddp from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval +from .train_unizero_multitask_ddp import train_unizero_multitask_ddp +from .train_unizero_multitask import train_unizero_multitask from .utils import * from .train_unizero_multitask_balance_segment_ddp import train_unizero_multitask_balance_segment_ddp \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask.py b/lzero/entry/train_unizero_multitask.py new file mode 100644 index 000000000..1f491a212 --- /dev/null +++ b/lzero/entry/train_unizero_multitask.py @@ -0,0 +1,531 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List, Dict, Any +import concurrent.futures +import torch +import numpy as np +import torch.nn.functional as F +from tensorboardX import SummaryWriter + +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy, Policy +from ding.utils import set_pkg_seed, EasyTimer +from ding.worker import BaseLearner + +from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroCollector as Collector + +# Set timeout (seconds) +TIMEOUT = 12000 +timer = EasyTimer() + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector +) -> Tuple[Optional[bool], Optional[float]]: + """ + Safely execute the evaluation task to avoid timeout (non-DDP version). + Args: + evaluator (Evaluator): The evaluator instance. + learner (BaseLearner): The learner instance. + collector (Collector): The data collector instance. + Returns: + Tuple[Optional[bool], Optional[float]]: If evaluation succeeds, returns the stop flag and reward; + otherwise returns (None, None). + """ + try: + print(f"========= Evaluation Started =========") + # Reset stop_event to ensure it is unset before each evaluation + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # Submit the evaluation task + future = executor.submit(evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep) + try: + stop, reward = future.result(timeout=TIMEOUT) + except concurrent.futures.TimeoutError: + # Timeout occurred, set stop_event + evaluator.stop_event.set() + print(f"Evaluation operation timed out after {TIMEOUT} seconds.") + return None, None + + print(f"====== Evaluation Finished ======") + return stop, reward + except Exception as e: + print(f"An error occurred during evaluation: {e}") + return None, None + +def allocate_batch_size_local( + cfgs: List[Dict[str, Any]], + game_buffers, + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + Allocate batch_size inversely proportional to the number of collected episodes + for different tasks (non-DDP version). + Args: + cfgs (List[Dict[str, Any]]): Configuration list for each task. + game_buffers (List[Any]): Replay buffer instances for each task (use Any to avoid specific type dependency). + alpha (float, optional): Hyperparameter controlling the degree of inverse proportionality. Default is 1.0. + clip_scale (int, optional): Dynamic adjustment scale factor. Default is 1. + Returns: + List[int]: The allocated batch_size list. + """ + # Extract the number of collected episodes for each task (assuming buffer has this attribute) + buffer_num_of_collected_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] + print(f'Collected episodes for all local tasks: {buffer_num_of_collected_episodes}') + + # Compute the inverse weights for each task + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in buffer_num_of_collected_episodes]) + inv_sum = np.sum(inv_episodes) + + # Compute the total batch_size (taken from the first task's configuration) + # Assume total_batch_size refers to the total batch size required by the current process + total_batch_size = cfgs[0].policy.total_batch_size + + # Dynamic adjustment: minimum and maximum batch_size range + num_local_tasks = len(cfgs) + avg_batch_size = total_batch_size / max(num_local_tasks, 1) # Avoid division by zero + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # Dynamically adjust alpha to make batch_size changes smoother + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = total_batch_size * task_weights + + # Clip batch_size within [min_batch_size, max_batch_size] + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + + # Ensure batch_size is an integer + batch_sizes = [int(size) for size in batch_sizes] + + return batch_sizes + +def compute_task_weights( + task_rewards: dict, + option: str = "symlog", + epsilon: float = 1e-6, + temperature: float = 1.0, + use_softmax: bool = False, + reverse: bool = False, + clip_min: float = 1e-2, + clip_max: float = 1.0, +) -> dict: + global GLOBAL_MAX, GLOBAL_MIN + + if not task_rewards: + return {} + + task_ids = list(task_rewards.keys()) + rewards_tensor = torch.tensor(list(task_rewards.values()), dtype=torch.float32) + + if option == "symlog": + scaled_rewards = symlog(rewards_tensor) + elif option == "max-min": + max_reward = rewards_tensor.max().item() + min_reward = rewards_tensor.min().item() + scaled_rewards = (rewards_tensor - min_reward) / (max_reward - min_reward + epsilon) + elif option == "run-max-min": + GLOBAL_MAX = max(GLOBAL_MAX, rewards_tensor.max().item()) + GLOBAL_MIN = min(GLOBAL_MIN, rewards_tensor.min().item()) + scaled_rewards = (rewards_tensor - GLOBAL_MIN) / (GLOBAL_MAX - GLOBAL_MIN + epsilon) + elif option == "rank": + sorted_indices = torch.argsort(rewards_tensor) + scaled_rewards = torch.empty_like(rewards_tensor) + rank_values = torch.arange(1, len(rewards_tensor) + 1, dtype=torch.float32) + scaled_rewards[sorted_indices] = rank_values + elif option == "none": + scaled_rewards = rewards_tensor + else: + raise ValueError(f"Unsupported option: {option}") + + if not reverse: + raw_weights = scaled_rewards + else: + scaled_rewards = torch.clamp(scaled_rewards, min=epsilon) + raw_weights = 1.0 / scaled_rewards + + if use_softmax: + beta = 1.0 / max(temperature, epsilon) + logits = -beta * raw_weights + softmax_weights = F.softmax(logits, dim=0).numpy() + weights = dict(zip(task_ids, softmax_weights)) + else: + scaled_weights = raw_weights ** (1 / max(temperature, epsilon)) + total_weight = scaled_weights.sum() + normalized_weights = scaled_weights / max(total_weight, epsilon) # Avoid division by zero + weights = dict(zip(task_ids, normalized_weights.numpy())) + + for task_id in weights: + weights[task_id] = max(min(weights[task_id], clip_max), clip_min) + + return weights + +def train_unizero_multitask( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + Entry point for UniZero multi-task training (non-DDP version). + Args: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): Configuration list for different tasks. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): Path to the pre-trained model. + - max_train_iter (:obj:`Optional[int]`): Maximum number of policy update iterations. + - max_env_step (:obj:`Optional[int]`): Maximum number of collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): The converged policy. + """ + # Initialize temperature scheduler (unchanged) + initial_temperature = 10.0 + final_temperature = 1.0 + threshold_steps = int(1e4) + temperature_scheduler = TemperatureScheduler( + initial_temp=initial_temperature, + final_temp=final_temperature, + threshold_steps=threshold_steps, + mode='linear' + ) + + # Handle all tasks in a single process + tasks = input_cfg_list + total_tasks = len(tasks) + print(f"Handling all {total_tasks} tasks in a single process.") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + # Ensure at least one task is provided + if not tasks: + logging.error("No task configurations provided. Training cannot proceed.") + return None + + # Use the first task's configuration to create the shared policy and learner + task_id_first, [cfg_first, create_cfg_first] = tasks[0] + + assert create_cfg_first.policy.type in ['unizero_multitask', 'sampled_unizero_multitask'], "train_unizero_multitask entry currently only supports 'unizero_multitask' or 'sampled_unizero_multitask'" + + + GameBuffer = None + if create_cfg_first.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GB + GameBuffer = GB + elif create_cfg_first.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as SGB + GameBuffer = SGB + else: + raise NotImplementedError(f"Policy type {create_cfg_first.policy.type} not fully supported for GameBuffer import.") + + cfg_first.policy.device = 'cuda' if cfg_first.policy.cuda and torch.cuda.is_available() else 'cpu' + logging.info(f'Using device: {cfg_first.policy.device}') + + # Compile the main config (only for creating policy and learner) + # Note: we compile once here, but later re-compile per-task configs + compiled_cfg_first = compile_config(cfg_first, seed=seed, env=None, auto=True, create_cfg=create_cfg_first, save_cfg=True) + + # Create shared policy + policy = create_policy(compiled_cfg_first.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + if model_path is not None: + logging.info(f'Loading pretrained model: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=compiled_cfg_first.policy.device)) + logging.info(f'Finished loading model: {model_path}') + + log_dir = os.path.join('./{}/log/'.format(compiled_cfg_first.exp_name), 'serial') + tb_logger = SummaryWriter(log_dir) + + # Create shared learner + learner = BaseLearner(compiled_cfg_first.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=compiled_cfg_first.exp_name) + + # Process each task + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks): + # Set random seed per task + current_seed = seed + task_id + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + # Compile per-task config + cfg = compile_config(cfg, seed=current_seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Get policy config + policy_config = cfg.policy + policy_config.task_id = task_id # explicitly set task_id + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # Create environments + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(current_seed) + evaluator_env.seed(current_seed, dynamic_seed=False) + set_pkg_seed(current_seed, use_cuda=cfg.policy.cuda) + + # Create buffer, collector, and evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + + reanalyze_batch_size = compiled_cfg_first.policy.reanalyze_batch_size + update_per_collect = compiled_cfg_first.policy.update_per_collect + task_complexity_weight = compiled_cfg_first.policy.task_complexity_weight + use_task_exploitation_weight = compiled_cfg_first.policy.use_task_exploitation_weight + task_exploitation_weight = None + + task_rewards = {} + while True: + # Dynamically allocate batch_size + if compiled_cfg_first.policy.allocated_batch_sizes: + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes_list = allocate_batch_size_local(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + # Convert list to {task_id: batch_size} + allocated_batch_sizes_dict = {cfg.policy.task_id: size for cfg, size in zip(cfgs, allocated_batch_sizes_list)} + print("Allocated batch_sizes: ", allocated_batch_sizes_dict) + policy._cfg.batch_size = allocated_batch_sizes_dict + for i, cfg in enumerate(cfgs): + cfg.policy.batch_size = allocated_batch_sizes_dict + + # Iterate over tasks for data collection and evaluation + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + + current_task_id = cfg.policy.task_id + + # Log buffer memory usage + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, current_task_id) + + policy_config = cfg.policy + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 + } + update_per_collect = policy_config.update_per_collect + if update_per_collect is None: + update_per_collect = 40 + + if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): # only for debug + print('=' * 20) + print(f'Evaluating task_id: {current_task_id}...') + # Reset evaluator policy state + evaluator._policy.reset(reset_init_data=True, task_id=current_task_id) + + # Perform safe evaluation (non-DDP version) + stop, reward = safe_eval(evaluator, learner, collector) + if stop is None or reward is None: + print(f"Evaluation failed or timed out, task_id: {current_task_id}, continuing training...") + task_rewards[current_task_id] = float('inf') + else: + try: + eval_mean_reward = reward.get('eval_episode_return_mean', float('inf')) + print(f"Evaluation reward for task {current_task_id}: {eval_mean_reward}") + task_rewards[current_task_id] = eval_mean_reward + except Exception as e: + print(f"Error extracting reward for task {current_task_id}: {e}") + task_rewards[current_task_id] = float('inf') + + print('=' * 20) + print(f'Starting data collection for task_id: {current_task_id}...') + print(f'cfg.policy.task_id={current_task_id}') + + # Reset collector policy state + collector._policy.reset(reset_init_data=True, task_id=current_task_id) + + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + if policy_config.buffer_reanalyze_freq >= 1: + if update_per_collect is None or update_per_collect == 0: + logging.warning("update_per_collect undefined or zero, cannot compute reanalyze_interval") + reanalyze_interval = float('inf') + + else: + reanalyze_interval = update_per_collect // policy_config.buffer_reanalyze_freq + else: + reanalyze_interval = float('inf') + if train_epoch > 0 and policy_config.buffer_reanalyze_freq > 0 and \ + train_epoch % int(1 / policy_config.buffer_reanalyze_freq) == 0 and \ + replay_buffer.get_num_of_transitions() // policy_config.num_unroll_steps > int(reanalyze_batch_size / policy_config.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}, time cost: {timer.value}') + + logging.info(f'Finished data collection for task {current_task_id}') + + not_enough_data = any( + game_buffers[idx].get_num_of_transitions() < policy._cfg.batch_size[cfg.policy.task_id] + for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)) + ) + + current_temperature_task_weight = temperature_scheduler.get_temperature(learner.train_iter) + + # Compute task weights + if task_complexity_weight: + task_weights = compute_task_weights(task_rewards, temperature=current_temperature_task_weight) + else: + task_weights = None + + if not not_enough_data: + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_this_epoch = 0 + + for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + current_task_id = cfg.policy.task_id + current_batch_size = policy._cfg.batch_size[current_task_id] + + if current_batch_size == 0: + logging.warning(f"Task {current_task_id} batch_size is 0, skipping sampling.") + continue + + if replay_buffer.get_num_of_transitions() >= current_batch_size: + policy_config = cfg.policy + if policy_config.buffer_reanalyze_freq >= 1: + if update_per_collect is not None and update_per_collect > 0: + reanalyze_interval = update_per_collect // policy_config.buffer_reanalyze_freq + if i % reanalyze_interval == 0 and \ + replay_buffer.get_num_of_transitions() // policy_config.num_unroll_steps > int(reanalyze_batch_size / policy_config.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}, time cost: {timer.value}') + + train_data = replay_buffer.sample(current_batch_size, policy) + train_data.append(current_task_id) + train_data_multi_task.append(train_data) + envstep_this_epoch += collector.envstep + else: + logging.warning( + f'Not enough data for task {current_task_id}: ' + f'batch_size: {current_batch_size}, buffer: {replay_buffer.get_num_of_transitions()}' + ) + + if train_data_multi_task: + learn_kwargs = {'task_weights': task_weights} + log_vars = learner.train(train_data_multi_task, envstep_this_epoch, policy_kwargs=learn_kwargs) + + # --- Compute and update task_exploitation_weight --- + if i == 0 and use_task_exploitation_weight: + local_obs_loss_task = {} + for cfg in cfgs: + task_id = cfg.policy.task_id + loss_key = f'noreduce_obs_loss_task{task_id}' + if log_vars and loss_key in log_vars[0]: + local_obs_loss_task[task_id] = log_vars[0][loss_key] + + if local_obs_loss_task: + task_exploitation_weight = compute_task_weights( + local_obs_loss_task, + option="rank", + temperature=1, + reverse=True + ) + print(f"Locally computed task_exploitation_weight (by task_id): {task_exploitation_weight}") + + else: + logging.warning("Unable to compute local task_exploitation_weight, obs_loss is empty or invalid.") + task_exploitation_weight = None + + learn_kwargs['task_exploitation_weight'] = task_exploitation_weight + + if compiled_cfg_first.policy.use_priority: + if log_vars: + for batch_idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): + task_id = cfg.policy.task_id + priority_key = f'value_priority_task{task_id}' + if priority_key in log_vars[0]: + if batch_idx < len(train_data_multi_task): + try: + replay_buffer.update_priority( + train_data_multi_task[batch_idx], + log_vars[0][priority_key] + ) + current_priorities = log_vars[0][priority_key] + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + alpha = 0.1 + running_mean_key = f'running_mean_priority_task{task_id}' + if running_mean_key not in value_priority_tasks: + value_priority_tasks[running_mean_key] = mean_priority + else: + value_priority_tasks[running_mean_key] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[running_mean_key] + ) + running_mean_priority = value_priority_tasks[running_mean_key] + if policy_config.print_task_priority_logs: + print(f"Task {task_id} - Mean priority: {mean_priority:.8f}, " + f"Running mean priority: {running_mean_priority:.8f}, " + f"Std: {std_priority:.8f}") + except Exception as e: + logging.error(f"Error updating priority for task {task_id}: {e}") + else: + logging.warning(f"Cannot update priority for task {task_id}, missing data in train_data_multi_task.") + else: + logging.warning(f"Priority key '{priority_key}' not found for task {task_id} in log_vars[0]") + else: + logging.warning("log_vars is empty, cannot update priorities.") + train_epoch += 1 + # Check termination conditions + local_max_envstep = max(collector.envstep for collector in collectors) if collectors else 0 + max_envstep_reached = local_max_envstep >= max_env_step + max_train_iter_reached = learner.train_iter >= max_train_iter + + if max_envstep_reached or max_train_iter_reached: + logging.info(f'Termination condition reached: env_step ({local_max_envstep}/{max_env_step}) or train_iter ({learner.train_iter}/{max_train_iter})') + break + + if hasattr(policy, 'recompute_pos_emb_diff_and_clear_cache'): + policy.recompute_pos_emb_diff_and_clear_cache() + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_ddp.py b/lzero/entry/train_unizero_multitask_ddp.py new file mode 100644 index 000000000..8ef90632f --- /dev/null +++ b/lzero/entry/train_unizero_multitask_ddp.py @@ -0,0 +1,618 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List +import concurrent.futures +import torch +import torch.nn.functional as F +import torch.distributed as dist +import numpy as np +from tensorboardX import SummaryWriter + +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size, EasyTimer +from ding.worker import BaseLearner + +from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler, symlog, inv_symlog +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroCollector as Collector + +# Set timeout (seconds) +TIMEOUT = 12000 +timer = EasyTimer() + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[float]]: + """ + Safely execute the evaluation task to avoid timeout. + Args: + evaluator (Evaluator): The evaluator instance. + learner (BaseLearner): The learner instance. + collector (Collector): The data collector instance. + rank (int): Rank of the current process. + world_size (int): Total number of processes. + Returns: + Tuple[Optional[bool], Optional[float]]: If evaluation succeeds, returns the stop flag and reward; + otherwise returns (None, None). + """ + try: + print(f"========= Evaluation Started Rank {rank}/{world_size} ==========") + # Reset stop_event to ensure it is cleared before each evaluation + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep) + try: + stop, reward = future.result(timeout=TIMEOUT) + except concurrent.futures.TimeoutError: + # Timeout occurred, set stop_event + evaluator.stop_event.set() + print(f"Evaluation timed out on Rank {rank}/{world_size}, exceeded {TIMEOUT} seconds.") + return None, None + + print(f"====== Evaluation Finished Rank {rank}/{world_size} ======") + return stop, reward + except Exception as e: + print(f"Error occurred during evaluation on Rank {rank}/{world_size}: {e}") + return None, None + + +def allocate_batch_size( + cfgs: List[dict], + game_buffers, + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + Allocate batch_size inversely proportional to the number of collected episodes + for different tasks, and dynamically adjust the batch_size range to improve + training stability and efficiency. + Args: + cfgs (List[dict]): Configuration list for each task. + game_buffers (List[GameBuffer]): Replay buffer instances for each task. + alpha (float, optional): Hyperparameter controlling the degree of inverse proportionality. Default is 1.0. + clip_scale (int, optional): Dynamic adjustment scale factor. Default is 1. + Returns: + List[int]: The allocated batch_size list. + """ + # Extract the number of collected episodes for each task (assuming buffer has this attribute) + buffer_num_of_collected_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] + + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # Gather collected episodes from all ranks + all_task_num_of_collected_episodes = [None for _ in range(world_size)] + torch.distributed.all_gather_object(all_task_num_of_collected_episodes, buffer_num_of_collected_episodes) + + # Flatten into a single list + all_task_num_of_collected_episodes = [ + episode for sublist in all_task_num_of_collected_episodes for episode in sublist + ] + if rank == 0: + print(f'Collected episodes for all tasks: {all_task_num_of_collected_episodes}') + + # Compute inverse weights for each task + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes]) + inv_sum = np.sum(inv_episodes) + + # Compute total batch_size + total_batch_size = cfgs[0].policy.total_batch_size + + # Dynamic adjustment: min and max batch_size range + avg_batch_size = total_batch_size / world_size + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # Dynamically adjust alpha to smooth batch_size variation + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = total_batch_size * task_weights + + # Clip batch_size within [min_batch_size, max_batch_size] + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + batch_sizes = [int(size) for size in batch_sizes] + + return batch_sizes + + +# Global max and min (for "run-max-min") +GLOBAL_MAX = -float('inf') +GLOBAL_MIN = float('inf') + +def compute_task_weights( + task_rewards: dict, + option: str = "symlog", + epsilon: float = 1e-6, + temperature: float = 1.0, + use_softmax: bool = False, + reverse: bool = False, + clip_min: float = 1e-2, + clip_max: float = 1.0, +) -> dict: + """ + Improved task weight computation function. + Supports multiple normalization methods, Softmax, proportional/inverse weighting, + and adds clipping functionality for weight ranges. + Args: + task_rewards (dict): Dictionary of task rewards or losses, + with task_id as key and the value as the reward/loss. + option (str): Normalization method. Options are "symlog", "max-min", "run-max-min", "rank", "none". + epsilon (float): Small constant to avoid division by zero. + temperature (float): Temperature parameter controlling weight distribution. + use_softmax (bool): Whether to use Softmax for weight allocation. + reverse (bool): If True, weights are inversely proportional to values; + if False, weights are directly proportional. + clip_min (float): Minimum value for clipping weights. + clip_max (float): Maximum value for clipping weights. + Returns: + dict: Normalized weights for each task, with task_id as key and the normalized weight as value. + """ + global GLOBAL_MAX, GLOBAL_MIN + if not task_rewards: + return {} + + # Step 1: Construct tensor from task_rewards values + task_ids = list(task_rewards.keys()) + rewards_tensor = torch.tensor(list(task_rewards.values()), dtype=torch.float32) + + if option == "symlog": + scaled_rewards = symlog(rewards_tensor) + elif option == "max-min": + max_reward = rewards_tensor.max().item() + min_reward = rewards_tensor.min().item() + scaled_rewards = (rewards_tensor - min_reward) / (max_reward - min_reward + epsilon) + elif option == "run-max-min": + GLOBAL_MAX = max(GLOBAL_MAX, rewards_tensor.max().item()) + GLOBAL_MIN = min(GLOBAL_MIN, rewards_tensor.min().item()) + scaled_rewards = (rewards_tensor - GLOBAL_MIN) / (GLOBAL_MAX - GLOBAL_MIN + epsilon) + elif option == "rank": + # Rank normalization + # Rank is based on sorted order; 1 = smallest, higher rank = larger value + sorted_indices = torch.argsort(rewards_tensor) + scaled_rewards = torch.empty_like(rewards_tensor) + rank_values = torch.arange(1, len(rewards_tensor) + 1, dtype=torch.float32) + scaled_rewards[sorted_indices] = rank_values + elif option == "none": + scaled_rewards = rewards_tensor + else: + raise ValueError(f"Unsupported option: {option}") + + # Step 2: Compute proportional or inverse weights + if not reverse: + raw_weights = scaled_rewards + else: + scaled_rewards = torch.clamp(scaled_rewards, min=epsilon) + raw_weights = 1.0 / scaled_rewards + + # Step 3: Apply Softmax or direct normalization + if use_softmax: + # Softmax weighting + beta = 1.0 / max(temperature, epsilon) # avoid division by zero + logits = -beta * raw_weights + softmax_weights = F.softmax(logits, dim=0).numpy() + weights = dict(zip(task_ids, softmax_weights)) + else: + scaled_weights = raw_weights ** (1 / max(temperature, epsilon)) + total_weight = scaled_weights.sum() + normalized_weights = scaled_weights / total_weight + weights = dict(zip(task_ids, normalized_weights.numpy())) + + # Step 4: Clip weights within [clip_min, clip_max] + for task_id in weights: + weights[task_id] = max(min(weights[task_id], clip_max), clip_min) + + return weights + +def train_unizero_multitask_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + Entry point for UniZero training. The goal is to improve the planning ability + of reinforcement learning agents by addressing the limitations of MuZero-like + algorithms in environments that require capturing long-term dependencies. + For more details, refer to https://arxiv.org/abs/2406.10667. + Args: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): Configuration list for different tasks. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): An instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): Path to the pretrained model checkpoint file. + - max_train_iter (:obj:`Optional[int]`): Maximum number of policy update iterations during training. + - max_env_step (:obj:`Optional[int]`): Maximum number of collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): The converged policy. + """ + initial_temperature = 10.0 + final_temperature = 1.0 + threshold_steps = int(1e4) + temperature_scheduler = TemperatureScheduler( + initial_temp=initial_temperature, + final_temp=final_temperature, + threshold_steps=threshold_steps, + mode='linear' + ) + + rank = get_rank() + world_size = get_world_size() + + # Task partitioning + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # Ensure at least one task is assigned + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: no tasks assigned, continuing execution.") + # Initialize empty lists to avoid errors in later code + cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] + else: + print(f"Rank {rank}/{world_size}, handling tasks {start_idx} to {end_idx - 1}") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + if tasks_for_this_rank: + # Use the first task’s config to create a shared policy + task_id, [cfg, create_cfg] = tasks_for_this_rank[0] + + for config in tasks_for_this_rank: + config[1][0].policy.task_num = tasks_per_rank + + assert create_cfg.policy.type in ['unizero_multitask', + 'sampled_unizero_multitask'], "train_unizero entry 目前仅支持 'unizero_multitask'" + + if create_cfg.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GameBuffer + if create_cfg.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer + + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'Configured device: {cfg.policy.device}') + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create shared policy + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + if model_path is not None: + logging.info(f'Loading pretrained model: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Finished loading pretrained model: {model_path}') + + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # Create shared learner + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + policy_config = cfg.policy + + # Handle each task assigned to this rank + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # Create environments + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # Create game buffer, collector, and evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + task_complexity_weight = cfg.policy.task_complexity_weight + use_task_exploitation_weight = cfg.policy.use_task_exploitation_weight + task_exploitation_weight = None + + # Create task reward dictionary + task_rewards = {} # {task_id: reward} + + while True: + # Dynamically adjust batch_size + if cfg.policy.allocated_batch_sizes: + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + print("Allocated batch_sizes: ", allocated_batch_sizes) + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + cfg.policy.batch_size = allocated_batch_sizes + policy._cfg.batch_size = allocated_batch_sizes + + # Perform data collection and evaluation for each task on this rank + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + + # Log buffer memory usage + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'Rank {rank} evaluating task_id: {cfg.policy.task_id}...') + evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + + # Perform safe evaluation + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + if stop is None or reward is None: + print(f"Rank {rank} encountered an issue during evaluation, continuing training...") + task_rewards[cfg.policy.task_id] = float('inf') # Assign max difficulty if evaluation fails + else: + try: + eval_mean_reward = reward.get('eval_episode_return_mean', float('inf')) + print(f"Evaluation reward for task {cfg.policy.task_id}: {eval_mean_reward}") + task_rewards[cfg.policy.task_id] = eval_mean_reward + except Exception as e: + print(f"Error extracting evaluation reward: {e}") + task_rewards[cfg.policy.task_id] = float('inf') # Assign max reward if error occurs + + + print('=' * 20) + print(f'Starting data collection for Rank {rank}, task_id: {cfg.policy.task_id}...') + print(f'Rank {rank}: cfg.policy.task_id={cfg.policy.task_id} ') + + # Reset policy state before each collection (important for multi-task setups) + collector._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch > 0 and train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time cost: {timer.value}') + + logging.info(f'Rank {rank}: Finished data collection for task {cfg.policy.task_id}') + + # Check if there is enough data for training + not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size + for replay_buffer in game_buffers + ) + + # Get current temperature + current_temperature_task_weight = temperature_scheduler.get_temperature(learner.train_iter) + + # Compute task weights + try: + dist.barrier() + if task_complexity_weight: + all_task_rewards = [None for _ in range(world_size)] + dist.all_gather_object(all_task_rewards, task_rewards) + merged_task_rewards = {} + for rewards in all_task_rewards: + if rewards: + merged_task_rewards.update(rewards) + task_weights = compute_task_weights(merged_task_rewards, temperature=current_temperature_task_weight) + dist.broadcast_object_list([task_weights], src=0) + print(f"Rank {rank}, global task weights (by task_id): {task_weights}") + else: + task_weights = None + except Exception as e: + logging.error(f'Rank {rank}: Failed to synchronize task weights, error: {e}') + break + + if not not_enough_data: + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + if i % reanalyze_interval == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time cost: {timer.value}') + + + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f'Not enough data in replay buffer to sample a mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + learn_kwargs = {'task_weights':task_exploitation_weight} + log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs) + + # Compute task_exploitation_weight if needed + if i == 0: + try: + dist.barrier() + if use_task_exploitation_weight: + all_obs_loss = [None for _ in range(world_size)] + merged_obs_loss_task = {} + for cfg, replay_buffer in zip(cfgs, game_buffers): + task_id = cfg.policy.task_id + if f'noreduce_obs_loss_task{task_id}' in log_vars[0]: + merged_obs_loss_task[task_id] = log_vars[0][f'noreduce_obs_loss_task{task_id}'] + dist.all_gather_object(all_obs_loss, merged_obs_loss_task) + global_obs_loss_task = {} + for obs_loss_task in all_obs_loss: + if obs_loss_task: + global_obs_loss_task.update(obs_loss_task) + if global_obs_loss_task: + task_exploitation_weight = compute_task_weights( + global_obs_loss_task, + option="rank", + temperature=1, + ) + dist.broadcast_object_list([task_exploitation_weight], src=0) + print(f"Rank {rank}, task_exploitation_weight (by task_id): {task_exploitation_weight}") + else: + logging.warning(f"Rank {rank}: Failed to compute global obs_loss task weights, obs_loss data is empty.") + task_exploitation_weight = None + else: + task_exploitation_weight = None + learn_kwargs['task_weight'] = task_exploitation_weight + except Exception as e: + logging.error(f'Rank {rank}: Failed to synchronize task weights, error: {e}') + raise e + + if cfg.policy.use_priority: + for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): + task_id = cfg.policy.task_id + replay_buffer.update_priority( + train_data_multi_task[idx], + log_vars[0][f'value_priority_task{task_id}'] + ) + + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + alpha = 0.1 # smoothing factor + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + if cfg.policy.print_task_priority_logs: + print(f"Task {task_id} - Mean priority: {mean_priority:.8f}, " + f"Running mean priority: {running_mean_priority:.8f}, " + f"Std: {std_priority:.8f}") + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # Synchronize all ranks after training + try: + dist.barrier() + logging.info(f'Rank {rank}: passed training synchronization barrier') + except Exception as e: + logging.error(f'Rank {rank}: synchronization barrier failed, error: {e}') + break + + # Check termination conditions + try: + local_envsteps = [collector.envstep for collector in collectors] + total_envsteps = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + all_envsteps = torch.cat([torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: termination condition reached') + dist.barrier() + break + except Exception as e: + logging.error(f'Rank {rank}: termination check failed, error: {e}') + break + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index b51eb7f11..bfdf706cb 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -362,3 +362,18 @@ def log_buffer_run_time(train_iter: int, buffer: "GameBuffer", writer: SummaryWr # Reset the time records in the buffer. buffer.reset_runtime_metrics() + + +def symlog(x: torch.Tensor) -> torch.Tensor: + """ + Symlog normalization to reduce the scale differences of target values. + symlog(x) = sign(x) * log(|x| + 1) + """ + return torch.sign(x) * torch.log(torch.abs(x) + 1) + +def inv_symlog(x: torch.Tensor) -> torch.Tensor: + """ + Inverse operation of symlog, used to recover the original values. + inv_symlog(x) = sign(x) * (exp(|x|) - 1) + """ + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index 38c1935ea..1e5eaf4b3 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -52,10 +52,13 @@ def __init__(self, cfg: dict): if hasattr(self._cfg, 'task_id'): self.task_id = self._cfg.task_id print(f"Task ID is set to {self.task_id}.") - try: - self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] - except Exception as e: + + if isinstance(self._cfg.model.action_space_size, list): + self.action_space_size = self._cfg.model.action_space_size[self.task_id] + elif isinstance(self._cfg.model.action_space_size, int): self.action_space_size = self._cfg.model.action_space_size + else: + raise ValueError(" action_space_size should be int or list") else: self.task_id = None print("No task_id found in configuration. Task ID is set to None.") @@ -90,6 +93,7 @@ def sample( ) # target policy + batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, current_batch[1], current_batch[-1]) # current_batch[1] is batch_action batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( policy_non_re_context, self.action_space_size @@ -135,14 +139,14 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: obs_list, action_list, mask_list = [], [], [] timestep_list = [] bootstrap_action_list = [] - - # prepare the inputs of a batch + for i in range(batch_size): game = game_segment_list[i] pos_in_game_segment = pos_in_game_segment_list[i] actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps].tolist() + timestep_tmp = game.timestep_segment[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps].tolist() # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid @@ -158,9 +162,17 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: # mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))] # mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + # prepare the inputs of a batch + if isinstance(game.action_space_size, list): + action_size = game.action_space_size[self.task_id] + elif isinstance(game.action_space_size, int): + action_size = game.action_space_size + else: + raise ValueError(" action_space_size should be int or list") + # pad random action actions_tmp += [ - np.random.randint(0, game.action_space_size) + np.random.randint(0, action_size) for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) ] # TODO: check the effect @@ -185,7 +197,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: self._cfg.num_unroll_steps+self._cfg.td_steps].tolist() # pad random action bootstrap_action_tmp += [ - np.random.randint(0, game.action_space_size) + np.random.randint(0, action_size) for _ in range(self._cfg.num_unroll_steps - len(bootstrap_action_tmp)) ] bootstrap_action_list.append(bootstrap_action_tmp) diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index 2c45b328b..714fe2ab9 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -69,7 +69,10 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea # image obs input, e.g. atari environments self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1]) else: - self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) + if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: + self.zero_obs_shape = config.model.observation_shape + else: + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) self.obs_segment = [] self.action_segment = [] diff --git a/lzero/model/common.py b/lzero/model/common.py index 5ac305e52..50e0ce815 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -369,8 +369,7 @@ def __init__(self, model_path: str = 'google-bert/bert-base-uncased', embedding_size: int = 768, group_size: int = 8, - norm_type: str = "simnorm", - # norm_type: str = "layernorm", # TODO: Why does nan appear in the first step of training? + final_norm_option_in_encoder: str = "simnorm", tokenizer=None): """ Overview: @@ -391,12 +390,12 @@ def __init__(self, # In distributed training, only the rank 0 process downloads the model, and other processes load from cache to speed up startup. if get_rank() == 0: - self.model = AutoModel.from_pretrained(model_path) + self.pretrained_model = AutoModel.from_pretrained(model_path) if get_world_size() > 1: # Wait for rank 0 to finish loading the model. torch.distributed.barrier() if get_rank() != 0: - self.model = AutoModel.from_pretrained(model_path) + self.pretrained_model = AutoModel.from_pretrained(model_path) if tokenizer is None: # Only rank 0 downloads the tokenizer, and then other processes load it from cache. @@ -411,15 +410,15 @@ def __init__(self, # Set the embedding dimension. A linear projection is added (the dimension remains unchanged here but can be extended for other mappings). self.embedding_size = embedding_size - self.embed_proj_head = nn.Linear(self.model.config.hidden_size, self.embedding_size) + self.embed_proj_head = nn.Linear(self.pretrained_model.config.hidden_size, self.embedding_size) # Select the normalization method based on the norm_type parameter. - if norm_type.lower() == "simnorm": + if final_norm_option_in_encoder.lower() == "simnorm": self.norm = SimNorm(simnorm_dim=group_size) - elif norm_type.lower() == "layernorm": + elif final_norm_option_in_encoder.lower() == "layernorm": self.norm = nn.LayerNorm(embedding_size) else: - raise NotImplementedError(f"Normalization type '{norm_type}' is not implemented. " + raise NotImplementedError(f"Normalization type '{final_norm_option_in_encoder}' is not implemented. " f"Choose 'simnorm' or 'layernorm'.") def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: @@ -442,12 +441,12 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: if no_grad: with torch.no_grad(): x = x.long() # Ensure the input tensor is of type long. - outputs = self.model(x, attention_mask=attention_mask) + outputs = self.pretrained_model(x, attention_mask=attention_mask) # Get the hidden state from the last layer and select the output corresponding to the [CLS] token. cls_embedding = outputs.last_hidden_state[:, 0, :] else: x = x.long() - outputs = self.model(x, attention_mask=attention_mask) + outputs = self.pretrained_model(x, attention_mask=attention_mask) cls_embedding = outputs.last_hidden_state[:, 0, :] # Apply linear projection to obtain the desired output dimension. diff --git a/lzero/model/common_bkp20250521.py b/lzero/model/common_bkp20250521.py deleted file mode 100644 index 3e4edf8a2..000000000 --- a/lzero/model/common_bkp20250521.py +++ /dev/null @@ -1,1369 +0,0 @@ -""" -Overview: - In this Python file, we provide a collection of reusable model templates designed to streamline the development - process for various custom algorithms. By utilizing these pre-built model templates, users can quickly adapt and - customize their custom algorithms, ensuring efficient and effective development. - BTW, users can refer to the unittest of these model templates to learn how to use them. -""" -import math -from dataclasses import dataclass -from typing import Callable, List, Optional -from typing import Tuple - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.nn.init as init -from ding.torch_utils import MLP, ResBlock -from ding.torch_utils.network.normalization import build_normalization -from ding.utils import SequenceType -from ditk import logging -from ding.utils import set_pkg_seed, get_rank, get_world_size -import torch - -def MLP_V2( - in_channels: int, - hidden_channels: List[int], - out_channels: int, - layer_fn: Callable = None, - activation: Optional[nn.Module] = None, - norm_type: Optional[str] = None, - use_dropout: bool = False, - dropout_probability: float = 0.5, - output_activation: bool = True, - output_norm: bool = True, - last_linear_layer_init_zero: bool = False, -): - """ - Overview: - Create a multi-layer perceptron (MLP) using a list of hidden dimensions. Each layer consists of a fully - connected block with optional activation, normalization, and dropout. The final layer is configurable - to include or exclude activation, normalization, and dropout based on user preferences. - - Arguments: - - in_channels (:obj:`int`): Number of input channels (dimensionality of the input tensor). - - hidden_channels (:obj:`List[int]`): A list specifying the number of channels for each hidden layer. - For example, [512, 256, 128] means the MLP will have three hidden layers with 512, 256, and 128 units, respectively. - - out_channels (:obj:`int`): Number of output channels (dimensionality of the output tensor). - - layer_fn (:obj:`Callable`, optional): Layer function to construct layers (default is `nn.Linear`). - - activation (:obj:`nn.Module`, optional): Activation function to use after each layer - (e.g., `nn.ReLU`, `nn.Sigmoid`). Default is None (no activation). - - norm_type (:obj:`str`, optional): Type of normalization to apply after each layer. - If None, no normalization is applied. Supported values depend on the implementation of `build_normalization`. - - use_dropout (:obj:`bool`, optional): Whether to apply dropout after each layer. Default is False. - - dropout_probability (:obj:`float`, optional): The probability of setting elements to zero in dropout. Default is 0.5. - - output_activation (:obj:`bool`, optional): Whether to apply activation to the output layer. Default is True. - - output_norm (:obj:`bool`, optional): Whether to apply normalization to the output layer. Default is True. - - last_linear_layer_init_zero (:obj:`bool`, optional): Whether to initialize the weights and biases of the - last linear layer to zeros. This is commonly used in reinforcement learning for stable initial outputs. - - Returns: - - block (:obj:`nn.Sequential`): A PyTorch `nn.Sequential` object containing the layers of the MLP. - - Notes: - - The final layer's normalization, activation, and dropout are controlled by `output_activation`, - `output_norm`, and `use_dropout`. - - If `last_linear_layer_init_zero` is True, the weights and biases of the last linear layer are initialized to 0. - """ - assert len(hidden_channels) > 0, "The hidden_channels list must contain at least one element." - if layer_fn is None: - layer_fn = nn.Linear - - # Initialize the MLP block - block = [] - channels = [in_channels] + hidden_channels + [out_channels] - - # Build all layers except the final layer - for i, (in_channels, out_channels) in enumerate(zip(channels[:-2], channels[1:-1])): - block.append(layer_fn(in_channels, out_channels)) - if norm_type is not None: - block.append(build_normalization(norm_type, dim=1)(out_channels)) - if activation is not None: - block.append(activation) - if use_dropout: - block.append(nn.Dropout(dropout_probability)) - - # Build the final layer - in_channels = channels[-2] - out_channels = channels[-1] - block.append(layer_fn(in_channels, out_channels)) - - # Add optional normalization and activation for the final layer - if output_norm and norm_type is not None: - block.append(build_normalization(norm_type, dim=1)(out_channels)) - if output_activation and activation is not None: - block.append(activation) - if use_dropout: - block.append(nn.Dropout(dropout_probability)) - - # Initialize the weights and biases of the last linear layer to zero if specified - if last_linear_layer_init_zero: - for layer in reversed(block): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - - return nn.Sequential(*block) - -# use dataclass to make the output of network more convenient to use -@dataclass -class MZRNNNetworkOutput: - # output format of the MuZeroRNN model - value: torch.Tensor - value_prefix: torch.Tensor - policy_logits: torch.Tensor - latent_state: torch.Tensor - predict_next_latent_state: torch.Tensor - reward_hidden_state: Tuple[torch.Tensor] - - -@dataclass -class EZNetworkOutput: - # output format of the EfficientZero model - value: torch.Tensor - value_prefix: torch.Tensor - policy_logits: torch.Tensor - latent_state: torch.Tensor - reward_hidden_state: Tuple[torch.Tensor] - - -@dataclass -class MZNetworkOutput: - # output format of the MuZero model - value: torch.Tensor - reward: torch.Tensor - policy_logits: torch.Tensor - latent_state: torch.Tensor - - -class SimNorm(nn.Module): - - def __init__(self, simnorm_dim: int) -> None: - """ - Overview: - Simplicial normalization. Adapted from https://arxiv.org/abs/2204.00616. - Arguments: - - simnorm_dim (:obj:`int`): The dimension for simplicial normalization. - """ - super().__init__() - self.dim = simnorm_dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Overview: - Forward pass of the SimNorm layer. - Arguments: - - x (:obj:`torch.Tensor`): The input tensor to normalize. - Returns: - - x (:obj:`torch.Tensor`): The normalized tensor. - """ - shp = x.shape - # Ensure that there is at least one simplex to normalize across. - if shp[1] != 0: - x = x.view(*shp[:-1], -1, self.dim) - x = F.softmax(x, dim=-1) - return x.view(*shp) - else: - return x - - def __repr__(self) -> str: - """ - Overview: - String representation of the SimNorm layer. - Returns: - - output (:obj:`str`): The string representation. - """ - return f"SimNorm(dim={self.dim})" - - -def AvgL1Norm(x, eps=1e-8): - """ - Overview: - Normalize the input tensor by the L1 norm. - Arguments: - - x (:obj:`torch.Tensor`): The input tensor to normalize. - - eps (:obj:`float`): The epsilon value to prevent division by zero. - Returns: - - :obj:`torch.Tensor`: The normalized tensor. - """ - return x / x.abs().mean(-1, keepdim=True).clamp(min=eps) - - -class FeatureAndGradientHook: - - def __init__(self): - """ - Overview: - Class to capture features and gradients at SimNorm. - """ - self.features_before = [] - self.features_after = [] - self.grads_before = [] - self.grads_after = [] - - def setup_hooks(self, model): - # Hooks to capture features and gradients at SimNorm - self.forward_handler = model.sim_norm.register_forward_hook(self.forward_hook) - self.backward_handler = model.sim_norm.register_full_backward_hook(self.backward_hook) - - def forward_hook(self, module, input, output): - with torch.no_grad(): - self.features_before.append(input[0]) - self.features_after.append(output) - - def backward_hook(self, module, grad_input, grad_output): - with torch.no_grad(): - self.grads_before.append(grad_input[0] if grad_input[0] is not None else None) - self.grads_after.append(grad_output[0] if grad_output[0] is not None else None) - - def analyze(self): - # Calculate L2 norms of features - l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_before])) - l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_after])) - - # Calculate norms of gradients - grad_norm_before = torch.mean( - torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_before if g is not None])) - grad_norm_after = torch.mean( - torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_after if g is not None])) - - # Clear stored data and delete tensors to free memory - self.clear_data() - - # Optionally clear CUDA cache - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return l2_norm_before, l2_norm_after, grad_norm_before, grad_norm_after - - def clear_data(self): - del self.features_before[:] - del self.features_after[:] - del self.grads_before[:] - del self.grads_after[:] - - def remove_hooks(self): - self.forward_handler.remove() - self.backward_handler.remove() - - -class DownSample(nn.Module): - - def __init__(self, observation_shape: SequenceType, out_channels: int, - activation: nn.Module = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - num_resblocks: int = 1, - ) -> None: - """ - Overview: - Define downSample convolution network. Encode the observation into hidden state. - This network is often used in video games like Atari. In board games like go and chess, - we don't need this module. - Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[12, 96, 96] - for video games like atari, RGB 3 channel times stack 4 frames. - - out_channels (:obj:`int`): The output channels of output hidden state. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ - Use the inplace operation to speed up. - - norm_type (:obj:`Optional[str]`): The normalization type used in network, defaults to 'BN'. - - num_resblocks (:obj:`int`): The number of residual blocks. Defaults to 1. - """ - super().__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" - - assert num_resblocks == 1, "num_resblocks must be 1 in DownSample" - - self.observation_shape = observation_shape - self.conv1 = nn.Conv2d( - observation_shape[0], - out_channels // 2, - kernel_size=3, - stride=2, - padding=1, - bias=False, # disable bias for better convergence - ) - if norm_type == 'BN': - self.norm1 = nn.BatchNorm2d(out_channels // 2) - elif norm_type == 'LN': - self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2], - eps=1e-5) - - self.resblocks1 = nn.ModuleList( - [ - ResBlock( - in_channels=out_channels // 2, - activation=activation, - norm_type=norm_type, - res_type='basic', - bias=False - ) for _ in range(num_resblocks) - ] - ) - self.downsample_block = ResBlock( - in_channels=out_channels // 2, - out_channels=out_channels, - activation=activation, - norm_type=norm_type, - res_type='downsample', - bias=False - ) - self.resblocks2 = nn.ModuleList( - [ - ResBlock( - in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_resblocks) - ] - ) - self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) - self.resblocks3 = nn.ModuleList( - [ - ResBlock( - in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_resblocks) - ] - ) - self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) - self.activation = activation - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ - H is height. - - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ - output width, H_ is output height. - """ - x = self.conv1(x) - x = self.norm1(x) - x = self.activation(x) - - for block in self.resblocks1: - x = block(x) - x = self.downsample_block(x) - for block in self.resblocks2: - x = block(x) - x = self.pooling1(x) - for block in self.resblocks3: - x = block(x) - - # 64, 84, 96 are the most common observation shapes in Atari games. - if self.observation_shape[1] == 64: - output = x - elif self.observation_shape[1] == 84: - x = self.pooling2(x) - output = x - elif self.observation_shape[1] == 96: - x = self.pooling2(x) - output = x - else: - raise NotImplementedError(f"DownSample for observation shape {self.observation_shape} is not implemented now. " - f"You should transform the observation shape to 64 or 96 in the env.") - - return output - - -class HFLanguageRepresentationNetwork(nn.Module): - def __init__(self, - model_path: str = 'google-bert/bert-base-uncased', - embedding_size: int = 768, - group_size: int = 8, - norm_type: str = "simnorm", - # norm_type: str = "layernorm", # TODO: Why does nan appear in the first step of training? - tokenizer=None): - """ - Overview: - This class defines a language representation network that utilizes a pretrained Hugging Face model. - The network outputs embeddings with the specified dimension and can optionally use SimNorm or LayerNorm - for normalization at the final stage to ensure training stability. - Arguments: - - model_path (str): The path to the pretrained Hugging Face model. Default is 'google-bert/bert-base-uncased'. - - embedding_size (int): The dimension of the output embeddings. Default is 768. - - group_size (int): The group size for SimNorm when using normalization. - - norm_type (str): The type of normalization to use ("simnorm" or "layernorm"). Default is "layernorm". - - tokenizer (Optional): An instance of a tokenizer. If None, the tokenizer will be loaded from the pretrained model. - """ - super().__init__() - - from transformers import AutoModel, AutoTokenizer - logging.info(f"Loading model from: {model_path}") - - # In distributed training, only the rank 0 process downloads the model, and other processes load from cache to speed up startup. - if get_rank() == 0: - self.model = AutoModel.from_pretrained(model_path) - if get_world_size() > 1: - # Wait for rank 0 to finish loading the model. - torch.distributed.barrier() - if get_rank() != 0: - self.model = AutoModel.from_pretrained(model_path) - - if tokenizer is None: - # Only rank 0 downloads the tokenizer, and then other processes load it from cache. - if get_rank() == 0: - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - if get_world_size() > 1: - torch.distributed.barrier() - if get_rank() != 0: - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - else: - self.tokenizer = tokenizer - - # Set the embedding dimension. A linear projection is added (the dimension remains unchanged here but can be extended for other mappings). - self.embedding_size = embedding_size - self.embed_proj_head = nn.Linear(self.model.config.hidden_size, self.embedding_size) - - # Select the normalization method based on the norm_type parameter. - if norm_type.lower() == "simnorm": - self.norm = SimNorm(simnorm_dim=group_size) - elif norm_type.lower() == "layernorm": - self.norm = nn.LayerNorm(embedding_size) - else: - raise NotImplementedError(f"Normalization type '{norm_type}' is not implemented. " - f"Choose 'simnorm' or 'layernorm'.") - - def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: - """ - Forward Propagation: - Compute the language representation based on the input token sequence. - The [CLS] token’s representation is extracted from the output of the pretrained model, - then passed through a linear projection and final normalization layer (SimNorm or LayerNorm). - - Arguments: - - x (torch.Tensor): Input token sequence of shape [batch_size, seq_len]. - - no_grad (bool): Whether to run in no-gradient mode for memory efficiency. Default is True. - Returns: - - torch.Tensor: The processed language embedding with shape [batch_size, embedding_size]. - """ - # Construct the attention mask to exclude padding tokens. - attention_mask = x != self.tokenizer.pad_token_id - - # Use no_grad context if specified to disable gradient computation. - if no_grad: - with torch.no_grad(): - x = x.long() # Ensure the input tensor is of type long. - outputs = self.model(x, attention_mask=attention_mask) - # Get the hidden state from the last layer and select the output corresponding to the [CLS] token. - cls_embedding = outputs.last_hidden_state[:, 0, :] - else: - x = x.long() - outputs = self.model(x, attention_mask=attention_mask) - cls_embedding = outputs.last_hidden_state[:, 0, :] - - # Apply linear projection to obtain the desired output dimension. - cls_embedding = self.embed_proj_head(cls_embedding) - # Normalize the embeddings using the selected normalization layer (SimNorm or LayerNorm) to ensure training stability. - cls_embedding = self.norm(cls_embedding) - - return cls_embedding - -from torch.nn.utils import weight_norm - -# AdaptiveFeatureScaler:在对 1D 向量进行 scaling 时,加入 clamp 限制,避免 runaway -class AdaptiveFeatureScaler(nn.Module): - def __init__(self, init_scale=0.1, max_scale=1.0): - super().__init__() - self.scale = nn.Parameter(torch.tensor(init_scale)) - self.max_scale = max_scale - - def forward(self, x): - # 限制 scale 参数的最大值,避免数值爆炸 - clamped_scale = torch.clamp(self.scale, 0.0, self.max_scale) - return x * clamped_scale / math.sqrt(x.size(1)) - -# 假设 SimNorm, ResBlock, DownSample 在其他地方已经定义 -# 下面仅给出 RepresentationNetworkUniZero 的实现 - -class RepresentationNetworkUniZero(nn.Module): - def __init__( - self, - observation_shape: tuple = (3, 64, 64), - num_res_blocks: int = 1, - num_channels: int = 64, - downsample: bool = True, - activation: nn.Module = nn.GELU(approximate='tanh'), - norm_type: str = 'BN', - embedding_dim: int = 256, - group_size: int = 8, - final_norm_option_in_encoder: str = 'SimNorm', - use_adaptive_scale: bool = False - ) -> None: - """ - Representation network used in UniZero. - 对于 channel 数较大的场景,可使用全局平均池化来降低全连接层的输入维度,提高训练稳定性。 - """ - super().__init__() - assert norm_type in ['BN', 'LN'], "norm_type must be in ['BN', 'LN']" - # 打印日志信息(可选) - print(f"Using norm type: {norm_type}") - print(f"Using activation type: {activation}") - - self.use_global_pooling = False - - self.observation_shape = observation_shape - self.downsample = downsample - - if self.downsample: - # DownSample 对象的实现需自行定义 - self.downsample_net = DownSample( - observation_shape, - num_channels, - activation=activation, - norm_type=norm_type, - num_resblocks=1, - ) - else: - self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) - if norm_type == 'BN': - self.norm = nn.BatchNorm2d(num_channels) - elif norm_type == 'LN': - # 当不进行 downsample 时,观察图尺寸不变 - self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) - - # 构建 residual block 层 - self.resblocks = nn.ModuleList( - [ - ResBlock( - in_channels=num_channels, - activation=activation, - norm_type=norm_type, - res_type='basic', - bias=False - ) for _ in range(num_res_blocks) - ] - ) - self.activation = activation - self.embedding_dim = embedding_dim - - # 根据观察图尺寸确定空间维度 - if self.observation_shape[1] == 64: - spatial_size = 8 - elif self.observation_shape[1] in [84, 96]: - spatial_size = 6 - else: - spatial_size = self.observation_shape[1] # 默认采用输入H - - if self.observation_shape[1] == 64: - last_linear_in_dim = num_channels * 8 * 8 - elif self.observation_shape[1] in [84, 96]: - last_linear_in_dim = num_channels * 6 * 6 - else: - # 默认采用完整 flatten 的维度 - last_linear_in_dim = num_channels * self.observation_shape[1] * self.observation_shape[2] - - self.last_linear = nn.Linear(last_linear_in_dim, self.embedding_dim, bias=False) - - - # 根据是否使用全局平均池化决定 last_linear 前的输入维度以及 norm 的形状 - if self.use_global_pooling: - linear_in_dim = num_channels # 全局池化后形状: (B, num_channels, 1, 1) - self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) - # 对 1D 向量使用 LayerNorm - self.norm_before_last_linear = nn.LayerNorm(linear_in_dim, eps=1e-5) - else: - linear_in_dim = num_channels * spatial_size * spatial_size - if use_adaptive_scale: - # 若通过 flatten 后进行 adaptive scaling,对 1D 向量归一化 - self.norm_before_last_linear = nn.LayerNorm(linear_in_dim, eps=1e-5) - else: - # 保留空间信息时,在 (C, H, W) 上归一化 - self.norm_before_last_linear = nn.LayerNorm([num_channels, spatial_size, spatial_size], eps=1e-5) - - self.last_linear = nn.Linear(linear_in_dim, self.embedding_dim, bias=False) - - self.use_adaptive_scale = use_adaptive_scale - if self.use_adaptive_scale: - self.adaptive_scaler = AdaptiveFeatureScaler(init_scale=0.1, max_scale=1.0) - - # 最后归一化层,根据 final_norm_option_in_encoder 进行选择 - if final_norm_option_in_encoder == 'LayerNorm': - self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5) - elif final_norm_option_in_encoder == 'SimNorm': - self.final_norm = SimNorm(simnorm_dim=group_size) - else: - raise ValueError(f"Unsupported final_norm_option_in_encoder: {final_norm_option_in_encoder}") - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: (B, C_in, H, W) - Returns: - x: (B, embedding_dim) - """ - if self.downsample: - x = self.downsample_net(x) - else: - x = self.conv(x) - x = self.norm(x) - x = self.activation(x) - - # 依次通过多个 residual block - for block in self.resblocks: - x = block(x) - - # 分支1:使用全局平均池化 - if self.use_global_pooling: - x = self.global_pool(x) # 输出 shape: (B, num_channels, 1, 1) - x = x.view(x.size(0), -1) # 展平为 (B, num_channels) - x = self.norm_before_last_linear(x) # 对 1D 向量做归一化 - else: - # 分支2:不使用全局池化 - if self.use_adaptive_scale: - # 若启用 adaptive scaling:先展平再做 fan-in 缩放 - x = x.view(x.size(0), -1) # (B, num_channels * spatial_size^2) - x = self.adaptive_scaler(x) - x = self.norm_before_last_linear(x) # 归一化 1D 向量 - else: - # 保持完整空间信息:在 (B, C, H, W) 上归一化后,再展平 - x = self.norm_before_last_linear(x) - x = x.view(x.size(0), -1) - - # 最后一层全连接映射与归一化 - x = self.last_linear(x) - x = self.final_norm(x) - return x - - -class RepresentationNetwork(nn.Module): - - def __init__( - self, - observation_shape: SequenceType = (4, 96, 96), - num_res_blocks: int = 1, - num_channels: int = 64, - downsample: bool = True, - activation: nn.Module = nn.ReLU(inplace=True), - norm_type: str = 'BN', - embedding_dim: int = 256, - group_size: int = 8, - use_sim_norm: bool = False, - ) -> None: - """ - Overview: - Representation network used in MuZero and derived algorithms. Encode the 2D image obs into latent state. - Currently, the network only supports obs images with both a width and height of 96. - Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[4, 96, 96] - for video games like atari, 1 gray channel times stack 4 frames. - - num_res_blocks (:obj:`int`): The number of residual blocks. - - num_channels (:obj:`int`): The channel of output hidden state. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ - Use the inplace operation to speed up. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - - embedding_dim (:obj:`int`): The dimension of the output hidden state. - - group_size (:obj:`int`): The size of group in the SimNorm layer. - - use_sim_norm (:obj:`bool`): Whether to use SimNorm layer, defaults to False. - """ - super().__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" - - self.downsample = downsample - if self.downsample: - self.downsample_net = DownSample( - observation_shape, - num_channels, - activation=activation, - norm_type=norm_type, - ) - else: - self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) - - if norm_type == 'BN': - self.norm = nn.BatchNorm2d(num_channels) - elif norm_type == 'LN': - if downsample: - self.norm = nn.LayerNorm( - [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], - eps=1e-5) - else: - self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) - - self.resblocks = nn.ModuleList( - [ - ResBlock( - in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_res_blocks) - ] - ) - self.activation = activation - - self.use_sim_norm = use_sim_norm - - if self.use_sim_norm: - self.embedding_dim = embedding_dim - self.sim_norm = SimNorm(simnorm_dim=group_size) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ - H is height. - - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ - output width, H_ is output height. - """ - if self.downsample: - x = self.downsample_net(x) - else: - x = self.conv(x) - x = self.norm(x) - x = self.activation(x) - - for block in self.resblocks: - x = block(x) - - if self.use_sim_norm: - # NOTE: very important. - # for atari 64,8,8 = 4096 -> 768 - x = self.sim_norm(x) - - return x - - -class RepresentationNetworkMLP(nn.Module): - - def __init__( - self, - observation_shape: int, - hidden_channels: int = 64, - layer_num: int = 2, - activation: nn.Module = nn.GELU(approximate='tanh'), - norm_type: Optional[str] = 'BN', - group_size: int = 8, - ) -> torch.Tensor: - """ - Overview: - Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \ - with Multi-Layer Perceptron (MLP). - Arguments: - - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. - - num_res_blocks (:obj:`int`): The number of residual blocks. - - hidden_channels (:obj:`int`): The channel of output hidden state. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ - Use the inplace operation to speed up. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - """ - super().__init__() - self.fc_representation = MLP( - in_channels=observation_shape, - hidden_channels=hidden_channels, - out_channels=hidden_channels, - layer_num=layer_num, - activation=activation, - norm_type=norm_type, - # don't use activation and norm in the last layer of representation network is important for convergence. - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=True, - ) - self.sim_norm = SimNorm(simnorm_dim=group_size) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. - - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. - """ - x = self.fc_representation(x) - # TODO - x = self.sim_norm(x) - return x - - -class LatentDecoder(nn.Module): - - def __init__(self, embedding_dim: int, output_shape: SequenceType, num_channels: int = 64, activation: nn.Module = nn.GELU(approximate='tanh')): - """ - Overview: - Decoder network used in UniZero. Decode the latent state into 2D image obs. - Arguments: - - embedding_dim (:obj:`int`): The dimension of the latent state. - - output_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] - for video games like atari, RGB 3 channel times stack 4 frames. - - num_channels (:obj:`int`): The channel of output hidden state. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). - """ - super().__init__() - self.embedding_dim = embedding_dim - self.output_shape = output_shape # (C, H, W) - self.num_channels = num_channels - self.activation = activation - - # Assuming that the output shape is (C, H, W) = (12, 96, 96) and embedding_dim is 256 - # We will reverse the process of the representation network - self.initial_size = ( - num_channels, output_shape[1] // 8, output_shape[2] // 8) # This should match the last layer of the encoder - self.fc = nn.Linear(self.embedding_dim, np.prod(self.initial_size)) - - # Upsampling blocks - self.conv_blocks = nn.ModuleList([ - # Block 1: (num_channels, H/8, W/8) -> (num_channels//2, H/4, W/4) - nn.ConvTranspose2d(num_channels, num_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1), - self.activation, - nn.BatchNorm2d(num_channels // 2), - # Block 2: (num_channels//2, H/4, W/4) -> (num_channels//4, H/2, W/2) - nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1, - output_padding=1), - self.activation, - nn.BatchNorm2d(num_channels // 4), - # Block 3: (num_channels//4, H/2, W/2) -> (output_shape[0], H, W) - nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1, - output_padding=1), - ]) - # TODO: last layer use sigmoid? - - def forward(self, embeddings: torch.Tensor) -> torch.Tensor: - # Map embeddings back to the image space - x = self.fc(embeddings) # (B, embedding_dim) -> (B, C*H/8*W/8) - x = x.view(-1, *self.initial_size) # (B, C*H/8*W/8) -> (B, C, H/8, W/8) - - # Apply conv blocks - for block in self.conv_blocks: - x = block(x) # Upsample progressively - - # The output x should have the shape of (B, output_shape[0], output_shape[1], output_shape[2]) - return x - - -class LatentEncoderForMemoryEnv(nn.Module): - - def __init__( - self, - image_shape=(3, 5, 5), - embedding_size=100, - channels=[16, 32, 64], - kernel_sizes=[3, 3, 3], - strides=[1, 1, 1], - activation: nn.Module = nn.GELU(approximate='tanh'), - normalize_pixel=False, - group_size: int = 8, - **kwargs, - ): - """ - Overview: - Encoder network used in UniZero in MemoryEnv. Encode the 2D image obs into latent state. - Arguments: - - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] - for video games like atari, RGB 3 channel times stack 4 frames. - - embedding_size (:obj:`int`): The dimension of the latent state. - - channels (:obj:`List[int]`): The channel of output hidden state. - - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers. - - strides (:obj:`List[int]`): The stride of convolution layers. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). \ - Use the inplace operation to speed up. - - normalize_pixel (:obj:`bool`): Whether to normalize the pixel values to [0, 1], defaults to False. - - group_size (:obj:`int`): The dimension for simplicial normalization - """ - super(LatentEncoderForMemoryEnv, self).__init__() - self.shape = image_shape - self.channels = [image_shape[0]] + list(channels) - - layers = [] - for i in range(len(self.channels) - 1): - layers.append( - nn.Conv2d( - self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i], - padding=kernel_sizes[i] // 2 # keep the same size of feature map - ) - ) - layers.append(nn.BatchNorm2d(self.channels[i + 1])) - layers.append(activation) - - layers.append(nn.AdaptiveAvgPool2d(1)) - - self.cnn = nn.Sequential(*layers) - self.linear = nn.Sequential( - nn.Linear(self.channels[-1], embedding_size, bias=False), - ) - init.kaiming_normal_(self.linear[0].weight, mode='fan_out', nonlinearity='relu') - - self.normalize_pixel = normalize_pixel - self.sim_norm = SimNorm(simnorm_dim=group_size) - - def forward(self, image): - if self.normalize_pixel: - image = image / 255.0 - x = self.cnn(image.float()) # (B, C, 1, 1) - x = torch.flatten(x, start_dim=1) # (B, C) - x = self.linear(x) # (B, embedding_size) - x = self.sim_norm(x) - return x - - -class LatentDecoderForMemoryEnv(nn.Module): - - def __init__( - self, - image_shape=(3, 5, 5), - embedding_size=256, - channels=[64, 32, 16], - kernel_sizes=[3, 3, 3], - strides=[1, 1, 1], - activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), - **kwargs, - ): - """ - Overview: - Decoder network used in UniZero in MemoryEnv. Decode the latent state into 2D image obs. - Arguments: - - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] - for video games like atari, RGB 3 channel times stack 4 frames. - - embedding_size (:obj:`int`): The dimension of the latent state. - - channels (:obj:`List[int]`): The channel of output hidden state. - - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers. - - strides (:obj:`List[int]`): The stride of convolution layers. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.LeakyReLU(). \ - Use the inplace operation to speed up. - """ - super(LatentDecoderForMemoryEnv, self).__init__() - self.shape = image_shape - self.channels = list(channels) + [image_shape[0]] - - self.linear = nn.Linear(embedding_size, channels[0] * image_shape[1] * image_shape[2]) - - layers = [] - for i in range(len(self.channels) - 1): - layers.append( - nn.ConvTranspose2d( - self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i], - padding=kernel_sizes[i] // 2, output_padding=strides[i] - 1 - ) - ) - if i < len(self.channels) - 2: - layers.append(nn.BatchNorm2d(self.channels[i + 1])) - layers.append(activation) - else: - layers.append(nn.Sigmoid()) - - self.deconv = nn.Sequential(*layers) - - def forward(self, embedding): - x = self.linear(embedding) - x = x.view(-1, self.channels[0], self.shape[1], self.shape[2]) - x = self.deconv(x) # (B, C, H, W) - return x - - -class VectorDecoderForMemoryEnv(nn.Module): - - def __init__( - self, - embedding_dim: int, - output_shape: SequenceType, - hidden_channels: int = 64, - layer_num: int = 2, - activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), # TODO - norm_type: Optional[str] = 'BN', - ) -> torch.Tensor: - """ - Overview: - Decoder network used in UniZero in MemoryEnv. Decode the latent state into vector obs. - Arguments: - - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. - - num_res_blocks (:obj:`int`): The number of residual blocks. - - hidden_channels (:obj:`int`): The channel of output hidden state. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \ - Use the inplace operation to speed up. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - """ - super().__init__() - self.fc_representation = MLP( - in_channels=embedding_dim, - hidden_channels=hidden_channels, - out_channels=output_shape, - layer_num=layer_num, - activation=activation, - norm_type=norm_type, - # don't use activation and norm in the last layer of representation network is important for convergence. - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=True, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. - - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. - """ - x = self.fc_representation(x) - return x - - -class PredictionNetwork(nn.Module): - - def __init__( - self, - observation_shape: SequenceType, - action_space_size: int, - num_res_blocks: int, - num_channels: int, - value_head_channels: int, - policy_head_channels: int, - value_head_hidden_channels: int, - policy_head_hidden_channels: int, - output_support_size: int, - flatten_input_size_for_value_head: int, - flatten_input_size_for_policy_head: int, - downsample: bool = False, - last_linear_layer_init_zero: bool = True, - activation: nn.Module = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - ) -> None: - """ - Overview: - The definition of policy and value prediction network, which is used to predict value and policy by the - given latent state. - Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image. - - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. - - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. - - num_channels (:obj:`int`): The channels of hidden states. - - value_head_channels (:obj:`int`): The channels of value head. - - policy_head_channels (:obj:`int`): The channels of policy head. - - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - output_support_size (:obj:`int`): The size of categorical value output. - - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ - - flatten_input_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ - of the value head. - - flatten_input_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ - of the policy head. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ - dynamics/prediction mlp, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - """ - super(PredictionNetwork, self).__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" - - self.resblocks = nn.ModuleList( - [ - ResBlock( - in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_res_blocks) - ] - ) - - self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) - self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) - - if observation_shape[1] == 96: - latent_shape = (observation_shape[1] // 16, observation_shape[2] // 16) - elif observation_shape[1] == 64: - latent_shape = (observation_shape[1] // 8, observation_shape[2] // 8) - - if norm_type == 'BN': - self.norm_value = nn.BatchNorm2d(value_head_channels) - self.norm_policy = nn.BatchNorm2d(policy_head_channels) - elif norm_type == 'LN': - if downsample: - self.norm_value = nn.LayerNorm( - [value_head_channels, *latent_shape], - eps=1e-5) - self.norm_policy = nn.LayerNorm([policy_head_channels, *latent_shape], eps=1e-5) - else: - self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]], - eps=1e-5) - self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]], - eps=1e-5) - - self.flatten_input_size_for_value_head = flatten_input_size_for_value_head - self.flatten_input_size_for_policy_head = flatten_input_size_for_policy_head - - self.activation = activation - - self.fc_value = MLP_V2( - in_channels=self.flatten_input_size_for_value_head, - hidden_channels=value_head_hidden_channels, - out_channels=output_support_size, - activation=self.activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - self.fc_policy = MLP_V2( - in_channels=self.flatten_input_size_for_policy_head, - hidden_channels=policy_head_hidden_channels, - out_channels=action_space_size, - activation=self.activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - - def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Overview: - Forward computation of the prediction network. - Arguments: - - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). - Returns: - - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). - - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). - """ - for res_block in self.resblocks: - latent_state = res_block(latent_state) - - value = self.conv1x1_value(latent_state) - value = self.norm_value(value) - value = self.activation(value) - - policy = self.conv1x1_policy(latent_state) - policy = self.norm_policy(policy) - policy = self.activation(policy) - - value = value.reshape(-1, self.flatten_input_size_for_value_head) - policy = policy.reshape(-1, self.flatten_input_size_for_policy_head) - - value = self.fc_value(value) - policy = self.fc_policy(policy) - return policy, value - - -class PredictionNetworkMLP(nn.Module): - - def __init__( - self, - action_space_size, - num_channels, - common_layer_num: int = 2, - value_head_hidden_channels: SequenceType = [32], - policy_head_hidden_channels: SequenceType = [32], - output_support_size: int = 601, - last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - ): - """ - Overview: - The definition of policy and value prediction network with Multi-Layer Perceptron (MLP), - which is used to predict value and policy by the given latent state. - Arguments: - - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \ - space, it is the number of discrete actions. - - num_channels (:obj:`int`): The channels of latent states. - - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - output_support_size (:obj:`int`): The size of categorical value output. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ - dynamics/prediction mlp, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - """ - super().__init__() - self.num_channels = num_channels - - # ******* common backbone ****** - self.fc_prediction_common = MLP( - in_channels=self.num_channels, - hidden_channels=self.num_channels, - out_channels=self.num_channels, - layer_num=common_layer_num, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - - # ******* value and policy head ****** - self.fc_value_head = MLP_V2( - in_channels=self.num_channels, - hidden_channels=value_head_hidden_channels, - out_channels=output_support_size, - activation=activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - self.fc_policy_head = MLP_V2( - in_channels=self.num_channels, - hidden_channels=policy_head_hidden_channels, - out_channels=action_space_size, - activation=activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - - def forward(self, latent_state: torch.Tensor): - """ - Overview: - Forward computation of the prediction network. - Arguments: - - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). - Returns: - - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). - - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). - """ - x_prediction_common = self.fc_prediction_common(latent_state) - - value = self.fc_value_head(x_prediction_common) - policy = self.fc_policy_head(x_prediction_common) - return policy, value - - -class PredictionHiddenNetwork(nn.Module): - - def __init__( - self, - observation_shape: SequenceType, - action_space_size: int, - num_res_blocks: int, - num_channels: int, - value_head_channels: int, - policy_head_channels: int, - value_head_hidden_channels: int, - policy_head_hidden_channels: int, - output_support_size: int, - flatten_input_size_for_value_head: int, - flatten_input_size_for_policy_head: int, - downsample: bool = False, - last_linear_layer_init_zero: bool = True, - activation: nn.Module = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - gru_hidden_size: int = 512, - ) -> None: - """ - Overview: - The definition of policy and value prediction network, which is used to predict value and policy by the - given latent state. - Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image. - - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. - - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. - - num_channels (:obj:`int`): The channels of hidden states. - - value_head_channels (:obj:`int`): The channels of value head. - - policy_head_channels (:obj:`int`): The channels of policy head. - - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - output_support_size (:obj:`int`): The size of categorical value output. - - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ - - flatten_input_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ - of the value head. - - flatten_input_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ - of the policy head. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ - dynamics/prediction mlp, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - """ - super(PredictionHiddenNetwork, self).__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" - - self.observation_shape = observation_shape - self.gru_hidden_size = gru_hidden_size - self.resblocks = nn.ModuleList( - [ - ResBlock( - in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_res_blocks) - ] - ) - - self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) - self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) - - if norm_type == 'BN': - self.norm_value = nn.BatchNorm2d(value_head_channels) - self.norm_policy = nn.BatchNorm2d(policy_head_channels) - elif norm_type == 'LN': - if downsample: - self.norm_value = nn.LayerNorm( - [value_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], - eps=1e-5) - self.norm_policy = nn.LayerNorm([policy_head_channels, math.ceil(observation_shape[-2] / 16), - math.ceil(observation_shape[-1] / 16)], eps=1e-5) - else: - self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]], - eps=1e-5) - self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]], - eps=1e-5) - - self.flatten_input_size_for_value_head = flatten_input_size_for_value_head - self.flatten_input_size_for_policy_head = flatten_input_size_for_policy_head - - self.activation = activation - - self.fc_value = MLP( - in_channels=self.flatten_input_size_for_value_head + self.gru_hidden_size, - hidden_channels=value_head_hidden_channels[0], - out_channels=output_support_size, - layer_num=len(value_head_hidden_channels) + 1, - activation=self.activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - self.fc_policy = MLP( - in_channels=self.flatten_input_size_for_policy_head + self.gru_hidden_size, - hidden_channels=policy_head_hidden_channels[0], - out_channels=action_space_size, - layer_num=len(policy_head_hidden_channels) + 1, - activation=self.activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - - def forward(self, latent_state: torch.Tensor, world_model_latent_history: torch.Tensor) -> Tuple[ - torch.Tensor, torch.Tensor]: - """ - Overview: - Forward computation of the prediction network. - Arguments: - - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). - Returns: - - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). - - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). - """ - for res_block in self.resblocks: - latent_state = res_block(latent_state) - - value = self.conv1x1_value(latent_state) - value = self.norm_value(value) - value = self.activation(value) - - policy = self.conv1x1_policy(latent_state) - policy = self.norm_policy(policy) - policy = self.activation(policy) - - latent_state_value = value.reshape(-1, self.flatten_input_size_for_value_head) - latent_state_policy = policy.reshape(-1, self.flatten_input_size_for_policy_head) - - # TODO: world_model_latent_history.squeeze(0) shape: (num_layers * num_directions, batch_size, hidden_size) -> ( batch_size, hidden_size) - latent_history_value = torch.cat([latent_state_value, world_model_latent_history.squeeze(0)], dim=1) - latent_history_policy = torch.cat([latent_state_policy, world_model_latent_history.squeeze(0)], dim=1) - - value = self.fc_value(latent_history_value) - policy = self.fc_policy(latent_history_policy) - return policy, value \ No newline at end of file diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py index 0e050502d..e9610a69c 100644 --- a/lzero/model/unizero_model_multitask.py +++ b/lzero/model/unizero_model_multitask.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch import torch.nn as nn @@ -6,7 +6,8 @@ from easydict import EasyDict from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ - VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook + VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \ + HFLanguageRepresentationNetwork from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model_multitask import WorldModelMT @@ -18,12 +19,11 @@ # use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. @MODEL_REGISTRY.register('UniZeroMTModel') class UniZeroMTModel(nn.Module): - #@profile def __init__( self, observation_shape: SequenceType = (4, 64, 64), - action_space_size: int = 6, + action_space_size: Union[int, list] = 0, num_res_blocks: int = 1, num_channels: int = 64, activation: nn.Module = nn.GELU(approximate='tanh'), @@ -45,7 +45,7 @@ def __init__( - and heads, which generate the logits for observations, rewards, policy, and value. Arguments: - observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[3, 64, 64] for Atari. - - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - action_space_size: (:obj:`[int, list]`): Action space size. For discrete or fixed action spaces, this is usually an integer. For multi-task environments where the action spaces are different, this is a list. - num_res_blocks (:obj:`int`): The number of res blocks in UniZero model. - num_channels (:obj:`int`): The channels of hidden states in representation network. - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ @@ -75,7 +75,6 @@ def __init__( self.action_space_size = action_space_size # for multi-task - self.action_space_size = 18 self.task_num = task_num self.activation = activation @@ -239,7 +238,172 @@ def __init__( print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') print('==' * 20) + elif world_model_cfg.obs_type == 'text': + self.representation_network = nn.ModuleList() + for task_id in range(1): # TODO: one share encoder + self.representation_network.append( + HFLanguageRepresentationNetwork( + model_path=kwargs['encoder_url'], + embedding_size=world_model_cfg.embed_dim, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder + ) + ) + + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False, obs_type=world_model_cfg.obs_type) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) + self._log_model_parameters(world_model_cfg.obs_type) + + def _log_model_parameters(self, obs_type: str) -> None: + """ + Overview: + Logs detailed parameter counts for all model components with a comprehensive breakdown. + Includes encoder, transformer, prediction heads, and other components. + Arguments: + - obs_type (:obj:`str`): The type of observation ('vector', 'image', or 'image_memory'). + """ + from ding.utils import get_rank + + # Only print from rank 0 to avoid duplicate logs in DDP + if get_rank() != 0: + return + + print('=' * 80) + print('MODEL PARAMETER STATISTICS'.center(80)) + print('=' * 80) + + # --- Total Model Parameters --- + total_params = sum(p.numel() for p in self.parameters()) + total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) + print(f'\n{"TOTAL MODEL":<40} {total_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {total_trainable:>15,} parameters') + print(f'{" └─ Frozen":<40} {total_params - total_trainable:>15,} parameters') + + # --- World Model Components --- + print(f'\n{"-" * 80}') + print(f'{"WORLD MODEL BREAKDOWN":<40}') + print(f'{"-" * 80}') + + wm_params = sum(p.numel() for p in self.world_model.parameters()) + wm_trainable = sum(p.numel() for p in self.world_model.parameters() if p.requires_grad) + print(f'{"World Model Total":<40} {wm_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {wm_trainable:>15,} parameters ({100*wm_trainable/wm_params:.1f}%)') + + # --- Encoder --- + encoder_params = sum(p.numel() for p in self.tokenizer.encoder.parameters()) + encoder_trainable = sum(p.numel() for p in self.tokenizer.encoder.parameters() if p.requires_grad) + print(f'\n{"1. ENCODER (Tokenizer)":<40} {encoder_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {encoder_trainable:>15,} parameters ({100*encoder_trainable/encoder_params:.1f}%)') + + # --- Transformer Backbone --- + transformer_params = sum(p.numel() for p in self.world_model.transformer.parameters()) + transformer_trainable = sum(p.numel() for p in self.world_model.transformer.parameters() if p.requires_grad) + print(f'\n{"2. TRANSFORMER BACKBONE":<40} {transformer_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {transformer_trainable:>15,} parameters ({100*transformer_trainable/transformer_params:.1f}%)') + + # --- Prediction Heads (Detailed Breakdown) --- + print(f'\n{"3. PREDICTION HEADS":<40}') + + # Access head_dict from world_model + if hasattr(self.world_model, 'head_dict'): + head_dict = self.world_model.head_dict + + # Calculate total heads parameters + total_heads_params = sum(p.numel() for module in head_dict.values() for p in module.parameters()) + total_heads_trainable = sum(p.numel() for module in head_dict.values() for p in module.parameters() if p.requires_grad) + print(f'{" Total (All Heads)":<40} {total_heads_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {total_heads_trainable:>15,} parameters ({100*total_heads_trainable/total_heads_params:.1f}%)') + + # Breakdown by head type + head_names_map = { + 'head_policy_multi_task': 'Policy Head', + 'head_value_multi_task': 'Value Head', + 'head_rewards_multi_task': 'Reward Head', + 'head_observations_multi_task': 'Next Latent (Obs) Head' + } + + print(f'\n{" Breakdown by Head Type:":<40}') + for head_key, head_name in head_names_map.items(): + if head_key in head_dict: + head_module = head_dict[head_key] + head_params = sum(p.numel() for p in head_module.parameters()) + head_trainable = sum(p.numel() for p in head_module.parameters() if p.requires_grad) + + # Count number of task-specific heads (for ModuleList) + if isinstance(head_module, nn.ModuleList): + num_heads = len(head_module) + params_per_head = head_params // num_heads if num_heads > 0 else 0 + print(f'{" ├─ " + head_name:<38} {head_params:>15,} parameters') + print(f'{" └─ " + f"{num_heads} task-specific heads":<38} {params_per_head:>15,} params/head') + else: + print(f'{" ├─ " + head_name:<38} {head_params:>15,} parameters') + print(f'{" └─ Shared across tasks":<38}') + + # --- Positional & Task Embeddings --- + print(f'\n{"4. EMBEDDINGS":<40}') + + if hasattr(self.world_model, 'pos_emb'): + pos_emb_params = sum(p.numel() for p in self.world_model.pos_emb.parameters()) + pos_emb_trainable = sum(p.numel() for p in self.world_model.pos_emb.parameters() if p.requires_grad) + print(f'{" ├─ Positional Embedding":<40} {pos_emb_params:>15,} parameters') + if pos_emb_trainable == 0: + print(f'{" └─ (Frozen)":<40}') + + if hasattr(self.world_model, 'task_emb') and self.world_model.task_emb is not None: + task_emb_params = sum(p.numel() for p in self.world_model.task_emb.parameters()) + task_emb_trainable = sum(p.numel() for p in self.world_model.task_emb.parameters() if p.requires_grad) + print(f'{" ├─ Task Embedding":<40} {task_emb_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {task_emb_trainable:>15,} parameters') + + if hasattr(self.world_model, 'act_embedding_table'): + act_emb_params = sum(p.numel() for p in self.world_model.act_embedding_table.parameters()) + act_emb_trainable = sum(p.numel() for p in self.world_model.act_embedding_table.parameters() if p.requires_grad) + print(f'{" └─ Action Embedding":<40} {act_emb_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {act_emb_trainable:>15,} parameters') + + # --- Decoder (if applicable) --- + if obs_type in ['vector', 'image_memory'] and self.tokenizer.decoder_network is not None: + print(f'\n{"5. DECODER":<40}') + decoder_params = sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) + decoder_trainable = sum(p.numel() for p in self.tokenizer.decoder_network.parameters() if p.requires_grad) + print(f'{" Decoder Network":<40} {decoder_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {decoder_trainable:>15,} parameters') + + if obs_type == 'image_memory' and hasattr(self.tokenizer, 'lpips'): + lpips_params = sum(p.numel() for p in self.tokenizer.lpips.parameters()) + print(f'{" LPIPS Loss Network":<40} {lpips_params:>15,} parameters') + + # Calculate world model params excluding decoder and LPIPS + params_without_decoder = wm_params - decoder_params - lpips_params + print(f'\n{" World Model (exc. Decoder & LPIPS)":<40} {params_without_decoder:>15,} parameters') + + # --- Summary Table --- + print(f'\n{"=" * 80}') + print(f'{"SUMMARY":<40}') + print(f'{"=" * 80}') + print(f'{"Component":<30} {"Total Params":>15} {"Trainable":>15} {"% of Total":>15}') + print(f'{"-" * 80}') + + components = [ + ("Encoder", encoder_params, encoder_trainable), + ("Transformer", transformer_params, transformer_trainable), + ] + + if hasattr(self.world_model, 'head_dict'): + components.append(("Prediction Heads", total_heads_params, total_heads_trainable)) + + for name, total, trainable in components: + pct = 100 * total / total_params if total_params > 0 else 0 + print(f'{name:<30} {total:>15,} {trainable:>15,} {pct:>14.1f}%') + print(f'{"=" * 80}') + print(f'{"TOTAL":<30} {total_params:>15,} {total_trainable:>15,} {"100.0%":>15}') + print(f'{"=" * 80}\n') + #@profile def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, task_id=None) -> MZNetworkOutput: """ diff --git a/lzero/model/unizero_world_models/moe.py b/lzero/model/unizero_world_models/moe.py index 8ee8115ee..c6e340e5b 100644 --- a/lzero/model/unizero_world_models/moe.py +++ b/lzero/model/unizero_world_models/moe.py @@ -8,11 +8,6 @@ from lzero.model.unizero_world_models.transformer import _maybe_wrap_linear -# _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim), config, "feed_forward") - -# https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/moe.py -# https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/transformer_layers.py#L149 -# Modified from https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/transformer.py#L108 class MultiplicationFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -24,6 +19,14 @@ def __init__(self, config): self.w1 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False) self.w2 = nn.Linear(4 * config.embed_dim, config.embed_dim, bias=False) self.w3 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False) + if config.moe_use_lora: + self.w1 = _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False), config, "feed_forward") + self.w2 = _maybe_wrap_linear(nn.Linear(4 * config.embed_dim, config.embed_dim, bias=False), config, "feed_forward") + self.w3 = _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False), config, "feed_forward") + else: + self.w1 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False) + self.w2 = nn.Linear(4 * config.embed_dim, config.embed_dim, bias=False) + self.w3 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore @@ -33,34 +36,37 @@ class MoeArgs(Serializable): num_experts: int num_experts_per_tok: int - class MoELayer(nn.Module): """ - Mixture-of-Experts (MoE) 层的实现,参考了如下的设计: - - - 根据输入 x 的形状先展平为二维张量([batch_size, dim]) - - 使用门控网络(gate)为每个 token 计算各专家的 logits,并选出前 k 个专家(k = num_experts_per_tok) - - 对于选中的每个专家,对应 token 调用该专家的前向传播,将专家计算结果乘以门控权重后累积 - - 可选支持共享专家分支 shared_expert 对所有 token 做统一处理 - - 最后恢复输入的原始形状返回 - + Mixture-of-Experts (MoE) layer. + Design: + - Flatten input to 2D [N, dim] where N = batch_size * seq_len. + - A gating module produces logits over experts for each token. + - Select top-k experts per token (k = num_experts_per_tok), softmax the + selected logits to get normalized weights, and combine expert outputs + weighted by those gate weights. + - Optionally add a shared expert branch applied to all tokens. + - Finally, restore the original shape. Attributes: - dim (int): 输入特征的维度 - num_experts (int): 专家数量 - num_experts_per_tok (int): 每个 token 激活的专家个数 - gate (nn.Module): 门控模块,用于生成专家路由 logits - experts (nn.ModuleList): 专家模块列表 - shared_expert (nn.Module or None): 用于所有 token 的共享专家分支(如果配置了 n_shared_experts) + dim (int): Input feature dimension. + num_experts (int): Number of experts. + num_experts_per_tok (int): Top-k experts activated per token. + gate (nn.Module): Gating module that outputs logits of shape [N, num_experts]. + experts (nn.ModuleList): List of expert modules. + shared_expert (nn.Module or None): Optional shared expert used for all tokens + when `config.n_shared_experts > 0`. """ def __init__(self, config, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok: int = 1): super().__init__() self.dim = config.embed_dim self.num_experts = len(experts) + self.dim = config.embed_dim + self.num_experts = len(experts) self.num_experts_per_tok = num_experts_per_tok self.gate = gate self.experts = nn.ModuleList(experts) - # 如果配置中指定了共享专家数量,则构建共享专家分支 + # Optional shared expert branch if hasattr(config, "n_shared_experts") and config.n_shared_experts > 0: self.shared_expert = nn.Sequential( nn.Linear(self.dim, config.n_shared_experts * (4 * self.dim)), @@ -71,41 +77,36 @@ def __init__(self, config, experts: List[nn.Module], gate: nn.Module, num_expert self.shared_expert = None def forward(self, x: torch.Tensor) -> torch.Tensor: - # 保存原始形状后将 x reshape 为二维张量: [batch_size * seq_len, dim] + # Save shape and flatten to [N, dim] original_shape = x.size() x = x.view(-1, self.dim) - # 计算门控 logits,shape 为 [N, num_experts],N 为 token 数量 + # Gate logits: [N, num_experts] gate_logits = self.gate(x) - # 选取每个 token 得分最高的 k 个专家 + # Top-k experts per token weights, indices = torch.topk(gate_logits, self.num_experts_per_tok, dim=1) - # 对选中的 logits 做 softmax,获得归一化权重 weights = F.softmax(weights, dim=1).to(x.dtype) - - # 初始化存放专家计算输出的张量 + # Accumulate expert outputs expert_output = torch.zeros_like(x) - - # 遍历所有专家,对被该专家选择的 token 分支进行计算 + # For each expert, gather the tokens routed to it for expert_id in range(self.num_experts): - # 通过 where 找到 indices 中等于当前 expert_id 的 token 索引 batch_idx, expert_tok_idx = torch.where(indices == expert_id) if batch_idx.numel() == 0: continue - token_subset = x[batch_idx] # 选中的 token,形状 [num_tokens, dim] - # 调用当前专家模块计算输出 + token_subset = x[batch_idx] # [num_tokens_routed, dim] output_expert = self.experts[expert_id](token_subset) - # 获取对应 token 的权重,注意 weights 的形状为 [N, num_experts_per_tok] + # Get the corresponding token weights; note that `weights` has shape [N, num_experts_per_tok] token_weights = weights[batch_idx, expert_tok_idx].unsqueeze(-1) expert_output[batch_idx] += output_expert * token_weights - # 如果使用了共享专家分支,则加上其输出 + # If a shared expert branch is configured, add its output if self.shared_expert is not None: shared_output = self.shared_expert(x) output = expert_output + shared_output else: output = expert_output - # 恢复原始形状后返回结果 + # Restore the original shape and return the result return output.view(original_shape) class MoELayerOptimized(nn.Module): diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index 1e87efb17..ffe60b5f2 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -61,6 +61,7 @@ def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Ten Arguments: - x (torch.Tensor): Input tensor of shape (B, ...). + - task_id (int, optional): Task ID for multitask settings. Defaults to None. Returns: - torch.Tensor: Encoded embeddings of shape (B, 1, E). @@ -81,14 +82,21 @@ def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Ten if len(shape) == 2: # Case when input is 2D (B, E) # obs_embeddings = self.encoder[task_id](x) - obs_embeddings = self.encoder(x, task_id) # TODO: - + if self.obs_type == 'text': + try: + obs_embeddings = self.encoder[0](x) # 目前共用一个encoder + except: + obs_embeddings = self.encoder(x) obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') + elif len(shape) == 3: # Case when input is 3D (B, T, E) x = x.contiguous().view(-1, shape[-1]) # Flatten the last two dimensions (B * T, E) - # obs_embeddings = self.encoder[task_id](x) - obs_embeddings = self.encoder(x,task_id) # TODO: + if self.obs_type == 'text': + try: + obs_embeddings = self.encoder[0](x) # 目前共用一个encoder + except: + obs_embeddings = self.encoder(x) obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 4: diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index 3edf4f1c9..a55c240e2 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -219,7 +219,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ############################################## # 修改 _maybe_wrap_linear 辅助函数 ############################################## - def _maybe_wrap_linear(linear: nn.Linear, config, module_label: str) -> nn.Module: """ 辅助函数:当满足以下条件时,将传入的 nn.Linear 层替换为 @@ -545,23 +544,11 @@ def __init__(self, config: TransformerConfig) -> None: gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), num_experts_per_tok=config.num_experts_per_tok, ) - + # If a shared expert branch is configured, add its output print("="*20) print(f'use moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') print("="*20) elif config.multiplication_moe_in_transformer: - # TODO: deepseek-v3 - # from .moe import MoeConfig,MoELayer - # moe_cfg = MoeConfig( - # embed_dim=config.embed_dim, - # num_experts_total=config.num_experts_of_moe_in_transformer, - # num_experts_per_tok=1, - # ) - # self.feed_forward = MoELayer(moe_cfg) - # print("=" * 20) - # print(f"Use MoE feed_forward, num_experts={moe_cfg.num_experts_total}") - # print("=" * 20) - from .moe import MoELayer, MultiplicationFeedForward # Create multiple FeedForward instances for multiplication-based MoE self.experts = nn.ModuleList([ diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index ecb583504..62d426183 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -3,11 +3,11 @@ from typing import Any, Tuple from typing import Optional from typing import Union, Dict - import torch.nn as nn import torch.nn.functional as F from einops import rearrange +from .kv_caching import KeysValues from lzero.model.common import SimNorm from lzero.model.unizero_world_models.world_model import WorldModel from lzero.model.utils import cal_dormant_ratio, compute_average_weight_magnitude, cal_effective_rank @@ -19,7 +19,7 @@ logging.getLogger().setLevel(logging.DEBUG) from ding.utils import get_rank -import torch.distributed as dist +import torch.distributed as dist from sklearn.manifold import TSNE import os import numpy as np @@ -29,8 +29,10 @@ import torch import math +# TODO xjy: 继承容易出问题 +# class WorldModelMT(WorldModel): -class WorldModelMT(WorldModel): +class WorldModelMT(nn.Module): """ Overview: The WorldModel class is responsible for the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), @@ -55,7 +57,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: - "concat_task_embed": Concatenates task embeddings with observation embeddings. - "register_task_embed": Uses task embeddings as additional input tokens. """ - super().__init__(config, tokenizer) + super().__init__() self.tokenizer = tokenizer self.config = config @@ -74,6 +76,12 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.share_head = config.share_head # 新增参数 + # Task embedding setup + self.use_task_embed = config.use_task_embed + self.task_embed_option = self.config.task_embed_option + self.task_num = config.task_num + self.task_embed_dim = config.task_embed_dim if hasattr(config, "task_embed_dim") else 96 + self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 if self.config.device == 'cpu': self.device = torch.device('cpu') @@ -81,10 +89,14 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Move all modules to the specified device print(f"self.device: {self.device}") - # Position embedding - self.pos_emb = nn.Embedding(config.max_tokens, self.config.embed_dim, device=self.device) - print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + self._initialize_config_parameters() + self._initialize_patterns() + + self.sim_norm = SimNorm(simnorm_dim=self.group_size) + self.hidden_size = config.embed_dim // config.num_heads + + # Position embedding if self.task_embed_option == "register_task_embed": # 由于 "register_task_embed"设定下的位置编码没有矫正 # 使用 nn.Embedding,但初始化为全零并禁止学习 @@ -92,15 +104,6 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: nn.init.constant_(self.pos_emb.weight, 0.0) # 初始化全零 self.pos_emb.weight.requires_grad = False # 禁止更新 - # Task embedding setup - self.use_task_embed = config.use_task_embed - self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings - self.task_embed_dim = config.task_embed_dim if hasattr(config, "task_embed_dim") else 96 - self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 - - self.precompute_pos_emb_diff_kv() - - self.sim_norm = SimNorm(simnorm_dim=self.group_size) if self.task_embed_option == "concat_task_embed": # TODO:目前在 "concat_task_embed"下面,self.pos_emb需要设置为固定的0 self.task_emb = nn.Embedding(self.task_num, self.task_embed_dim, max_norm=1) # TODO: TDMPC2:max_norm=1性能更好 @@ -118,8 +121,12 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.obs_act_embed_dim = config.embed_dim self.register_token_num = 0 - self.transformer = Transformer(self.config, self.task_emb) + self.to(self.device) + + self.pos_emb = nn.Embedding(config.max_tokens, self.config.embed_dim, device=self.device) + print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + self.precompute_pos_emb_diff_kv() self.analysis_dormant_ratio_interval = self.config.get('analysis_dormant_ratio_interval', 100) # 每 100 次调用做一次分析 self._analysis_step_counter = 0 @@ -136,26 +143,53 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: # 遍历 env_id_list,提取短名称 for env_id in self.config.env_id_list: # 提取 'NoFrameskip-v4' 之前的部分作为短名称 - short_name = env_id.replace('NoFrameskip-v4', '') + short_name = env_id.replace('.z5', '') self.env_short_names[env_id] = short_name - # 映射环境 ID 到简写名称 - # self.env_short_names = { - # 'PongNoFrameskip-v4': 'Pong', - # 'MsPacmanNoFrameskip-v4': 'MsPacman', - # 'SeaquestNoFrameskip-v4': 'Seaquest', - # 'BoxingNoFrameskip-v4': 'Boxing', - # 'AlienNoFrameskip-v4': 'Alien', - # 'ChopperCommandNoFrameskip-v4': 'Chopper', - # 'HeroNoFrameskip-v4': 'Hero', - # 'RoadRunnerNoFrameskip-v4': 'RoadRunner' - # } - # 颜色映射,确保每个任务有固定的颜色 self.num_tasks = len(self.env_id_list) # 生成足够多的颜色 self.colors = self._generate_colors(len(self.env_id_list)) - + self.continuous_action_space = self.config.continuous_action_space + + # Initialize action embedding table + if self.continuous_action_space: + # TODO: check the effect of SimNorm + if isinstance(config.action_space_size, list): + assert self.task_num == len(config.action_space_size) + self.act_embedding_table = nn.ModuleList( + nn.Sequential( + nn.Linear(action_size, config.embed_dim, device=self.device, bias=False), + SimNorm(simnorm_dim=self.group_size) + ) for action_size in config.action_space_size + ) + elif isinstance(config.action_space_size, int): + self.act_embedding_table = nn.Sequential( + nn.Linear(config.action_space_size, config.embed_dim, device=self.device, bias=False), + SimNorm(simnorm_dim=self.group_size)) + else: + raise ValueError(f"Unsupported action space size type: {type(config.action_space_size)}") + + else: + # for discrete action space + if isinstance(config.action_space_size, list): + assert self.task_num == len(config.action_space_size) + self.act_embedding_table = nn.ModuleList( + nn.Embedding(action_size, config.embed_dim, device=self.device) + for action_size in config.action_space_size + ) + print(f"self.act_embedding_table: {self.act_embedding_table}") + + elif isinstance(config.action_space_size, int): + self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) + else: + raise ValueError(f"Unsupported action space size type: {type(config.action_space_size)}") + + + print(f'='*20) + print(f"self.obs_act_embed_dim:{self.obs_act_embed_dim}") + print(f'='*20) + self.head_policy_multi_task = nn.ModuleList() self.head_value_multi_task = nn.ModuleList() self.head_rewards_multi_task = nn.ModuleList() @@ -166,42 +200,6 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.use_moe_head = config.use_moe_head self.use_softmoe_head = config.use_softmoe_head - - self.to(self.device) - - # Initialize configuration parameters - self._initialize_config_parameters() - - # Initialize patterns for block masks - self._initialize_patterns() - - self.hidden_size = config.embed_dim // config.num_heads - - - # Initialize action embedding table - if self.continuous_action_space: - # TODO: check the effect of SimNorm - # self.act_embedding_table = nn.Sequential( - # nn.Linear(config.action_space_size, config.embed_dim, device=self.device, bias=False), - # SimNorm(simnorm_dim=self.group_size)) - # print(f'config.action_space_size_list:{config.action_space_size_list}') - self.act_embedding_table = nn.ModuleList([ - nn.Sequential( - nn.Linear(config.action_space_size_list[task_id], self.obs_act_embed_dim, device=self.device, bias=False), - SimNorm(simnorm_dim=self.group_size) - ) - for task_id in range(self.task_num) - ]) - else: - # for discrete action space - self.act_embedding_table = nn.Embedding(config.action_space_size, self.obs_act_embed_dim, device=self.device) - print(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") - - print(f'='*20) - print(f"self.obs_act_embed_dim:{self.obs_act_embed_dim}") - print(f'='*20) - - # if self.num_experts_in_moe_head == -1: assert self.num_experts_in_moe_head > 0 if self.use_normal_head: @@ -211,13 +209,20 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: print('We use normal head') # TODO: Normal Head for task_id in range(self.task_num): + if isinstance(self.action_space_size, list): + action_size = self.action_space_size[task_id] + elif isinstance(self.action_space_size, int): + action_size = self.action_space_size + else: + raise ValueError('Unsupported action space size type: {}'.format(type(self.action_space_size))) + if self.continuous_action_space: # TODO self.sigma_type = self.config.sigma_type self.bound_type = self.config.bound_type - self.head_policy = self._create_head_cont(self.value_policy_tokens_pattern, self.config.action_space_size_list[task_id]) # TODO + self.head_policy = self._create_head_cont(self.value_policy_tokens_pattern, action_size) # TODO else: - self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) + self.head_policy = self._create_head(self.value_policy_tokens_pattern, action_size) if not self.share_head or task_id == 0: self.head_policy_multi_task.append(self.head_policy) @@ -278,8 +283,22 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: print("="*20) print(f"self.head_dict:{self.head_dict}") + + if isinstance(self.tokenizer.encoder, torch.nn.ModuleList): + skip_modules = set() + for m in self.tokenizer.encoder: + skip_modules.update(m.pretrained_model.modules()) + else: + skip_modules = set(self.tokenizer.encoder.pretrained_model.modules()) + + def custom_init(module): + if module in skip_modules: + return + init_weights(module, norm_type=self.config.norm_type) + # Apply weight initialization, the order is important - self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) + self.apply(custom_init) + # self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) self._initialize_last_layer() # Cache structures @@ -420,6 +439,7 @@ def _create_head_moe(self, block_mask: torch.Tensor, output_dim: int, norm_layer block_mask=block_mask, head_module=nn.Sequential(*modules) ) + def get_moe(self, name): """Get or create a MoE instance""" from .moe import MoELayer, MultiplicationFeedForward @@ -710,13 +730,6 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu if self.task_embed_option == "add_task_embed": obs_embeddings = obs_embeddings + self.task_embeddings elif self.task_embed_option == "concat_task_embed": - - # print(f'=='*20) - # print(f"is_init_infer:{is_init_infer}") - # print(f'obs_embeddings.shape:{obs_embeddings.shape}') - # print(f'self.task_embeddings.shape:{self.task_embeddings.shape}') - # print(f'=='*20) - # if is_init_infer: # # 注意只有在inference时,只有在is_init_infer时拼接task embeddings,recurr_infer中已经在init_infer中增加了task embeddings的信息了 # # Expand task embeddings to match the sequence shape @@ -753,10 +766,15 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu if len(act_tokens.shape) == 3: act_tokens = act_tokens.squeeze(1) num_steps = act_tokens.size(1) - if self.task_num >= 1 and self.continuous_action_space: - act_embeddings = self.act_embedding_table[task_id](act_tokens) + + if isinstance(self.act_embedding_table, nn.ModuleList): + if task_id >= len(self.act_embedding_table): + act_embeddings = self.act_embedding_table[0](act_tokens) + else: + act_embeddings = self.act_embedding_table[task_id](act_tokens) else: act_embeddings = self.act_embedding_table(act_tokens) + if self.task_embed_option == "add_task_embed": # TODO: 对于action_token不需要增加task_embeddings会造成歧义,反而干扰学习 @@ -782,7 +800,7 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu if self.continuous_action_space: sequences, num_steps = self._process_obs_act_combined_cont(obs_embeddings_or_act_tokens, prev_steps, task_id=task_id) else: - sequences, num_steps = self._process_obs_act_combined(obs_embeddings_or_act_tokens, prev_steps) + sequences, num_steps = self._process_obs_act_combined(obs_embeddings_or_act_tokens, prev_steps, task_id=task_id) # Pass sequences through transformer @@ -938,7 +956,15 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, ta -1) num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) - act_embeddings = self.act_embedding_table(act_tokens) + + if isinstance(self.act_embedding_table, nn.ModuleList): + if task_id >= len(self.act_embedding_table): + act_embeddings = self.act_embedding_table[0](act_tokens) + else: + act_embeddings = self.act_embedding_table[task_id](act_tokens) + else: + act_embeddings = self.act_embedding_table(act_tokens) + B, L, K, E = obs_embeddings.size() if self.task_embed_option == "concat_task_embed": @@ -972,39 +998,6 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, ta return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps - - #@profile - # def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, task_id=0): - # """ - # Process combined observation embeddings and action tokens. - - # Arguments: - # - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. - # - prev_steps (:obj:`torch.Tensor`): Previous steps. - # Returns: - # - torch.Tensor: Combined observation and action embeddings with position information added. - # """ - # obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] - # if len(obs_embeddings.shape) == 3: - # obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, - # -1) - - # num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) - # # act_embeddings = self.act_embedding_table[task_id](act_tokens) - # act_embeddings = self.act_embedding_table(act_tokens) - - # B, L, K, E = obs_embeddings.size() - # obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device) - - # for i in range(L): - # # obs = obs_embeddings[:, i, :, :] - # obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings - # act = act_embeddings[:, i, 0, :].unsqueeze(1) - # obs_act = torch.cat([obs, act], dim=1) - # obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - - # return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps - #@profile def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths, task_id=0): """ @@ -1218,12 +1211,6 @@ def wm_forward_for_initial_inference(self, last_obs_embeddings: torch.LongTensor outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, task_id=task_id) - # if self.reanalyze_phase: - # # TODO - # outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, is_init_infer=False, task_id=task_id) - # else: - # outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, is_init_infer=True, task_id=task_id) - # select the last timestep for each sample last_steps_value = outputs_wm.logits_value[:, -1:, :] outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) @@ -1284,23 +1271,7 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0, if not self.continuous_action_space: token = action.reshape(-1, 1) else: - token = action.reshape(-1, self.config.action_space_size_list[task_id]) - - # ======= Print statistics for debugging ============= - # min_size = min(self.keys_values_wm_size_list) - # if min_size >= self.config.max_tokens - 5: - # self.length_largethan_maxminus5_context_cnt += len(self.keys_values_wm_size_list) - # if min_size >= self.config.max_tokens - 7: - # self.length_largethan_maxminus7_context_cnt += len(self.keys_values_wm_size_list) - # if self.total_query_count > 0 and self.total_query_count % 10000 == 0: - # self.hit_freq = self.hit_count / self.total_query_count - # print('total_query_count:', self.total_query_count) - # length_largethan_maxminus5_context_cnt_ratio = self.length_largethan_maxminus5_context_cnt / self.total_query_count - # print('recurrent largethan_maxminus5_context:', self.length_largethan_maxminus5_context_cnt) - # print('recurrent largethan_maxminus5_context_ratio:', length_largethan_maxminus5_context_cnt_ratio) - # length_largethan_maxminus7_context_cnt_ratio = self.length_largethan_maxminus7_context_cnt / self.total_query_count - # print('recurrent largethan_maxminus7_context_ratio:', length_largethan_maxminus7_context_cnt_ratio) - # print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt) + token = action.reshape(-1, self.config.action_space_size[task_id]) # Trim and pad kv_cache self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False) @@ -1616,16 +1587,7 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, task_id=task_id ) - # if self.reanalyze_phase: - # self.forward( - # {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, - # past_keys_values=self.keys_values_wm_single_env, is_init_infer=False, task_id=task_id - # ) - # else: - # self.forward( - # {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, - # past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, task_id=task_id - # ) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) self.keys_values_wm_size_list.append(1) @@ -1826,7 +1788,8 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar if self.do_analysis: # Calculate dormant ratio of the encoder shape = batch['observations'].shape # (..., C, H, W) - inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) + inputs = batch['observations'].contiguous().view(-1, shape[-1]) # (b, s, h) -> (b * s, h) + # inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) if self.continuous_action_space: encoder_index = task_id else: @@ -1852,10 +1815,12 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # representation 层在 model.named_modules() 的名称为 "representation" # print(f"self.tokenizer.encoder:{self.tokenizer.encoder}") - e_rank_last_linear = cal_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="last_linear") + # e_rank_last_linear = cal_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="last_linear") + e_rank_last_linear = cal_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="embed_proj_head") # print("Effective Rank of encoder_last_linear:", e_rank_last_linear) try: - e_rank_sim_norm = cal_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="final_norm") + # e_rank_sim_norm = cal_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="final_norm") + e_rank_sim_norm = cal_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="norm") except Exception as e: e_rank_sim_norm = torch.tensor(0.) @@ -1936,6 +1901,14 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar dtype=batch['observations'].dtype) perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) + elif self.obs_type == 'text': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=torch.float32) + # # Reconstruct observations from latent state representations + # reconstructed_logits = self.tokenizer.decode_to_language_logits(obs_embeddings, batch['observations']) + # # Calculate reconstruction loss + # latent_recon_loss = self.tokenizer.lm_reconstruction_loss(batch['observations'], reconstructed_logits, ignore_index=0) + latent_recon_loss = self.latent_recon_loss # Action tokens if self.continuous_action_space: @@ -1971,19 +1944,6 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # e_rank_last_linear = None # e_rank_sim_norm = None - # ========== for visualization ========== - # Uncomment the lines below for visualization - # predict_policy = outputs.logits_policy - # predict_policy = F.softmax(outputs.logits_policy, dim=-1) - # predict_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) - # predict_rewards = inverse_scalar_transform_handle(outputs.logits_rewards.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) - # import pdb; pdb.set_trace() - # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=[], suffix='pong_H10_H4_0613') - - # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_success_episode') - # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_fail_episode') - # ========== for visualization ========== - # For training stability, use target_tokenizer to compute the true next latent state representations with torch.no_grad(): target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) @@ -2031,15 +1991,6 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar loss_obs = F.kl_div(logits_reshaped.log(), labels_reshaped, reduction='none').sum(dim=-1).mean(dim=-1) - # ========== for debugging ========== - # assert not torch.isnan(logits_reshaped).any(), "logits_reshaped contains NaN values" - # assert not torch.isnan(labels_reshaped).any(), "labels_reshaped contains NaN values" - # print('loss_obs:', loss_obs.mean()) - # for name, param in self.tokenizer.encoder.named_parameters(): - # print('name, param.mean(), param.std():', name, param.mean(), param.std()) - # logits_grad = torch.autograd.grad(loss_obs.mean(), logits_observations, retain_graph=True)[0] - # print(f"logits_grad (min, max, mean): {logits_grad.min()}, {logits_grad.max()}, {logits_grad.mean()}") - # Apply mask to loss_obs mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) loss_obs = (loss_obs * mask_padding_expanded) @@ -2047,7 +1998,8 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Compute labels for policy and value labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], batch['target_policy'], - batch['mask_padding']) + batch['mask_padding'], + task_id=task_id) # Compute losses for rewards, policy, and value loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') @@ -2242,7 +2194,7 @@ def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torc #@profile def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: + mask_padding: torch.BoolTensor, task_id: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute labels for value and policy predictions. """ mask_fill = torch.logical_not(mask_padding) @@ -2253,11 +2205,18 @@ def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, ta # Fill the masked areas of value mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) labels_value = target_value.masked_fill(mask_fill_value, -100) + + if isinstance(self.action_space_size, list): + action_size = self.action_space_size[task_id] + elif isinstance(self.action_space_size, int): + action_size = self.action_space_size + else: + raise ValueError('Unsupported action space size type: {}'.format(type(self.action_space_size))) if self.continuous_action_space: return None, labels_value.reshape(-1, self.support_size) else: - return labels_policy.reshape(-1, self.action_space_size), labels_value.reshape(-1, self.support_size) + return labels_policy.reshape(-1, action_size), labels_value.reshape(-1, self.support_size) #@profile def clear_caches(self): @@ -2273,3 +2232,113 @@ def clear_caches(self): def __repr__(self) -> str: return "transformer-based latent world_model of UniZero" + + def custom_copy_kv_cache_to_shared_init_envs(self, src_kv: KeysValues, env_id) -> int: + """ + Overview: + Efficiently copies the contents of a KeysValues object to the shared pool for a specific environment in the init_infer stage. + Arguments: + - src_kv (:obj:`KeysValues`): The source KeysValues object from which data is copied. + - env_id (:obj:`int`): The identifier of the environment for which the cache is being copied. + Returns: + - index (:obj:`int`): The index in the shared pool where the KeysValues object is stored. + """ + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + + if self.shared_pool_init_infer[env_id][self.shared_pool_index_init_envs[env_id]] is None: + self.shared_pool_init_infer[env_id][self.shared_pool_index_init_envs[env_id]] = KeysValues( + src_kv_shape[0], # Number of elements (n) + src_kv_shape[1], # Number of attention heads (num_heads) + src_kv_shape[2], # Maximum number of tokens (max_tokens) + src_kv_shape[3] * src_kv_shape[1], # Embedding dimension (embed_dim) + len(src_kv), # Number of layers (num_layers) + src_kv._keys_values[0]._k_cache._cache.device, # Device where the cache is stored + ) + + dst_kv = self.shared_pool_init_infer[env_id][self.shared_pool_index_init_envs[env_id]] + + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() for efficient data transfer + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + index = self.shared_pool_index_init_envs[env_id] + self.shared_pool_index_init_envs[env_id] = (self.shared_pool_index_init_envs[env_id] + 1) % self.shared_pool_size_init + + return index + + def custom_copy_kv_cache_to_shared_wm(self, src_kv: KeysValues) -> int: + """ + Overview: + Efficiently copies the contents of a KeysValues object to the shared pool for world model usage. + Arguments: + - src_kv (:obj:`KeysValues`): The source KeysValues object from which data is copied. + Returns: + - index (:obj:`int`): The index in the shared pool where the KeysValues object is stored. + """ + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + + if self.shared_pool_wm[self.shared_pool_index_wm] is None: + # import ipdb; ipdb.set_trace() + self.shared_pool_wm[self.shared_pool_index_wm] = KeysValues( + src_kv_shape[0], # Number of elements (n) + src_kv_shape[1], # Number of attention heads (num_heads) + src_kv_shape[2], # Maximum number of tokens (max_tokens) + src_kv_shape[3] * src_kv_shape[1], # Embedding dimension (embed_dim) + len(src_kv), # Number of layers (num_layers) + src_kv._keys_values[0]._k_cache._cache.device, # Device where the cache is stored + ) + + dst_kv = self.shared_pool_wm[self.shared_pool_index_wm] + + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() for efficient data transfer + # try: + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + # except Exception as e: + # import ipdb; ipdb.set_trace() + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + self.shared_pool_index_wm = (self.shared_pool_index_wm + 1) % self.shared_pool_size_wm + + return dst_kv + + def custom_copy_kv_cache_to_shared_recur(self, src_kv: KeysValues) -> int: + """ + Overview: + Efficiently copies the contents of a KeysValues object to the shared pool for recurrent inference. + Arguments: + - src_kv (:obj:`KeysValues`): The source KeysValues object from which data is copied. + Returns: + - index (:obj:`int`): The index in the shared pool where the KeysValues object is stored. + """ + src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape + + if self.shared_pool_recur_infer[self.shared_pool_index] is None: + self.shared_pool_recur_infer[self.shared_pool_index] = KeysValues( + src_kv_shape[0], # Number of elements (n) + src_kv_shape[1], # Number of attention heads (num_heads) + src_kv_shape[2], # Maximum number of tokens (max_tokens) + src_kv_shape[3] * src_kv_shape[1], # Embedding dimension (embed_dim) + len(src_kv), # Number of layers (num_layers) + src_kv._keys_values[0]._k_cache._cache.device, # Device where the cache is stored + ) + + dst_kv = self.shared_pool_recur_infer[self.shared_pool_index] + + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): + # Copy the key and value caches using torch.copy_() for efficient data transfer + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + index = self.shared_pool_index + self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size + + return index + diff --git a/lzero/model/utils.py b/lzero/model/utils.py index c849aedca..77be52646 100644 --- a/lzero/model/utils.py +++ b/lzero/model/utils.py @@ -199,7 +199,10 @@ def cal_dormant_ratio( if not hasattr(model, "encoder") and not hasattr(model, "transformer") and not hasattr(model, "head"): # 如果传入的是self.tokenizer.encoder - parts["model"] = model + if hasattr(model, "pretrained_model"): + parts["pretrained_model"] = model.pretrained_model + else: + parts["model"] = model # 定义要捕获的目标模块类型 TODO: 增加更多模块 target_modules = (nn.Conv2d, nn.Linear) diff --git a/lzero/policy/sampled_unizero_multitask.py b/lzero/policy/sampled_unizero_multitask.py index 3ba7b3fbb..40227184a 100644 --- a/lzero/policy/sampled_unizero_multitask.py +++ b/lzero/policy/sampled_unizero_multitask.py @@ -32,7 +32,7 @@ from ding.utils import set_pkg_seed, get_rank, get_world_size import sys -sys.path.append('/mnt/afs/niuyazhe/code/LibMTL/') + from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect # from LibMTL.weighting.CAGrad_unizero import CAGrad as GradCorrect diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py index 13ba63eb2..4585d3e24 100644 --- a/lzero/policy/unizero_multitask.py +++ b/lzero/policy/unizero_multitask.py @@ -17,9 +17,7 @@ from .utils import configure_optimizers_nanogpt import sys -sys.path.append('/cpfs04/user/puyuan/code/LibMTL') -# sys.path.append('/fs-computility/niuyazhe/puyuan/code/LibMTL') - +# TODO need to install the LibMTL package from the following link: https://github.com/puyuan1996/LibMTL from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect # from LibMTL.weighting.moco_generic import GenericMoCo, MoCoCfg # from LibMTL.weighting.moco_fast import FastMoCo, MoCoCfg @@ -68,7 +66,6 @@ def generate_task_loss_dict(multi_task_losses, task_name_template, task_id): return task_loss_dict - class WrappedModel: def __init__(self, world_model): self.world_model = world_model @@ -428,7 +425,7 @@ def _init_learn(self) -> None: if self._cfg.cos_lr_scheduler: self.lr_scheduler = CosineAnnealingLR( - self._optimizer_world_model, T_max=int(2e5), eta_min=0, last_epoch=-1 + self._optimizer_world_model, T_max=int(1e5), eta_min=0, last_epoch=-1 ) # TODO elif self._cfg.piecewise_decay_lr_scheduler: # Example step scheduler, adjust milestones and gamma as needed @@ -645,7 +642,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr # Convert to categorical distributions target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) target_value_categorical = phi_transform(self.value_support, transformed_target_value) - + # Prepare batch for a transformer-based world model batch_for_gpt = {} if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1: @@ -772,7 +769,6 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr e_rank_last_linear_multi_task.append(e_rank_last_linear) e_rank_sim_norm_multi_task.append(e_rank_sim_norm) - # Core learn model update step self._optimizer_world_model.zero_grad() @@ -1081,7 +1077,7 @@ def _monitor_vars_learn(self, num_tasks=2) -> List[str]: if num_tasks is not None: for var in task_specific_vars: for task_idx in range(num_tasks): - # print(f"learner policy Rank {rank}, self.task_id: {self.task_id}") + print(f"learner policy Rank, self.task_id: {self.task_id+task_idx}") monitored_vars.append(f'{var}_task{self.task_id+task_idx}') else: # If num_tasks is not provided, we assume there's only one task and keep the original variable names diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 0299abf8f..15bc2da51 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -1,3 +1,4 @@ +import logging import time from collections import deque, namedtuple from typing import Optional, Any, List @@ -58,6 +59,7 @@ def __init__( - task_id (:obj:`int`): Unique identifier for the task. If None, that means we are in the single task mode. """ self.task_id = task_id + self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq @@ -65,6 +67,9 @@ def __init__( self._end_flag = False self._rank = get_rank() + + print(f'rank {self._rank}, self.task_id: {self.task_id}') + self._world_size = get_world_size() if self._rank == 0: if tb_logger is not None: @@ -82,7 +87,9 @@ def __init__( self._logger, _ = build_logger( path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False ) - self._tb_logger = None + # =========== TODO: for unizero_multitask ddp_v2 ======== + self._tb_logger = tb_logger + self.policy_config = policy_config self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy @@ -122,7 +129,7 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: self._logger.debug( 'Set default n_episode mode(n_episode({}), env_num({}))'.format(self._default_n_episode, self._env_num) ) - self._policy.reset() + self._policy.reset(task_id=self.task_id) def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ @@ -256,13 +263,10 @@ def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegm pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] # NOTE: for unizero - beg_index = 0 - end_index = beg_index + self.policy_config.num_unroll_steps + self.policy_config.td_steps - pad_action_lst = game_segments[i].action_segment[beg_index:end_index] - + pad_action_lst = game_segments[i].action_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] + # NOTE: for unizero - pad_child_visits_lst = game_segments[i].child_visit_segment[ - :self.policy_config.num_unroll_steps + self.policy_config.td_steps] + pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] # EfficientZero original repo bug: # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] @@ -310,7 +314,7 @@ def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegm # put the game segment into the pool self.game_segment_pool.append((last_game_segments[i], last_game_priorities[i], done[i])) - # reset last game_segments + # reset last game_segments and last game_priorities for the next collection last_game_segments[i] = None last_game_priorities[i] = None @@ -378,7 +382,8 @@ def collect(self, GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) for _ in range(env_nums) ] # stacked observation windows in reset stage for init game_segments @@ -388,6 +393,7 @@ def collect(self, [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], maxlen=self.policy_config.model.frame_stack_num ) + game_segments[env_id].reset(observation_window_stack[env_id]) dones = np.array([False for _ in range(env_nums)]) @@ -448,14 +454,13 @@ def collect(self, # ============================================================== # Key policy forward step # ============================================================== - # print(f'ready_env_id:{ready_env_id}') - # policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) if self.task_id is None: # single task setting - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep) + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) else: - # multi-task setting - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep, task_id=self.task_id) + # multi task setting + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, task_id=self.task_id) + # Extract relevant policy outputs actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} @@ -580,9 +585,9 @@ def collect(self, completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) eps_steps_lst[env_id] += 1 - if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero', 'unizero_multitask', 'sampled_unizero_multitask']: - # TODO: only for UniZero now - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) # NOTE: reset_init_data=False + if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: + # ============ only for UniZero now ============ + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) total_transitions += 1 @@ -623,7 +628,8 @@ def collect(self, game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) game_segments[env_id].reset(observation_window_stack[env_id]) @@ -704,7 +710,8 @@ def collect(self, game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) observation_window_stack[env_id] = deque( [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], @@ -727,7 +734,8 @@ def collect(self, visit_entropies_lst[env_id] = 0 # Env reset is done by env_manager automatically - self._policy.reset([env_id]) # NOTE: reset the policy for the env_id. Default reset_init_data=True. + # NOTE: ============ reset the policy for the env_id. Default reset_init_data=True. ================ + self._policy.reset([env_id], task_id=self.task_id) self._reset_stat(env_id) ready_env_id.remove(env_id) @@ -744,16 +752,12 @@ def collect(self, break collected_duration = sum([d['time'] for d in self._episode_info]) - - # reduce data when enables DDP - if self._world_size > 1: - # Before allreduce - self._logger.info(f"Rank {self._rank} before allreduce: collected_step={collected_step}, collected_episode={collected_episode}") - collected_step = allreduce_data(collected_step, 'sum') - collected_episode = allreduce_data(collected_episode, 'sum') - collected_duration = allreduce_data(collected_duration, 'sum') - # After allreduce - self._logger.info(f"Rank {self._rank} after allreduce: collected_step={collected_step}, collected_episode={collected_episode}") + # TODO: for multitask new ddp pipeline + # 再多任务情况下,只有多个进程处理同一个任务的时候才需要allreduce, 单进程处理1~多任务的时候不需要allreduce + # if self._world_size > 1: + # collected_step = allreduce_data(collected_step, 'sum') + # collected_episode = allreduce_data(collected_episode, 'sum') + # collected_duration = allreduce_data(collected_duration, 'sum') self._total_envstep_count += collected_step self._total_episode_count += collected_episode @@ -770,8 +774,9 @@ def _output_log(self, train_iter: int) -> None: Arguments: - train_iter (:obj:`int`): Current training iteration number for logging context. """ - if self._rank != 0: - return + # TODO: for multitask new ddp pipeline,since each process has different tasks to handle, each process needs to output logs + # if self._rank != 0: + # return if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: self._last_train_iter = train_iter episode_count = len(self._episode_info) @@ -804,6 +809,7 @@ def _output_log(self, train_iter: int) -> None: if self.policy_config.gumbel_algo: info['completed_value'] = np.mean(completed_value) self._episode_info.clear() + print(f'collector output_log: rank {self._rank}, self.task_id: {self.task_id}') self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) for k, v in info.items(): if k in ['each_reward']: diff --git a/lzero/worker/muzero_segment_collector.py b/lzero/worker/muzero_segment_collector.py index 6ef81f4d5..2012bc54a 100644 --- a/lzero/worker/muzero_segment_collector.py +++ b/lzero/worker/muzero_segment_collector.py @@ -59,6 +59,7 @@ def __init__( - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. - instance_name (:obj:`str`): Unique identifier for this collector instance. - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. + - task_id (:obj:`int`): Unique identifier for the task. If None, that means we are in the single task mode. """ self.task_id = task_id @@ -72,7 +73,6 @@ def __init__( print(f'rank {self._rank}, self.task_id: {self.task_id}') - self._world_size = get_world_size() if self._rank == 0: if tb_logger is not None: @@ -128,7 +128,6 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: assert hasattr(self, '_env'), "please set env first" if _policy is not None: self._policy = _policy - self._default_num_segments = _policy.get_attribute('cfg').get('num_segments', None) self._logger.debug( 'Set default num_segments mode(num_segments({}), env_num({}))'.format(self._default_num_segments, self._env_num) @@ -367,11 +366,10 @@ def collect(self, collected_episode = 0 collected_step = 0 env_nums = self._env_num - + retry_waiting_time = 0.05 # initializations init_obs = self._env.ready_obs - retry_waiting_time = 0.05 while len(init_obs.keys()) != self._env_num: # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to # len(self._env.ready_obs), especially in tictactoe env. @@ -481,7 +479,7 @@ def collect(self, # ============================================================== # Key policy forward step # ============================================================== - # print(f'ready_env_id:{ready_env_id}') + if self.task_id is None: # single task setting policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep) @@ -754,7 +752,7 @@ def collect(self, break collected_duration = sum([d['time'] for d in self._episode_info]) - # TODO: for atari multitask new ddp pipeline + # TODO: for multitask new ddp pipeline # reduce data when enables DDP # if self._world_size > 1: # collected_step = allreduce_data(collected_step, 'sum') @@ -776,7 +774,7 @@ def _output_log(self, train_iter: int) -> None: Arguments: - train_iter (:obj:`int`): Current training iteration number for logging context. """ - # TODO: for atari multitask new ddp pipeline + # TODO: for multitask new ddp pipeline # if self._rank != 0: # return if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: @@ -819,8 +817,7 @@ def _output_log(self, train_iter: int) -> None: if self.task_id is None: self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) else: - self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, - train_iter) + self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, train_iter) if k in ['total_envstep_count']: continue if self.task_id is None: diff --git a/zoo/jericho/configs/jericho_unizero_multitask_config.py b/zoo/jericho/configs/jericho_unizero_multitask_config.py new file mode 100644 index 000000000..32d76fcb1 --- /dev/null +++ b/zoo/jericho/configs/jericho_unizero_multitask_config.py @@ -0,0 +1,204 @@ +from easydict import EasyDict + +def create_config(env_id, max_steps, max_action_num, action_space_size, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, total_batch_size, + num_layers, model_name, replay_ratio, norm_type, update_per_collect): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=512, + max_steps=max_steps, + max_action_num=max_action_num, + tokenizer_path=model_name, + max_seq_len=512, + game_path=f"./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", + for_unizero=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + ), + policy=dict( + multi_gpu=True, # Very important for ddp + only_use_moco_stats=False, + use_moco=False, + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + use_wandb=False, + learn=dict( + learner=dict( + hook=dict( + save_ckpt_after_iter=200000 + ), + ), + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=512, + action_space_size=action_space_size, + encoder_url=model_name, + model_type="mlp", + norm_type=norm_type, + continuous_action_space=False, + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + share_head=False, # TODO + policy_entropy_weight=5e-2, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device="cuda", + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=24, + obs_type="text", # TODO: Modify as needed. + env_num=max(collector_env_num, evaluator_env_num), + task_embed_option=None, + use_task_embed=False, + embed_dim=768, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + moe_in_transformer=False, + num_experts_in_moe_head=4, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + + lora_r= 0, + lora_alpha =1, + lora_dropout= 0.0, + + ), + ), + use_task_exploitation_weight=False, + task_complexity_weight=False, + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + update_per_collect=update_per_collect, + action_type="varied_action_space", + replay_ratio=replay_ratio, + batch_size=batch_size, + reanalyze_ratio=reanalyze_ratio, + learning_rate=0.0001, + cos_lr_scheduler=False, + fixed_temperature_value=0.25, + manual_temperature_decay=False, + num_simulations=num_simulations, + n_episode=n_episode, + train_start_after_envsteps=int(0), + replay_buffer_size=int(5e5), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + + +def generate_configs(env_id_list, env_configurations, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + total_batch_size, num_layers, model_name, replay_ratio, norm_type): + configs = [] + # ===== only for debug ===== + exp_name_prefix = f'data_lz/data_unizero_jericho_mt_20250501/jericho_{len(env_id_list)}games_tbs{total_batch_size}-nlayer{num_layers}-rr{replay_ratio}_not-share-head_encoder-final-ln_seed{seed}/' + + action_space_size_list = [v[0] for _, v in env_configurations.items()] + max_steps_list = [v[1] for _, v in env_configurations.items()] + + for task_id, env_id in enumerate(env_id_list): + max_action_num, max_steps = env_configurations.get(env_id, (10, 50)) # Default values if env_id not found + update_per_collect = 40 # Ensure at least one update per collect + + config = create_config( + env_id=env_id, max_steps=max_steps, max_action_num=max_action_num, action_space_size=action_space_size_list, + collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_episode=n_episode, + num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, + num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, total_batch_size=total_batch_size, + num_layers=num_layers, model_name=model_name, replay_ratio=replay_ratio, + norm_type=norm_type, update_per_collect=update_per_collect + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('.z5')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='jericho', + import_names=['zoo.jericho.envs.jericho_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + from lzero.entry import train_unizero_multitask + import os + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + env_configurations = { + 'detective.z5': (12, 100), + 'omniquest.z5': (25, 100), + 'acorncourt.z5': (45, 50), + 'zork1.z5': (55, 500), + } + env_id_list = ['detective.z5', 'omniquest.z5', 'acorncourt.z5', 'zork1.z5'] + + # Model name or path - configurable according to the predefined model paths or names + model_name: str = 'BAAI/bge-base-en-v1.5' + replay_ratio = 0.1 + norm_type = 'BN' + + collector_env_num = 4 + n_episode = 4 + evaluator_env_num = 2 + num_simulations = 50 + max_env_step = int(5e5) + reanalyze_ratio = 0.0 + + total_batch_size =int(64*4) + batch_size = [int(total_batch_size / len(env_id_list)) for _ in range(len(env_id_list))] + + num_layers=2 + num_unroll_steps = 10 + infer_context_length = 4 + buffer_reanalyze_freq = 1 / 1000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + for seed in [0]: + configs = generate_configs( env_id_list=env_id_list, env_configurations=env_configurations, + collector_env_num=collector_env_num, n_episode=n_episode, + evaluator_env_num=evaluator_env_num, num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, + num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, + seed=seed, buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, + total_batch_size=total_batch_size, num_layers=num_layers, + model_name=model_name, replay_ratio=replay_ratio, + norm_type=norm_type) + + train_unizero_multitask(configs, seed=seed, max_env_step=max_env_step) diff --git a/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py new file mode 100644 index 000000000..960f3f4d1 --- /dev/null +++ b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py @@ -0,0 +1,219 @@ +from easydict import EasyDict + +def create_config(env_id, max_steps, max_action_num, action_space_size, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, total_batch_size, + num_layers, model_name, replay_ratio, norm_type, update_per_collect): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=512, + max_steps=max_steps, + max_action_num=max_action_num, + tokenizer_path=model_name, + max_seq_len=512, + game_path=f"./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", + for_unizero=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + ), + policy=dict( + multi_gpu=True, # Very important for ddp + only_use_moco_stats=False, + use_moco=False, # Whether to use MoCo for multi-task gradient adjustments + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + use_wandb=False, + learn=dict( + learner=dict( + hook=dict( + save_ckpt_after_iter=200000 + ), + ), + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=512, + action_space_size=action_space_size, + encoder_url=model_name, + model_type="mlp", + norm_type=norm_type, + continuous_action_space=False, + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + share_head=False, # TODO + policy_entropy_weight=5e-2, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device="cuda", + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=24, + obs_type="text", + env_num=max(collector_env_num, evaluator_env_num), + task_embed_option=None, + use_task_embed=False, + embed_dim=768, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + moe_in_transformer=False, + multiplication_moe_in_transformer=False, # Whether to use moe in transformers + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + moe_use_lora=False, # Does moe use lora + lora_r= 0, + lora_alpha =1, + lora_dropout= 0.0, + + analysis_dormant_ratio_weight_rank=False, + analysis_dormant_ratio_interval=5000 + ), + ), + use_task_exploitation_weight=False, # TODO + task_complexity_weight=False, # TODO + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + update_per_collect=update_per_collect, + action_type="varied_action_space", + replay_ratio=replay_ratio, + batch_size=batch_size, + reanalyze_ratio=reanalyze_ratio, + learning_rate=0.0001, + cos_lr_scheduler=False, + fixed_temperature_value=0.25, + manual_temperature_decay=False, + num_simulations=num_simulations, + n_episode=n_episode, + train_start_after_envsteps=int(0), + replay_buffer_size=int(5e5), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + + +def generate_configs(env_id_list, env_configurations, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + total_batch_size, num_layers, model_name, replay_ratio, norm_type): + configs = [] + # ===== only for debug ===== + exp_name_prefix = f'data_scalezero/jericho_mt_moe8_{len(env_id_list)}games_tbs{total_batch_size}-nlayer{num_layers}-rr{replay_ratio}_not-share-head_encoder-final-ln_seed{seed}/' + + action_space_size_list = [v[0] for _, v in env_configurations.items()] + max_steps_list = [v[1] for _, v in env_configurations.items()] + + for task_id, env_id in enumerate(env_id_list): + max_action_num, max_steps = env_configurations.get(env_id, (10, 50)) + update_per_collect = 40 # Ensure at least one update per collect + + config = create_config( + env_id=env_id, max_steps=max_steps, max_action_num=max_action_num, action_space_size=action_space_size_list, + collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_episode=n_episode, + num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, + num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, total_batch_size=total_batch_size, + num_layers=num_layers, model_name=model_name, replay_ratio=replay_ratio, + norm_type=norm_type, update_per_collect=update_per_collect + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('.z5')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='jericho', + import_names=['zoo.jericho.envs.jericho_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + torchrun --nproc_per_node=4 ./zoo/jericho/config/jericho_unizero_multitask_ddp_config.py + """ + + from lzero.entry import train_unizero_multitask_ddp + from ding.utils import DDPContext + import os + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + env_configurations = { + 'detective.z5': (12, 100), + 'omniquest.z5': (25, 100), + 'acorncourt.z5': (45, 50), + 'zork1.z5': (55, 500), + } + env_id_list = ['detective.z5', 'omniquest.z5', 'acorncourt.z5', 'zork1.z5'] + + # Model name or path - configurable according to the predefined model paths or names + model_name: str = 'BAAI/bge-base-en-v1.5' + replay_ratio = 0.1 + norm_type = 'LN' + + collector_env_num = 4 + n_episode = 4 + evaluator_env_num = 2 + num_simulations = 50 + max_env_step = int(5e5) + reanalyze_ratio = 0.0 + + total_batch_size =int(64*4) + batch_size = [int(total_batch_size / len(env_id_list)) for _ in range(len(env_id_list))] + + num_layers=2 + num_unroll_steps = 10 + infer_context_length = 4 + buffer_reanalyze_freq = 1 / 1000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + for seed in [0]: + configs = generate_configs( env_id_list=env_id_list, env_configurations=env_configurations, + collector_env_num=collector_env_num, n_episode=n_episode, + evaluator_env_num=evaluator_env_num, num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, + num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, + seed=seed, buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, + total_batch_size=total_batch_size, num_layers=num_layers, + model_name=model_name, replay_ratio=replay_ratio, + norm_type=norm_type) + + with DDPContext(): + train_unizero_multitask_ddp(configs, seed=seed, max_env_step=max_env_step) diff --git a/zoo/jericho/configs/jericho_unizero_multitask_segment_ddp_config.py b/zoo/jericho/configs/jericho_unizero_multitask_segment_ddp_config.py new file mode 100644 index 000000000..504fc694a --- /dev/null +++ b/zoo/jericho/configs/jericho_unizero_multitask_segment_ddp_config.py @@ -0,0 +1,220 @@ +from easydict import EasyDict + +def create_config(env_id, max_steps, max_action_num, action_space_size, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, total_batch_size, + num_layers, model_name, replay_ratio, norm_type, update_per_collect, num_segments): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=512, + max_steps=max_steps, + max_action_num=max_action_num, + tokenizer_path=model_name, + max_seq_len=512, + game_path=f"./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", + for_unizero=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + ), + policy=dict( + multi_gpu=True, # Very important for ddp + only_use_moco_stats=False, + use_moco=False, # Whether to use MoCo for multi-task gradient adjustments + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + use_wandb=False, + learn=dict( + learner=dict( + hook=dict( + save_ckpt_after_iter=200000 + ), + ), + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=512, + action_space_size=action_space_size, + encoder_url=model_name, + model_type="mlp", + norm_type=norm_type, + continuous_action_space=False, + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: for latent state layer_norm + share_head=False, # TODO + policy_entropy_weight=5e-2, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device="cuda", + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=24, + obs_type="text", # TODO: Modify as needed. + env_num=max(collector_env_num, evaluator_env_num), + task_embed_option=None, # ==============TODO: none ============== + use_task_embed=False, # ==============TODO============== + embed_dim=768, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + moe_in_transformer=False, + multiplication_moe_in_transformer=False, # Whether to use moe in transformers + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + moe_use_lora=False, # Does moe use lora? + lora_r= 0, + lora_alpha =1, + lora_dropout= 0.0, + ), + ), + use_task_exploitation_weight=False, # TODO + task_complexity_weight=False, # TODO + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + update_per_collect=update_per_collect, + action_type="varied_action_space", + replay_ratio=replay_ratio, + batch_size=batch_size, + num_segments=num_segments, + reanalyze_ratio=reanalyze_ratio, + learning_rate=0.0001, + game_segment_length=50, + cos_lr_scheduler=False, + fixed_temperature_value=0.25, + manual_temperature_decay=False, + num_simulations=num_simulations, + n_episode=n_episode, + train_start_after_envsteps=int(0), # TODO: ===== only for debug ===== + replay_buffer_size=int(5e5), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + + +def generate_configs(env_id_list, env_configurations, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + total_batch_size, num_layers, model_name, replay_ratio, norm_type, num_segments): + configs = [] + # ===== only for debug ===== + exp_name_prefix = f'data_lz/data_unizero_jericho_mt_20250512/jericho_{len(env_id_list)}games_tbs{total_batch_size}-nlayer{num_layers}-rr{replay_ratio}_not-share-head_encoder-final-ln_seed{seed}/' + + action_space_size_list = [v[0] for _, v in env_configurations.items()] + max_steps_list = [v[1] for _, v in env_configurations.items()] + + for task_id, env_id in enumerate(env_id_list): + max_action_num, max_steps = env_configurations.get(env_id, (10, 50)) + update_per_collect = 40 # Ensure that all environments are updated at the same time + + config = create_config( + env_id=env_id, max_steps=max_steps, max_action_num=max_action_num, action_space_size=action_space_size_list, + collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_episode=n_episode, + num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, + num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, total_batch_size=total_batch_size, + num_layers=num_layers, model_name=model_name, replay_ratio=replay_ratio, + norm_type=norm_type, update_per_collect=update_per_collect, num_segments=num_segments + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('.z5')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='jericho', + import_names=['zoo.jericho.envs.jericho_env'], + ), + env_manager=dict(type='base'), + # env_manager=dict(type='subprocess'), # subprocess在jericho环境下不支持 + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + torchrun --nproc_per_node=4 ./zoo/jericho/config/jericho_unizero_multitask_segment_ddp_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + env_configurations = { + 'detective.z5': (12, 100), + 'omniquest.z5': (25, 100), + 'acorncourt.z5': (45, 50), + 'zork1.z5': (55, 500), + } + env_id_list = ['detective.z5', 'omniquest.z5', 'acorncourt.z5', 'zork1.z5'] + + # Model name or path - configurable according to the predefined model paths or names + model_name: str = 'BAAI/bge-base-en-v1.5' + replay_ratio = 0.1 + norm_type = 'BN' + + collector_env_num = 4 + n_episode = 4 + num_segments = 4 + evaluator_env_num = 2 + num_simulations = 50 + max_env_step = int(5e5) + reanalyze_ratio = 0.0 + + total_batch_size =int(128*4) + batch_size = [int(total_batch_size / len(env_id_list)) for _ in range(len(env_id_list))] + + num_layers=2 + num_unroll_steps = 10 + infer_context_length = 4 + buffer_reanalyze_freq = 1 / 1000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + for seed in [0]: + configs = generate_configs( env_id_list=env_id_list, env_configurations=env_configurations, + collector_env_num=collector_env_num, n_episode=n_episode, + evaluator_env_num=evaluator_env_num, num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, + num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, + seed=seed, buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, + total_batch_size=total_batch_size, num_layers=num_layers, + model_name=model_name, replay_ratio=replay_ratio, + norm_type=norm_type, num_segments=num_segments) + + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step)