-
Notifications
You must be signed in to change notification settings - Fork 182
feature(xjy): add the rnd-related features #438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
xiongjyu
wants to merge
12
commits into
opendilab:main
Choose a base branch
from
xiongjyu:dev-rnd
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
6a92678
Fix norm_type, kv_cache rewrite, _reset_collect/eval, init_weight, an…
xiongjyu 1cf8688
feature(xjy): add the rnd-related features
xiongjyu e9314d1
add dynamic control weights + intrinsic reward-state mapping graph; s…
xiongjyu ac58169
add episode-level RND intrinsic reward evaluation
xiongjyu b7015d8
fix a bug on evaluation
xiongjyu 0eb9792
modify some cfg
xiongjyu 567f9d4
fix some config
xiongjyu 4c4b98e
Share one optimizer for updating both rnd and world_model parameters.
xiongjyu 5a7037c
fix eval bug
xiongjyu 1fa399f
Optimized pre-RND random sampling and added 'update_proportion' to pr…
xiongjyu 7af57b5
adapting rnd to the jericho environment
xiongjyu f9f86f1
add cache in the jericho
xiongjyu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,235 @@ | ||
| import logging | ||
| import os | ||
| from functools import partial | ||
| from typing import Tuple, Optional | ||
|
|
||
| import torch | ||
| import wandb | ||
| from ding.config import compile_config | ||
| from ding.envs import create_env_manager | ||
| from ding.envs import get_vec_env_setting | ||
| from ding.policy import create_policy | ||
| from ding.rl_utils import get_epsilon_greedy_fn | ||
| from ding.utils import EasyTimer | ||
| from ding.utils import set_pkg_seed, get_rank, get_world_size | ||
| from ding.worker import BaseLearner | ||
| from tensorboardX import SummaryWriter | ||
| from torch.utils.tensorboard import SummaryWriter | ||
|
|
||
| from lzero.entry.utils import log_buffer_memory_usage | ||
| from lzero.policy import visit_count_temperature | ||
| from lzero.policy.random_policy import LightZeroRandomPolicy | ||
| from lzero.worker import MuZeroEvaluator as Evaluator | ||
| from lzero.worker import MuZeroSegmentCollector as Collector | ||
| from .utils import random_collect, calculate_update_per_collect, random_collect_for_rnd | ||
|
|
||
| timer = EasyTimer() | ||
|
|
||
| def train_unizero_segment_with_reward_model( | ||
| input_cfg: Tuple[dict, dict], | ||
| seed: int = 0, | ||
| model: Optional[torch.nn.Module] = None, | ||
| model_path: Optional[str] = None, | ||
| max_train_iter: Optional[int] = int(1e10), | ||
| max_env_step: Optional[int] = int(1e10), | ||
| ) -> 'Policy': | ||
| """ | ||
| Overview: | ||
| The train entry for UniZero (with muzero_segment_collector and buffer reanalyze trick), proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models. | ||
| UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms, | ||
| particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. | ||
| Arguments: | ||
| - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. | ||
| ``Tuple[dict, dict]`` type means [user_config, create_cfg]. | ||
| - seed (:obj:`int`): Random seed. | ||
| - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. | ||
| - model_path (:obj:`Optional[str]`): The pretrained model path, which should | ||
| point to the ckpt file of the pretrained model, and an absolute path is recommended. | ||
| In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. | ||
| - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. | ||
| - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. | ||
| Returns: | ||
| - policy (:obj:`Policy`): Converged policy. | ||
| """ | ||
|
|
||
| cfg, create_cfg = input_cfg | ||
|
|
||
| # Ensure the specified policy type is supported | ||
| assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'" | ||
| assert cfg.policy.use_rnd_model, "cfg.policy.use_rnd_model must be True to use RND reward model" | ||
|
|
||
| # Import the correct GameBuffer class based on the policy type | ||
| game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'} | ||
|
|
||
| GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]), | ||
| game_buffer_classes[create_cfg.policy.type]) | ||
|
|
||
| # Set device based on CUDA availability | ||
| cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' | ||
| logging.info(f'cfg.policy.device: {cfg.policy.device}') | ||
|
|
||
| # Compile the configuration | ||
| cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) | ||
|
|
||
| # Create main components: env, policy | ||
| env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) | ||
| collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) | ||
| evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) | ||
|
|
||
| collector_env.seed(cfg.seed) | ||
| evaluator_env.seed(cfg.seed, dynamic_seed=False) | ||
| set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available()) | ||
|
|
||
| policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) | ||
|
|
||
| # Load pretrained model if specified | ||
| if model_path is not None: | ||
| logging.info(f'Loading model from {model_path} begin...') | ||
| policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) | ||
| logging.info(f'Loading model from {model_path} end!') | ||
|
|
||
| # Create worker components: learner, collector, evaluator, replay buffer, commander | ||
| tb_logger = None | ||
| if get_rank() == 0: | ||
| tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial')) | ||
| learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) | ||
| if cfg.policy.use_rnd_model: | ||
| policy.rnd._init_log(tb_logger=tb_logger, _exp_name=cfg.exp_name) | ||
|
|
||
| # MCTS+RL algorithms related core code | ||
| policy_config = cfg.policy | ||
| replay_buffer = GameBuffer(policy_config) | ||
| collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name, | ||
| policy_config=policy_config) | ||
| evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode, | ||
| stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode, | ||
| tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config) | ||
| # Learner's before_run hook | ||
| learner.call_hook('before_run') | ||
|
|
||
| if cfg.policy.use_wandb and get_rank() == 0: | ||
| policy.set_train_iter_env_step(learner.train_iter, collector.envstep) | ||
|
|
||
| # Collect random data before training | ||
| if cfg.policy.random_collect_episode_num > 0: | ||
| random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) | ||
|
|
||
| if cfg.policy.rnd_random_collect_episode_num > 0: | ||
| random_collector_env_cfg = [collector_env_cfg[0] for _ in range(cfg.policy.rnd_random_collect_episode_num)] | ||
| random_collect_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in random_collector_env_cfg]) | ||
| random_data = random_collect_for_rnd(env=random_collect_env) | ||
| policy.rnd.warmup_with_random_segments(random_data) | ||
|
|
||
| batch_size = policy._cfg.batch_size | ||
|
|
||
| buffer_reanalyze_count = 0 | ||
| train_epoch = 0 | ||
| reanalyze_batch_size = cfg.policy.reanalyze_batch_size | ||
|
|
||
| if cfg.policy.multi_gpu: | ||
| # Get current world size and rank | ||
| world_size = get_world_size() | ||
| rank = get_rank() | ||
| else: | ||
| world_size = 1 | ||
| rank = 0 | ||
|
|
||
| while True: | ||
| # Log buffer memory usage | ||
| log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) | ||
|
|
||
| # Set temperature for visit count distributions | ||
| collect_kwargs = { | ||
| 'temperature': visit_count_temperature( | ||
| policy_config.manual_temperature_decay, | ||
| policy_config.fixed_temperature_value, | ||
| policy_config.threshold_training_steps_for_final_temperature, | ||
| trained_steps=learner.train_iter | ||
| ), | ||
| 'epsilon': 0.0 # Default epsilon value | ||
| } | ||
|
|
||
| # Configure epsilon for epsilon-greedy exploration | ||
| if policy_config.eps.eps_greedy_exploration_in_collect: | ||
| epsilon_greedy_fn = get_epsilon_greedy_fn( | ||
| start=policy_config.eps.start, | ||
| end=policy_config.eps.end, | ||
| decay=policy_config.eps.decay, | ||
| type_=policy_config.eps.type | ||
| ) | ||
| collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) | ||
|
|
||
| # Evaluate policy performance | ||
| if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter): | ||
| stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep, reward_model=policy.rnd) | ||
| if stop: | ||
| break | ||
| # Collect new data | ||
| new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) | ||
|
|
||
| # Determine updates per collection | ||
| update_per_collect = calculate_update_per_collect(cfg, new_data, world_size) | ||
|
|
||
| # Update replay buffer | ||
| replay_buffer.push_game_segments(new_data) | ||
| replay_buffer.remove_oldest_data_to_fit() | ||
|
|
||
| # Periodically reanalyze buffer | ||
| if cfg.policy.buffer_reanalyze_freq >= 1: | ||
| # Reanalyze buffer <buffer_reanalyze_freq> times in one train_epoch | ||
| reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq | ||
| else: | ||
| # Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch | ||
| if train_epoch > 0 and train_epoch % int(1/cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): | ||
| with timer: | ||
| # Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence) | ||
| replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) | ||
| buffer_reanalyze_count += 1 | ||
| logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') | ||
| logging.info(f'Buffer reanalyze time: {timer.value}') | ||
|
|
||
| # Train the policy if sufficient data is available | ||
| if collector.envstep > cfg.policy.train_start_after_envsteps: | ||
| if cfg.policy.sample_type == 'episode': | ||
| data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size | ||
| else: | ||
| data_sufficient = replay_buffer.get_num_of_transitions() > batch_size | ||
| if not data_sufficient: | ||
| logging.warning( | ||
| f'The data in replay_buffer is not sufficient to sample a mini-batch: ' | ||
| f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....' | ||
| ) | ||
| continue | ||
|
|
||
| for i in range(update_per_collect): | ||
| if cfg.policy.buffer_reanalyze_freq >= 1: | ||
| # Reanalyze buffer <buffer_reanalyze_freq> times in one train_epoch | ||
| if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition): | ||
| with timer: | ||
| # Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence) | ||
| replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) | ||
| buffer_reanalyze_count += 1 | ||
| logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') | ||
| logging.info(f'Buffer reanalyze time: {timer.value}') | ||
|
|
||
| train_data = replay_buffer.sample(batch_size, policy) | ||
| train_data.append(learner.train_iter) | ||
| if cfg.policy.use_wandb: | ||
| policy.set_train_iter_env_step(learner.train_iter, collector.envstep) | ||
| log_vars = learner.train(train_data, collector.envstep) | ||
| logging.info(f'[{i}/{update_per_collect}]: learner ended training step.') | ||
|
|
||
| if cfg.policy.use_priority: | ||
| replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) | ||
|
|
||
| train_epoch += 1 | ||
| policy.recompute_pos_emb_diff_and_clear_cache() | ||
|
|
||
| # Check stopping criteria | ||
| if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: | ||
| break | ||
|
|
||
| learner.call_hook('after_run') | ||
| if cfg.policy.use_wandb: | ||
| wandb.finish() | ||
| return policy |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.