Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .train_muzero_multitask_segment_ddp import train_muzero_multitask_segment_ddp
from .train_unizero_multitask_segment_ddp import train_unizero_multitask_segment_ddp
from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval
from .train_unizero_multitask_ddp import train_unizero_multitask_ddp
from .train_unizero_multitask import train_unizero_multitask
from .utils import *

from .train_unizero_multitask_balance_segment_ddp import train_unizero_multitask_balance_segment_ddp
531 changes: 531 additions & 0 deletions lzero/entry/train_unizero_multitask.py

Large diffs are not rendered by default.

618 changes: 618 additions & 0 deletions lzero/entry/train_unizero_multitask_ddp.py

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions lzero/entry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,18 @@ def log_buffer_run_time(train_iter: int, buffer: "GameBuffer", writer: SummaryWr

# Reset the time records in the buffer.
buffer.reset_runtime_metrics()


def symlog(x: torch.Tensor) -> torch.Tensor:
"""
Symlog normalization to reduce the scale differences of target values.
symlog(x) = sign(x) * log(|x| + 1)
"""
return torch.sign(x) * torch.log(torch.abs(x) + 1)

def inv_symlog(x: torch.Tensor) -> torch.Tensor:
"""
Inverse operation of symlog, used to recover the original values.
inv_symlog(x) = sign(x) * (exp(|x|) - 1)
"""
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
26 changes: 19 additions & 7 deletions lzero/mcts/buffer/game_buffer_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ def __init__(self, cfg: dict):
if hasattr(self._cfg, 'task_id'):
self.task_id = self._cfg.task_id
print(f"Task ID is set to {self.task_id}.")
try:
self.action_space_size = self._cfg.model.action_space_size_list[self.task_id]
except Exception as e:

if isinstance(self._cfg.model.action_space_size, list):
self.action_space_size = self._cfg.model.action_space_size[self.task_id]
elif isinstance(self._cfg.model.action_space_size, int):
self.action_space_size = self._cfg.model.action_space_size
else:
raise ValueError(" action_space_size should be int or list")
else:
self.task_id = None
print("No task_id found in configuration. Task ID is set to None.")
Expand Down Expand Up @@ -90,6 +93,7 @@ def sample(
)

# target policy

batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, current_batch[1], current_batch[-1]) # current_batch[1] is batch_action
batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
policy_non_re_context, self.action_space_size
Expand Down Expand Up @@ -135,14 +139,14 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
obs_list, action_list, mask_list = [], [], []
timestep_list = []
bootstrap_action_list = []

# prepare the inputs of a batch

for i in range(batch_size):
game = game_segment_list[i]
pos_in_game_segment = pos_in_game_segment_list[i]

actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
self._cfg.num_unroll_steps].tolist()

timestep_tmp = game.timestep_segment[pos_in_game_segment:pos_in_game_segment +
self._cfg.num_unroll_steps].tolist()
# add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid
Expand All @@ -158,9 +162,17 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
# mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))]
# mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))]

# prepare the inputs of a batch
if isinstance(game.action_space_size, list):
action_size = game.action_space_size[self.task_id]
elif isinstance(game.action_space_size, int):
action_size = game.action_space_size
else:
raise ValueError(" action_space_size should be int or list")

# pad random action
actions_tmp += [
np.random.randint(0, game.action_space_size)
np.random.randint(0, action_size)
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
]
# TODO: check the effect
Expand All @@ -185,7 +197,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
self._cfg.num_unroll_steps+self._cfg.td_steps].tolist()
# pad random action
bootstrap_action_tmp += [
np.random.randint(0, game.action_space_size)
np.random.randint(0, action_size)
for _ in range(self._cfg.num_unroll_steps - len(bootstrap_action_tmp))
]
bootstrap_action_list.append(bootstrap_action_tmp)
Expand Down
5 changes: 4 additions & 1 deletion lzero/mcts/buffer/game_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea
# image obs input, e.g. atari environments
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1])
else:
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])
if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1:
self.zero_obs_shape = config.model.observation_shape
else:
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])

self.obs_segment = []
self.action_segment = []
Expand Down
19 changes: 9 additions & 10 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,7 @@ def __init__(self,
model_path: str = 'google-bert/bert-base-uncased',
embedding_size: int = 768,
group_size: int = 8,
norm_type: str = "simnorm",
# norm_type: str = "layernorm", # TODO: Why does nan appear in the first step of training?
final_norm_option_in_encoder: str = "simnorm",
tokenizer=None):
"""
Overview:
Expand All @@ -391,12 +390,12 @@ def __init__(self,

# In distributed training, only the rank 0 process downloads the model, and other processes load from cache to speed up startup.
if get_rank() == 0:
self.model = AutoModel.from_pretrained(model_path)
self.pretrained_model = AutoModel.from_pretrained(model_path)
if get_world_size() > 1:
# Wait for rank 0 to finish loading the model.
torch.distributed.barrier()
if get_rank() != 0:
self.model = AutoModel.from_pretrained(model_path)
self.pretrained_model = AutoModel.from_pretrained(model_path)

if tokenizer is None:
# Only rank 0 downloads the tokenizer, and then other processes load it from cache.
Expand All @@ -411,15 +410,15 @@ def __init__(self,

# Set the embedding dimension. A linear projection is added (the dimension remains unchanged here but can be extended for other mappings).
self.embedding_size = embedding_size
self.embed_proj_head = nn.Linear(self.model.config.hidden_size, self.embedding_size)
self.embed_proj_head = nn.Linear(self.pretrained_model.config.hidden_size, self.embedding_size)

# Select the normalization method based on the norm_type parameter.
if norm_type.lower() == "simnorm":
if final_norm_option_in_encoder.lower() == "simnorm":
self.norm = SimNorm(simnorm_dim=group_size)
elif norm_type.lower() == "layernorm":
elif final_norm_option_in_encoder.lower() == "layernorm":
self.norm = nn.LayerNorm(embedding_size)
else:
raise NotImplementedError(f"Normalization type '{norm_type}' is not implemented. "
raise NotImplementedError(f"Normalization type '{final_norm_option_in_encoder}' is not implemented. "
f"Choose 'simnorm' or 'layernorm'.")

def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
Expand All @@ -442,12 +441,12 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
if no_grad:
with torch.no_grad():
x = x.long() # Ensure the input tensor is of type long.
outputs = self.model(x, attention_mask=attention_mask)
outputs = self.pretrained_model(x, attention_mask=attention_mask)
# Get the hidden state from the last layer and select the output corresponding to the [CLS] token.
cls_embedding = outputs.last_hidden_state[:, 0, :]
else:
x = x.long()
outputs = self.model(x, attention_mask=attention_mask)
outputs = self.pretrained_model(x, attention_mask=attention_mask)
cls_embedding = outputs.last_hidden_state[:, 0, :]

# Apply linear projection to obtain the desired output dimension.
Expand Down
Loading