Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Functions] Support Packed_seq_params in Megatron-LM #1215

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions megatron/core/datasets/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from megatron.core.datasets.utils import Split
from megatron.core.datasets.utils_s3 import S3Config, is_s3_path
from megatron.core.utils import log_single_rank
from megatron.core.packed_seq_params import PackedSeqParams

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -213,13 +214,31 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]:
if idx is None:
loss_mask = torch.zeros_like(loss_mask)

# ais packed param
packed_seq_params = torch.tensor([False], dtype=torch.bool)
if self.config.reset_attention_mask and self.config.create_attention_mask \
and self.config.reset_position_ids:
reset_points = torch.where(position_ids[1:] < position_ids[:-1])[0] + 1
cu_seqlens = torch.cat([torch.tensor([0]), reset_points, torch.tensor([len(position_ids)])]).to(torch.int32)
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen, _ = seqlens.max(dim=0, keepdim=True)

packed_seq_params = {
"cu_seqlens_q": cu_seqlens,
"cu_seqlens_kv": cu_seqlens,
"max_seqlen_q": max_seqlen,
"max_seqlen_kv": max_seqlen,
"qkv_format": 'thd',
}

if self.config.create_attention_mask:
return {
"tokens": tokens,
"labels": labels,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"position_ids": position_ids,
"packed_seq_params": packed_seq_params,
}
else:
return {
Expand Down Expand Up @@ -646,7 +665,9 @@ def _get_ltor_masks_and_position_ids(
"""
seq_length = data.numel()

if create_attention_mask:
if create_attention_mask and reset_attention_mask:
attention_mask = True
elif create_attention_mask:
attention_mask = torch.tril(
torch.ones((seq_length, seq_length), device=data.device)
).unsqueeze(0)
Expand All @@ -673,15 +694,23 @@ def _get_ltor_masks_and_position_ids(

# Loop through EOD indices:
prev_index = 0
mask_list = []
for j in range(eod_index.numel()):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask and attention_mask is not None:
attention_mask[0, (i + 1) :, : (i + 1)] = 0
small_mask = torch.tril(torch.ones((i + 1 - prev_index, i + 1 - prev_index), device=data.device, dtype=torch.int8))
mask_list.append(small_mask)
# Reset positions.
if reset_position_ids:
position_ids[(i + 1) :] -= i + 1 - prev_index
prev_index = i + 1

if prev_index < seq_length:
small_mask = torch.tril(torch.ones((seq_length - prev_index, seq_length - prev_index), device=data.device, dtype=torch.int8))
mask_list.append(small_mask)

attention_mask = torch.block_diag(*mask_list).unsqueeze(0)

if attention_mask is not None:
# Convert attention mask to binary:
Expand Down
48 changes: 46 additions & 2 deletions megatron/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import os
import sys
from datetime import datetime
from dataclasses import dataclass, is_dataclass

import torch
import torch.nn.functional as F

try:
from transformer_engine.pytorch.optimizers import multi_tensor_applier, multi_tensor_l2norm
Expand Down Expand Up @@ -37,6 +39,7 @@
from megatron.core import DistributedDataParallel as DDP
from megatron.core import mpu
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.legacy.model import Float16Module
from megatron.legacy.model.module import param_is_not_shared

Expand Down Expand Up @@ -306,6 +309,13 @@ def get_batch_on_this_tp_rank(data_iterator):
def _broadcast(item):
if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())

def _broadcast_class(item):
if item and is_dataclass(item):
padded_tensors = F.pad(item.cu_seqlens_q, (0, args.seq_length - item.cu_seqlens_q.size(0)))
_broadcast(padded_tensors)
_broadcast(item.max_seqlen_q)


if mpu.get_tensor_model_parallel_rank() == 0:

Expand All @@ -314,12 +324,21 @@ def _broadcast(item):
else:
data = None

if data["packed_seq_params"]:
data["packed_seq_params"] = PackedSeqParams(**data["packed_seq_params"])
data["packed_seq_params"].cu_seqlens_q = data["packed_seq_params"].cu_seqlens_q.cuda().squeeze(0)
data["packed_seq_params"].cu_seqlens_kv = data["packed_seq_params"].cu_seqlens_kv.cuda().squeeze(0)
data["packed_seq_params"].max_seqlen_q = data["packed_seq_params"].max_seqlen_q.cuda().squeeze(0)
data["packed_seq_params"].max_seqlen_kv = data["packed_seq_params"].max_seqlen_kv.cuda().squeeze(0)
data["packed_seq_params"].qkv_format = 'thd'

batch = {
'tokens': data["tokens"].cuda(non_blocking = True),
'labels': data["labels"].cuda(non_blocking = True),
'loss_mask': data["loss_mask"].cuda(non_blocking = True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True),
'position_ids': data["position_ids"].cuda(non_blocking = True)
'position_ids': data["position_ids"].cuda(non_blocking = True),
'packed_seq_params': data["packed_seq_params"] if data["packed_seq_params"] else None
}

if args.pipeline_model_parallel_size == 1:
Expand All @@ -328,11 +347,13 @@ def _broadcast(item):
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
_broadcast_class(batch['packed_seq_params'])

elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
_broadcast_class(batch['packed_seq_params'])

elif mpu.is_pipeline_last_stage():
_broadcast(batch['labels'])
Expand All @@ -351,13 +372,20 @@ def _broadcast(item):
else:
attention_mask=None
position_ids=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())

if args.reset_position_ids and args.reset_attention_mask:
cu_seqlens=torch.empty((args.seq_length), dtype = torch.int32 , device = torch.cuda.current_device())
max_seqlen=torch.empty((1), dtype = torch.int32 , device = torch.cuda.current_device())

if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(position_ids)
if args.reset_position_ids and args.reset_attention_mask:
_broadcast(cu_seqlens)
_broadcast(max_seqlen)

elif mpu.is_pipeline_first_stage():
labels=None
Expand All @@ -366,21 +394,37 @@ def _broadcast(item):
_broadcast(tokens)
_broadcast(attention_mask)
_broadcast(position_ids)
if args.reset_position_ids and args.reset_attention_mask:
_broadcast(cu_seqlens)
_broadcast(max_seqlen)

elif mpu.is_pipeline_last_stage():
tokens=None
position_ids=None
cu_seqlens=None
max_seqlen=None

_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)

packed_seq_params = None
if args.reset_position_ids and args.reset_attention_mask:
packed_seq_params = PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_kv=max_seqlen,
qkv_format='thd',
)

batch = {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids
'position_ids': position_ids,
'packed_seq_params': packed_seq_params
}

return batch
Expand Down
4 changes: 2 additions & 2 deletions pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,13 @@ def forward_step(data_iterator, model: GPTModel):
timers('batch-generator', log_level=2).start()
global stimer
with stimer(bdata=True):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = get_batch(
data_iterator)
timers('batch-generator').stop()

with stimer:
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
labels=labels, packed_seq_params=packed_seq_params)

return output_tensor, partial(loss_func, loss_mask)

Expand Down