Skip to content
Open
1 change: 1 addition & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
from .train_rezero import train_rezero
from .train_unizero import train_unizero
from .train_unizero_segment import train_unizero_segment
from .train_unizero_segment_with_reward_model import train_unizero_segment_with_reward_model
from .utils import *
235 changes: 235 additions & 0 deletions lzero/entry/train_unizero_segment_with_reward_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import logging
import os
from functools import partial
from typing import Tuple, Optional

import torch
import wandb
from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import EasyTimer
from ding.utils import set_pkg_seed, get_rank, get_world_size
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroSegmentCollector as Collector
from .utils import random_collect, calculate_update_per_collect, random_collect_for_rnd

timer = EasyTimer()

def train_unizero_segment_with_reward_model(
input_cfg: Tuple[dict, dict],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
model_path: Optional[str] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy':
"""
Overview:
The train entry for UniZero (with muzero_segment_collector and buffer reanalyze trick), proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models.
UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms,
particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.
Arguments:
- input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- model_path (:obj:`Optional[str]`): The pretrained model path, which should
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""

cfg, create_cfg = input_cfg

# Ensure the specified policy type is supported
assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'"
assert cfg.policy.use_rnd_model, "cfg.policy.use_rnd_model must be True to use RND reward model"

# Import the correct GameBuffer class based on the policy type
game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'}

GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]),
game_buffer_classes[create_cfg.policy.type])

# Set device based on CUDA availability
cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
logging.info(f'cfg.policy.device: {cfg.policy.device}')

# Compile the configuration
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)

# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available())

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# Load pretrained model if specified
if model_path is not None:
logging.info(f'Loading model from {model_path} begin...')
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
logging.info(f'Loading model from {model_path} end!')

# Create worker components: learner, collector, evaluator, replay buffer, commander
tb_logger = None
if get_rank() == 0:
tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
if cfg.policy.use_rnd_model:
policy.rnd._init_log(tb_logger=tb_logger, _exp_name=cfg.exp_name)

# MCTS+RL algorithms related core code
policy_config = cfg.policy
replay_buffer = GameBuffer(policy_config)
collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name,
policy_config=policy_config)
evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode,
tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config)
# Learner's before_run hook
learner.call_hook('before_run')

if cfg.policy.use_wandb and get_rank() == 0:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

# Collect random data before training
if cfg.policy.random_collect_episode_num > 0:
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)

if cfg.policy.rnd_random_collect_episode_num > 0:
random_collector_env_cfg = [collector_env_cfg[0] for _ in range(cfg.policy.rnd_random_collect_episode_num)]
random_collect_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in random_collector_env_cfg])
random_data = random_collect_for_rnd(env=random_collect_env)
policy.rnd.warmup_with_random_segments(random_data)

batch_size = policy._cfg.batch_size

buffer_reanalyze_count = 0
train_epoch = 0
reanalyze_batch_size = cfg.policy.reanalyze_batch_size

if cfg.policy.multi_gpu:
# Get current world size and rank
world_size = get_world_size()
rank = get_rank()
else:
world_size = 1
rank = 0

while True:
# Log buffer memory usage
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)

# Set temperature for visit count distributions
collect_kwargs = {
'temperature': visit_count_temperature(
policy_config.manual_temperature_decay,
policy_config.fixed_temperature_value,
policy_config.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
),
'epsilon': 0.0 # Default epsilon value
}

# Configure epsilon for epsilon-greedy exploration
if policy_config.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(
start=policy_config.eps.start,
end=policy_config.eps.end,
decay=policy_config.eps.decay,
type_=policy_config.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)

# Evaluate policy performance
if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep, reward_model=policy.rnd)
if stop:
break
# Collect new data
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# Determine updates per collection
update_per_collect = calculate_update_per_collect(cfg, new_data, world_size)

# Update replay buffer
replay_buffer.push_game_segments(new_data)
replay_buffer.remove_oldest_data_to_fit()

# Periodically reanalyze buffer
if cfg.policy.buffer_reanalyze_freq >= 1:
# Reanalyze buffer <buffer_reanalyze_freq> times in one train_epoch
reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq
else:
# Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch
if train_epoch > 0 and train_epoch % int(1/cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition):
with timer:
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy)
buffer_reanalyze_count += 1
logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}')
logging.info(f'Buffer reanalyze time: {timer.value}')

# Train the policy if sufficient data is available
if collector.envstep > cfg.policy.train_start_after_envsteps:
if cfg.policy.sample_type == 'episode':
data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size
else:
data_sufficient = replay_buffer.get_num_of_transitions() > batch_size
if not data_sufficient:
logging.warning(
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....'
)
continue

for i in range(update_per_collect):
if cfg.policy.buffer_reanalyze_freq >= 1:
# Reanalyze buffer <buffer_reanalyze_freq> times in one train_epoch
if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition):
with timer:
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy)
buffer_reanalyze_count += 1
logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}')
logging.info(f'Buffer reanalyze time: {timer.value}')

train_data = replay_buffer.sample(batch_size, policy)
train_data.append(learner.train_iter)
if cfg.policy.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)
log_vars = learner.train(train_data, collector.envstep)
logging.info(f'[{i}/{update_per_collect}]: learner ended training step.')

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

train_epoch += 1
policy.recompute_pos_emb_diff_and_clear_cache()

# Check stopping criteria
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break

learner.call_hook('after_run')
if cfg.policy.use_wandb:
wandb.finish()
return policy
50 changes: 47 additions & 3 deletions lzero/entry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tensorboardX import SummaryWriter
import torch
import torch.distributed as dist
import numpy as np

def is_ddp_enabled():
"""
Expand Down Expand Up @@ -147,8 +148,8 @@ def random_collect(
collector_env: 'BaseEnvManager', # noqa
replay_buffer: 'IBuffer', # noqa
postprocess_data_fn: Optional[Callable] = None
) -> None: # noqa
assert policy_cfg.random_collect_episode_num > 0
) -> list: # noqa
assert policy_cfg.random_collect_data, "random_collect_data should be True."

