From ed8019962253b6ab557a2cc0693c267562223376 Mon Sep 17 00:00:00 2001 From: "yangfan.bai" Date: Tue, 24 Sep 2024 14:43:50 +0800 Subject: [PATCH 1/3] opt:opt ltor masks --- megatron/core/datasets/gpt_dataset.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index 115727de92..d757f60899 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -646,7 +646,9 @@ def _get_ltor_masks_and_position_ids( """ seq_length = data.numel() - if create_attention_mask: + if create_attention_mask and reset_attention_mask: + attention_mask = True + elif create_attention_mask: attention_mask = torch.tril( torch.ones((seq_length, seq_length), device=data.device) ).unsqueeze(0) @@ -673,15 +675,23 @@ def _get_ltor_masks_and_position_ids( # Loop through EOD indices: prev_index = 0 + mask_list = [] for j in range(eod_index.numel()): i = eod_index[j] # Mask attention loss. if reset_attention_mask and attention_mask is not None: - attention_mask[0, (i + 1) :, : (i + 1)] = 0 + small_mask = torch.tril(torch.ones((i + 1 - prev_index, i + 1 - prev_index), device=data.device, dtype=torch.int8)) + mask_list.append(small_mask) # Reset positions. if reset_position_ids: position_ids[(i + 1) :] -= i + 1 - prev_index prev_index = i + 1 + + if prev_index < seq_length: + small_mask = torch.tril(torch.ones((seq_length - prev_index, seq_length - prev_index), device=data.device, dtype=torch.int8)) + mask_list.append(small_mask) + + attention_mask = torch.block_diag(*mask_list).unsqueeze(0) if attention_mask is not None: # Convert attention mask to binary: From 073b8ed4de640cbe2cc243190ca6ef3567a6993b Mon Sep 17 00:00:00 2001 From: "yangfan.bai" Date: Fri, 11 Oct 2024 18:07:15 +0800 Subject: [PATCH 2/3] add packed param --- megatron/core/datasets/gpt_dataset.py | 18 ++++++++++++ megatron/training/utils.py | 40 +++++++++++++++++++++++++-- pretrain_gpt.py | 4 +-- 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index d757f60899..dadb866fa2 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -16,6 +16,7 @@ from megatron.core.datasets.utils import Split from megatron.core.datasets.utils_s3 import S3Config, is_s3_path from megatron.core.utils import log_single_rank +from megatron.core.packed_seq_params import PackedSeqParams logger = logging.getLogger(__name__) @@ -213,6 +214,22 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: if idx is None: loss_mask = torch.zeros_like(loss_mask) + # ais packed param + packed_seq_params = None + if self.config.reset_attention_mask and self.config.create_attention_mask \ + and self.config.reset_position_ids: + reset_points = torch.where(position_ids[1:] < position_ids[:-1])[0] + 1 + cu_seqlens = torch.cat([torch.tensor([0]), reset_points, torch.tensor([len(position_ids)])]).cuda() + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen, _ = seqlens.max(dim=0, keepdim=True) + packed_seq_params = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + qkv_format='thd', + ) + if self.config.create_attention_mask: return { "tokens": tokens, @@ -220,6 +237,7 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: "attention_mask": attention_mask, "loss_mask": loss_mask, "position_ids": position_ids, + "packed_seq_params": packed_seq_params, } else: return { diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 4c3223d0de..255f498e39 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -4,8 +4,10 @@ import os import sys from datetime import datetime +from dataclasses import dataclass import torch +import torch.nn.functional as F try: from transformer_engine.pytorch.optimizers import multi_tensor_applier, multi_tensor_l2norm @@ -37,6 +39,7 @@ from megatron.core import DistributedDataParallel as DDP from megatron.core import mpu from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate +from megatron.core.packed_seq_params import PackedSeqParams from megatron.legacy.model import Float16Module from megatron.legacy.model.module import param_is_not_shared @@ -306,6 +309,13 @@ def get_batch_on_this_tp_rank(data_iterator): def _broadcast(item): if item is not None: torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) + + def _broadcast_class(item): + if item is not None and item is isinstance(dataclass): + padded_tensors = F.pad(item.cu_seqlens_q, (0, args.seq_length - item.cu_seqlens_q.size(1))) + _broadcast(padded_tensors) + _broadcast(item.max_seqlen) + if mpu.get_tensor_model_parallel_rank() == 0: @@ -319,7 +329,8 @@ def _broadcast(item): 'labels': data["labels"].cuda(non_blocking = True), 'loss_mask': data["loss_mask"].cuda(non_blocking = True), 'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True), - 'position_ids': data["position_ids"].cuda(non_blocking = True) + 'position_ids': data["position_ids"].cuda(non_blocking = True), + 'packed_seq_params': data["packed_seq_params"] } if args.pipeline_model_parallel_size == 1: @@ -328,11 +339,13 @@ def _broadcast(item): _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) + _broadcast_class(batch['packed_seq_params']) elif mpu.is_pipeline_first_stage(): _broadcast(batch['tokens']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) + _broadcast_class(batch['packed_seq_params']) elif mpu.is_pipeline_last_stage(): _broadcast(batch['labels']) @@ -351,6 +364,10 @@ def _broadcast(item): else: attention_mask=None position_ids=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device()) + + if args.reset_position_ids and args.reset_attention_mask: + cu_seqlens=torch.empty((args.seq_length), dtype = torch.int32 , device = torch.cuda.current_device()) + max_seqlen=torch.empty((1), dtype = torch.int32 , device = torch.cuda.current_device()) if args.pipeline_model_parallel_size == 1: _broadcast(tokens) @@ -358,6 +375,9 @@ def _broadcast(item): _broadcast(loss_mask) _broadcast(attention_mask) _broadcast(position_ids) + if args.reset_position_ids and args.reset_attention_mask: + _broadcast(cu_seqlens) + _broadcast(max_seqlen) elif mpu.is_pipeline_first_stage(): labels=None @@ -366,21 +386,37 @@ def _broadcast(item): _broadcast(tokens) _broadcast(attention_mask) _broadcast(position_ids) + if args.reset_position_ids and args.reset_attention_mask: + _broadcast(cu_seqlens) + _broadcast(max_seqlen) elif mpu.is_pipeline_last_stage(): tokens=None position_ids=None + cu_seqlens=None + max_seqlen=None _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) + + packed_seq_params = None + if args.reset_position_ids and args.reset_attention_mask: + packed_seq_params = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + qkv_format='thd', + ) batch = { 'tokens': tokens, 'labels': labels, 'loss_mask': loss_mask, 'attention_mask': attention_mask, - 'position_ids': position_ids + 'position_ids': position_ids, + 'packed_seq_params': packed_seq_params } return batch diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 3b7f8db012..c19aab8843 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -185,13 +185,13 @@ def forward_step(data_iterator, model: GPTModel): timers('batch-generator', log_level=2).start() global stimer with stimer(bdata=True): - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = get_batch( data_iterator) timers('batch-generator').stop() with stimer: output_tensor = model(tokens, position_ids, attention_mask, - labels=labels) + labels=labels, packed_seq_params=packed_seq_params) return output_tensor, partial(loss_func, loss_mask) From f8a97caa5437bec2a3cebe6124b14cd14bd35ea0 Mon Sep 17 00:00:00 2001 From: "yangfan.bai" Date: Sat, 12 Oct 2024 15:19:38 +0800 Subject: [PATCH 3/3] fix packed param bugs --- megatron/core/datasets/gpt_dataset.py | 19 ++++++++++--------- megatron/training/utils.py | 18 +++++++++++++----- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index dadb866fa2..4aa895099b 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -215,20 +215,21 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: loss_mask = torch.zeros_like(loss_mask) # ais packed param - packed_seq_params = None + packed_seq_params = torch.tensor([False], dtype=torch.bool) if self.config.reset_attention_mask and self.config.create_attention_mask \ and self.config.reset_position_ids: reset_points = torch.where(position_ids[1:] < position_ids[:-1])[0] + 1 - cu_seqlens = torch.cat([torch.tensor([0]), reset_points, torch.tensor([len(position_ids)])]).cuda() + cu_seqlens = torch.cat([torch.tensor([0]), reset_points, torch.tensor([len(position_ids)])]).to(torch.int32) seqlens = cu_seqlens[1:] - cu_seqlens[:-1] max_seqlen, _ = seqlens.max(dim=0, keepdim=True) - packed_seq_params = PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_kv=max_seqlen, - qkv_format='thd', - ) + + packed_seq_params = { + "cu_seqlens_q": cu_seqlens, + "cu_seqlens_kv": cu_seqlens, + "max_seqlen_q": max_seqlen, + "max_seqlen_kv": max_seqlen, + "qkv_format": 'thd', + } if self.config.create_attention_mask: return { diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 255f498e39..87897b5a0a 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -4,7 +4,7 @@ import os import sys from datetime import datetime -from dataclasses import dataclass +from dataclasses import dataclass, is_dataclass import torch import torch.nn.functional as F @@ -311,10 +311,10 @@ def _broadcast(item): torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) def _broadcast_class(item): - if item is not None and item is isinstance(dataclass): - padded_tensors = F.pad(item.cu_seqlens_q, (0, args.seq_length - item.cu_seqlens_q.size(1))) + if item and is_dataclass(item): + padded_tensors = F.pad(item.cu_seqlens_q, (0, args.seq_length - item.cu_seqlens_q.size(0))) _broadcast(padded_tensors) - _broadcast(item.max_seqlen) + _broadcast(item.max_seqlen_q) if mpu.get_tensor_model_parallel_rank() == 0: @@ -324,13 +324,21 @@ def _broadcast_class(item): else: data = None + if data["packed_seq_params"]: + data["packed_seq_params"] = PackedSeqParams(**data["packed_seq_params"]) + data["packed_seq_params"].cu_seqlens_q = data["packed_seq_params"].cu_seqlens_q.cuda().squeeze(0) + data["packed_seq_params"].cu_seqlens_kv = data["packed_seq_params"].cu_seqlens_kv.cuda().squeeze(0) + data["packed_seq_params"].max_seqlen_q = data["packed_seq_params"].max_seqlen_q.cuda().squeeze(0) + data["packed_seq_params"].max_seqlen_kv = data["packed_seq_params"].max_seqlen_kv.cuda().squeeze(0) + data["packed_seq_params"].qkv_format = 'thd' + batch = { 'tokens': data["tokens"].cuda(non_blocking = True), 'labels': data["labels"].cuda(non_blocking = True), 'loss_mask': data["loss_mask"].cuda(non_blocking = True), 'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True), 'position_ids': data["position_ids"].cuda(non_blocking = True), - 'packed_seq_params': data["packed_seq_params"] + 'packed_seq_params': data["packed_seq_params"] if data["packed_seq_params"] else None } if args.pipeline_model_parallel_size == 1: