Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
from .train_rezero import train_rezero
from .train_unizero import train_unizero
from .train_unizero_segment import train_unizero_segment
from .train_unizero_segment_with_reward_model import train_unizero_segment_with_reward_model
from .utils import *
256 changes: 256 additions & 0 deletions lzero/entry/train_unizero_segment_with_reward_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import logging
import os
from functools import partial
from typing import Tuple, Optional

import torch
import wandb
from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import EasyTimer
from ding.utils import set_pkg_seed, get_rank, get_world_size
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroSegmentCollector as Collector
from lzero.reward_model.rnd_reward_model import RNDRewardModel
from .utils import random_collect, calculate_update_per_collect

timer = EasyTimer()

def train_unizero_segment_with_reward_model(
input_cfg: Tuple[dict, dict],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
model_path: Optional[str] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy':
"""
Overview:
The train entry for UniZero (with muzero_segment_collector and buffer reanalyze trick), proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models.
UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms,
particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.
Arguments:
- input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- model_path (:obj:`Optional[str]`): The pretrained model path, which should
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""

cfg, create_cfg = input_cfg

# Ensure the specified policy type is supported
assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'"
assert cfg.policy.use_rnd_model, "cfg.policy.use_rnd_model must be True to use RND reward model"

# Import the correct GameBuffer class based on the policy type
game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'}

GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]),
game_buffer_classes[create_cfg.policy.type])

# Set device based on CUDA availability
cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
logging.info(f'cfg.policy.device: {cfg.policy.device}')

# Compile the configuration
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)

# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available())

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# Load pretrained model if specified
if model_path is not None:
logging.info(f'Loading model from {model_path} begin...')
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
logging.info(f'Loading model from {model_path} end!')

# Create worker components: learner, collector, evaluator, replay buffer, commander
tb_logger = None
if get_rank() == 0:
tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name, use_tensorboard=cfg.policy.use_tensorboard)

# MCTS+RL algorithms related core code
policy_config = cfg.policy
replay_buffer = GameBuffer(policy_config)
collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name,
policy_config=policy_config)
evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode,
tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config)



# ==============================================================
# 新增: 初始化 RND 奖励模型
# RNDRewardModel 需要策略模型中的表征网络(作为预测器)和目标表征网络(作为固定目标)
# 对于 UniZero,tokenizer 扮演了表征网络的功能。
# ==============================================================
reward_model = RNDRewardModel(
config=cfg.reward_model,
device=policy.collect_mode.get_attribute('device'),
tb_logger=tb_logger,
exp_name=cfg.exp_name,
representation_network=policy._learn_model.representation_network,
target_representation_network=policy._target_model_for_intrinsic_reward.representation_network,
use_momentum_representation_network=cfg.policy.use_momentum_representation_network,
bp_update_sync=cfg.policy.bp_update_sync,
multi_gpu=cfg.policy.multi_gpu,
)
# Learner's before_run hook
learner.call_hook('before_run')

if cfg.policy.use_wandb and get_rank() == 0:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

# Collect random data before training
if cfg.policy.random_collect_data:
random_data = random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
try:
reward_model.warmup_with_random_segments(random_data)
except Exception as e:
logging.exception(f"Failed to warm up RND normalization using random data: {e}")
raise
batch_size = policy._cfg.batch_size

buffer_reanalyze_count = 0
train_epoch = 0
reanalyze_batch_size = cfg.policy.reanalyze_batch_size

if cfg.policy.multi_gpu:
# Get current world size and rank
world_size = get_world_size()
rank = get_rank()
else:
world_size = 1
rank = 0

while True:
# Log buffer memory usage
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)

# Set temperature for visit count distributions
collect_kwargs = {
'temperature': visit_count_temperature(
policy_config.manual_temperature_decay,
policy_config.fixed_temperature_value,
policy_config.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
),
'epsilon': 0.0 # Default epsilon value
}

# Configure epsilon for epsilon-greedy exploration
if policy_config.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(
start=policy_config.eps.start,
end=policy_config.eps.end,
decay=policy_config.eps.decay,
type_=policy_config.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)

# Evaluate policy performance
if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break

# Collect new data
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# Determine updates per collection
update_per_collect = calculate_update_per_collect(cfg, new_data, world_size)

# Update replay buffer
replay_buffer.push_game_segments(new_data)
replay_buffer.remove_oldest_data_to_fit()

# Periodically reanalyze buffer
if cfg.policy.buffer_reanalyze_freq >= 1:
# Reanalyze buffer <buffer_reanalyze_freq> times in one train_epoch
reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq
else:
# Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch
if train_epoch > 0 and train_epoch % int(1/cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition):
with timer:
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy)
buffer_reanalyze_count += 1
logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}')
logging.info(f'Buffer reanalyze time: {timer.value}')

# Train the policy if sufficient data is available
if collector.envstep > cfg.policy.train_start_after_envsteps:
if cfg.policy.sample_type == 'episode':
data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size
else:
data_sufficient = replay_buffer.get_num_of_transitions() > batch_size
if not data_sufficient:
logging.warning(
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....'
)
continue

for i in range(update_per_collect):
if cfg.policy.buffer_reanalyze_freq >= 1:
# Reanalyze buffer <buffer_reanalyze_freq> times in one train_epoch
if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition):
with timer:
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy)
buffer_reanalyze_count += 1
logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}')
logging.info(f'Buffer reanalyze time: {timer.value}')

train_data = replay_buffer.sample(batch_size, policy)
if cfg.policy.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

train_data_augmented = reward_model.estimate(train_data)
train_data_augmented.append(learner.train_iter)

log_vars = learner.train(train_data_augmented, collector.envstep)
reward_model.train_with_policy_batch(train_data)
logging.info(f'[{i}/{update_per_collect}]: learner and reward_model ended training step.')

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

train_epoch += 1
policy.recompute_pos_emb_diff_and_clear_cache()

# Check stopping criteria
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break

learner.call_hook('after_run')
if cfg.policy.use_wandb:
wandb.finish()
return policy
18 changes: 6 additions & 12 deletions lzero/entry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def initialize_pad_batch(observation_shape: Union[int, List[int], Tuple[int]], b
else:
raise TypeError(f"observation_shape must be int, list, or tuple, but got {type(observation_shape).__name__}")

return torch.full(shape, fill_value=pad_token_id, dtype=torch.long, device=device)
return torch.full(shape, fill_value=pad_token_id, dtype=torch.float32, device=device) if pad_token_id == -1 else torch.full(shape, fill_value=pad_token_id, dtype=torch.long, device=device)

def random_collect(
policy_cfg: 'EasyDict', # noqa
Expand All @@ -149,8 +149,8 @@ def random_collect(
collector_env: 'BaseEnvManager', # noqa
replay_buffer: 'IBuffer', # noqa
postprocess_data_fn: Optional[Callable] = None
) -> None: # noqa
assert policy_cfg.random_collect_episode_num > 0
) -> list: # noqa
assert policy_cfg.random_collect_data, "random_collect_data should be True."

random_policy = RandomPolicy(cfg=policy_cfg, action_space=collector_env.env_ref.action_space)
# set the policy to random policy
Expand All @@ -161,7 +161,7 @@ def random_collect(
collect_kwargs = {'temperature': 1, 'epsilon': 0.0}

# Collect data by default config n_sample/n_episode.
new_data = collector.collect(n_episode=policy_cfg.random_collect_episode_num, train_iter=0,
new_data = collector.collect(train_iter=0,
policy_kwargs=collect_kwargs)

if postprocess_data_fn is not None:
Expand All @@ -174,6 +174,7 @@ def random_collect(

# restore the policy
collector.reset_policy(policy.collect_mode)
return new_data


def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None:
Expand All @@ -190,28 +191,21 @@ def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: Summa
writer.add_scalar('Buffer/num_of_all_collected_episodes', buffer.num_of_collected_episodes, train_iter)
writer.add_scalar('Buffer/num_of_game_segments', len(buffer.game_segment_buffer), train_iter)
writer.add_scalar('Buffer/num_of_transitions', len(buffer.game_segment_game_pos_look_up), train_iter)

game_segment_buffer = buffer.game_segment_buffer

# Calculate the amount of memory occupied by self.game_segment_buffer (in bytes).
buffer_memory_usage = asizeof(game_segment_buffer)

# Convert buffer_memory_usage to megabytes (MB).
buffer_memory_usage_mb = buffer_memory_usage / (1024 * 1024)

# Record the memory usage of self.game_segment_buffer to TensorBoard.
writer.add_scalar('Buffer/memory_usage/game_segment_buffer', buffer_memory_usage_mb, train_iter)

# Get the amount of memory currently used by the process (in bytes).
process = psutil.Process(os.getpid())
process_memory_usage = process.memory_info().rss

# Convert process_memory_usage to megabytes (MB).
process_memory_usage_mb = process_memory_usage / (1024 * 1024)

# Record the memory usage of the process to TensorBoard.
writer.add_scalar('Buffer/memory_usage/process', process_memory_usage_mb, train_iter)


def log_buffer_run_time(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None:
"""
Expand Down
4 changes: 2 additions & 2 deletions lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,14 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps:
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps, 1).item()
if pos_in_game_segment >= len(game_segment.action_segment) - 1:
pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item()
pos_in_game_segment = np.random.choice(max(len(game_segment.action_segment) - 1, 1), 1).item()
else:
# For environments with a fixed action space (e.g., Atari),
# we can safely sample from the entire game segment range.
if pos_in_game_segment >= self._cfg.game_segment_length:
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item()
if pos_in_game_segment >= len(game_segment.action_segment) - 1:
pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item()
pos_in_game_segment = np.random.choice(max(len(game_segment.action_segment) - 1, 1), 1).item()

pos_in_game_segment_list.append(pos_in_game_segment)

Expand Down
5 changes: 3 additions & 2 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,9 @@ def encode(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
batch_idx = torch.arange(B, device=last_hidden.device)

selected = last_hidden[batch_idx, positions] # [B, H]

latent = self.embedding_head(selected.to(self.embedding_head[0].weight.dtype))
weight = self.embedding_head[0].weight
selected = selected.to(dtype=weight.dtype, device=weight.device)
latent = self.embedding_head(selected)
return latent

def decode(self, embeddings: torch.Tensor, max_length: int = 512) -> str:
Expand Down
Loading