random_policy = RandomPolicy(cfg=policy_cfg, action_space=collector_env.env_ref.action_space)
# set the policy to random policy
Expand All @@ -159,7 +160,7 @@ def random_collect(
collect_kwargs = {'temperature': 1, 'epsilon': 0.0}

# Collect data by default config n_sample/n_episode.
new_data = collector.collect(n_episode=policy_cfg.random_collect_episode_num, train_iter=0,
new_data = collector.collect(train_iter=0,
policy_kwargs=collect_kwargs)

if postprocess_data_fn is not None:
Expand All @@ -172,6 +173,49 @@ def random_collect(

# restore the policy
collector.reset_policy(policy.collect_mode)
return new_data

def random_collect_for_rnd(env):
if env is not None:
env.launch()
else:
raise ValueError(f'env has error!')
init_obs = env.ready_obs
env_nums = len(init_obs)
activate_env_ids = list(range(env_nums))

action_mask_dict = {}
episode_obs_env = {env_id:[] for env_id in activate_env_ids}
collect_obs_list = []

for env_id in range(env_nums):
action_mask_dict[env_id] = init_obs[env_id]['action_mask']
episode_obs_env[env_id].append(init_obs[env_id]['observation'])

while True:
if len(activate_env_ids) == 0:
break
actions = {}

for env_id in activate_env_ids:
legal_actions = [i for i, x in enumerate(action_mask_dict[env_id]) if x == 1]
actions[env_id] = int(np.random.choice(legal_actions, 1))

timesteps = env.step(actions)
action_mask_dict.clear()

for env_id, episode_timestep in timesteps.items():
obs, reward, done, info = episode_timestep.obs, episode_timestep.reward, episode_timestep.done, episode_timestep.info

action_mask_dict[env_id] = obs['action_mask']
episode_obs_env[env_id].append(obs['observation'])

if done:
activate_env_ids.remove(env_id)
collect_obs_list.extend(episode_obs_env[env_id])
episode_obs_env[env_id].clear()

return collect_obs_list


def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None:
Expand Down
31 changes: 31 additions & 0 deletions lzero/mcts/buffer/game_buffer_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,34 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
batch_target_values = np.asarray(batch_target_values)

return batch_rewards, batch_target_values

def update_priority(self, train_data: List[np.ndarray], batch_priorities: np.ndarray) -> None:
"""
Overview:
Update the priority of training data.
Arguments:
- train_data (:obj:`List[np.ndarray]`): training data to be updated priority.
- batch_priorities (:obj:`np.ndarray`): priorities to update to.
NOTE:
train_data = [current_batch, target_batch]
current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list]
"""
# TODO: NOTE: -4 is batch_index_list
indices = train_data[0][-4]
metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities}
# only update the priorities for data still in replay buffer
for i in range(len(indices)):
# ==================== START OF FINAL FIX ====================

# FIX 1: Handle ValueError by using the first timestamp of the segment for comparison.
first_transition_time = metas['make_time'][i][0]

if first_transition_time > self.clear_time:
# FIX 2: Handle IndexError by converting the float index to an integer before use.
idx = int(indices[i])
prio = metas['batch_priorities'][i]

# Now, idx is a valid integer index.
self.game_pos_priorities[idx] = prio

# ===================== END OF FINAL FIX =====================
10 changes: 6 additions & 4 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,9 @@ def __init__(self,
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
else:
self.tokenizer = tokenizer

for param in self.pretrained_model.parameters():
param.requires_grad = False

# Set the embedding dimension. A linear projection is added (the dimension remains unchanged here but can be extended for other mappings).
self.embedding_size = embedding_size
Expand Down Expand Up @@ -641,11 +644,11 @@ def __init__(
self.embedding_dim = embedding_dim

if self.observation_shape[1] == 64:
self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False)
self.last_linear = nn.Linear(num_channels * 8 * 8, self.embedding_dim, bias=False)

elif self.observation_shape[1] in [84, 96]:
self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False)

self.last_linear = nn.Linear(num_channels * 6 * 6, self.embedding_dim, bias=False)
self.final_norm_option_in_encoder = final_norm_option_in_encoder
if self.final_norm_option_in_encoder == 'LayerNorm':
self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5)
Expand Down Expand Up @@ -678,7 +681,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

x = x.view(-1, self.embedding_dim)

# NOTE: very important for training stability.
x = self.final_norm(x)

return x
Expand Down
Loading