diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index 115727de92..4aa895099b 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -16,6 +16,7 @@ from megatron.core.datasets.utils import Split from megatron.core.datasets.utils_s3 import S3Config, is_s3_path from megatron.core.utils import log_single_rank +from megatron.core.packed_seq_params import PackedSeqParams logger = logging.getLogger(__name__) @@ -213,6 +214,23 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: if idx is None: loss_mask = torch.zeros_like(loss_mask) + # ais packed param + packed_seq_params = torch.tensor([False], dtype=torch.bool) + if self.config.reset_attention_mask and self.config.create_attention_mask \ + and self.config.reset_position_ids: + reset_points = torch.where(position_ids[1:] < position_ids[:-1])[0] + 1 + cu_seqlens = torch.cat([torch.tensor([0]), reset_points, torch.tensor([len(position_ids)])]).to(torch.int32) + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen, _ = seqlens.max(dim=0, keepdim=True) + + packed_seq_params = { + "cu_seqlens_q": cu_seqlens, + "cu_seqlens_kv": cu_seqlens, + "max_seqlen_q": max_seqlen, + "max_seqlen_kv": max_seqlen, + "qkv_format": 'thd', + } + if self.config.create_attention_mask: return { "tokens": tokens, @@ -220,6 +238,7 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: "attention_mask": attention_mask, "loss_mask": loss_mask, "position_ids": position_ids, + "packed_seq_params": packed_seq_params, } else: return { @@ -646,7 +665,9 @@ def _get_ltor_masks_and_position_ids( """ seq_length = data.numel() - if create_attention_mask: + if create_attention_mask and reset_attention_mask: + attention_mask = True + elif create_attention_mask: attention_mask = torch.tril( torch.ones((seq_length, seq_length), device=data.device) ).unsqueeze(0) @@ -673,15 +694,23 @@ def _get_ltor_masks_and_position_ids( # Loop through EOD indices: prev_index = 0 + mask_list = [] for j in range(eod_index.numel()): i = eod_index[j] # Mask attention loss. if reset_attention_mask and attention_mask is not None: - attention_mask[0, (i + 1) :, : (i + 1)] = 0 + small_mask = torch.tril(torch.ones((i + 1 - prev_index, i + 1 - prev_index), device=data.device, dtype=torch.int8)) + mask_list.append(small_mask) # Reset positions. if reset_position_ids: position_ids[(i + 1) :] -= i + 1 - prev_index prev_index = i + 1 + + if prev_index < seq_length: + small_mask = torch.tril(torch.ones((seq_length - prev_index, seq_length - prev_index), device=data.device, dtype=torch.int8)) + mask_list.append(small_mask) + + attention_mask = torch.block_diag(*mask_list).unsqueeze(0) if attention_mask is not None: # Convert attention mask to binary: diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 4c3223d0de..87897b5a0a 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -4,8 +4,10 @@ import os import sys from datetime import datetime +from dataclasses import dataclass, is_dataclass import torch +import torch.nn.functional as F try: from transformer_engine.pytorch.optimizers import multi_tensor_applier, multi_tensor_l2norm @@ -37,6 +39,7 @@ from megatron.core import DistributedDataParallel as DDP from megatron.core import mpu from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate +from megatron.core.packed_seq_params import PackedSeqParams from megatron.legacy.model import Float16Module from megatron.legacy.model.module import param_is_not_shared @@ -306,6 +309,13 @@ def get_batch_on_this_tp_rank(data_iterator): def _broadcast(item): if item is not None: torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) + + def _broadcast_class(item): + if item and is_dataclass(item): + padded_tensors = F.pad(item.cu_seqlens_q, (0, args.seq_length - item.cu_seqlens_q.size(0))) + _broadcast(padded_tensors) + _broadcast(item.max_seqlen_q) + if mpu.get_tensor_model_parallel_rank() == 0: @@ -314,12 +324,21 @@ def _broadcast(item): else: data = None + if data["packed_seq_params"]: + data["packed_seq_params"] = PackedSeqParams(**data["packed_seq_params"]) + data["packed_seq_params"].cu_seqlens_q = data["packed_seq_params"].cu_seqlens_q.cuda().squeeze(0) + data["packed_seq_params"].cu_seqlens_kv = data["packed_seq_params"].cu_seqlens_kv.cuda().squeeze(0) + data["packed_seq_params"].max_seqlen_q = data["packed_seq_params"].max_seqlen_q.cuda().squeeze(0) + data["packed_seq_params"].max_seqlen_kv = data["packed_seq_params"].max_seqlen_kv.cuda().squeeze(0) + data["packed_seq_params"].qkv_format = 'thd' + batch = { 'tokens': data["tokens"].cuda(non_blocking = True), 'labels': data["labels"].cuda(non_blocking = True), 'loss_mask': data["loss_mask"].cuda(non_blocking = True), 'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True), - 'position_ids': data["position_ids"].cuda(non_blocking = True) + 'position_ids': data["position_ids"].cuda(non_blocking = True), + 'packed_seq_params': data["packed_seq_params"] if data["packed_seq_params"] else None } if args.pipeline_model_parallel_size == 1: @@ -328,11 +347,13 @@ def _broadcast(item): _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) + _broadcast_class(batch['packed_seq_params']) elif mpu.is_pipeline_first_stage(): _broadcast(batch['tokens']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) + _broadcast_class(batch['packed_seq_params']) elif mpu.is_pipeline_last_stage(): _broadcast(batch['labels']) @@ -351,6 +372,10 @@ def _broadcast(item): else: attention_mask=None position_ids=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device()) + + if args.reset_position_ids and args.reset_attention_mask: + cu_seqlens=torch.empty((args.seq_length), dtype = torch.int32 , device = torch.cuda.current_device()) + max_seqlen=torch.empty((1), dtype = torch.int32 , device = torch.cuda.current_device()) if args.pipeline_model_parallel_size == 1: _broadcast(tokens) @@ -358,6 +383,9 @@ def _broadcast(item): _broadcast(loss_mask) _broadcast(attention_mask) _broadcast(position_ids) + if args.reset_position_ids and args.reset_attention_mask: + _broadcast(cu_seqlens) + _broadcast(max_seqlen) elif mpu.is_pipeline_first_stage(): labels=None @@ -366,21 +394,37 @@ def _broadcast(item): _broadcast(tokens) _broadcast(attention_mask) _broadcast(position_ids) + if args.reset_position_ids and args.reset_attention_mask: + _broadcast(cu_seqlens) + _broadcast(max_seqlen) elif mpu.is_pipeline_last_stage(): tokens=None position_ids=None + cu_seqlens=None + max_seqlen=None _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) + + packed_seq_params = None + if args.reset_position_ids and args.reset_attention_mask: + packed_seq_params = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + qkv_format='thd', + ) batch = { 'tokens': tokens, 'labels': labels, 'loss_mask': loss_mask, 'attention_mask': attention_mask, - 'position_ids': position_ids + 'position_ids': position_ids, + 'packed_seq_params': packed_seq_params } return batch diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 3b7f8db012..c19aab8843 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -185,13 +185,13 @@ def forward_step(data_iterator, model: GPTModel): timers('batch-generator', log_level=2).start() global stimer with stimer(bdata=True): - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = get_batch( data_iterator) timers('batch-generator').stop() with stimer: output_tensor = model(tokens, position_ids, attention_mask, - labels=labels) + labels=labels, packed_seq_params=packed_seq_params) return output_tensor, partial(loss_func, loss_mask)