Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable huggingface tokenizer #1268

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import org.apache.commons.io.FilenameUtils
import groovy.json.JsonOutput


def show_node_info() {
sh """
echo "NODE_NAME = \$NODE_NAME" || true
lsb_release -sd || true
uname -r || true
cat /sys/module/amdgpu/version || true
ls /opt/ -la || true
"""
}

def clean_up_docker() {
sh 'docker ps -a || true' // "|| true" suppresses errors
sh 'docker kill $(docker ps -q) || true'
sh 'docker rm $(docker ps -a -q) || true'
sh 'docker rmi $(docker images -q) || true'
sh 'docker system prune -af --volumes || true'
}

def clean_up_docker_container() {
sh 'docker ps -a || true' // "|| true" suppresses errors
sh 'docker kill $(docker ps -q) || true'
}

//makes sure multiple builds are not triggered for branch indexing
def resetbuild() {
if(currentBuild.getBuildCauses().toString().contains('BranchIndexingCause')) {
def milestonesList = []
def build = currentBuild

while(build != null) {
if(build.getBuildCauses().toString().contains('BranchIndexingCause')) {
milestonesList.add(0, build.number)
}
build = build.previousBuildInProgress
}

for (buildNum in milestonesList) {
milestone(buildNum)
}
}
}

pipeline {
agent any

stages {
stage('Build') {
steps {
echo 'Building..'
}
}
stage('Test') {
steps {
echo 'Testing..'
}
}
stage('Deploy') {
steps {
show_node_info()
}
}
}
}
194 changes: 159 additions & 35 deletions megatron/core/models/common/embeddings/rope_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Tuple, Union

if TYPE_CHECKING:
from megatron.core.transformer.transformer_config import TransformerConfig
Expand All @@ -26,6 +26,11 @@
except ImportError:
HAVE_APPLY_ROPE_FUSION = False

try:
import transformer_engine.pytorch.cpp_extensions as tex
HAVE_TE = True
except ImportError:
HAVE_TE = False

def get_pos_emb_on_this_cp_rank(pos_emb: Tensor, seq_dim: int) -> Tensor:
"""Get the position embedding on the current context parallel rank.
Expand Down Expand Up @@ -149,43 +154,162 @@ def apply_rotary_pos_emb(
Reroute to the appropriate apply_rotary_pos_emb function depending on
fused/unfused kernels, or bshd (conventional) / thd (packed seq) format
"""
if config.apply_rope_fusion and not HAVE_APPLY_ROPE_FUSION:
# setting apply_rope_fusion in config to False
# so that subsequent queries to this config also return False
config.apply_rope_fusion = False
if not getattr(apply_rotary_pos_emb, "printed_fused_warning", False):
if not config.disable_te_fused_rope and HAVE_TE and torch.cuda.is_available() and torch.version.hip:
return apply_rotary_pos_emb_fused_te(t = t, freqs = freqs, config = config, cu_seqlens = cu_seqlens)
else:
if config.apply_rope_fusion and not HAVE_APPLY_ROPE_FUSION:
# setting apply_rope_fusion in config to False
# so that subsequent queries to this config also return False
config.apply_rope_fusion = False
if not getattr(apply_rotary_pos_emb, "printed_fused_warning", False):
logger.warning(
"Setting apply_rope_fusion to false because its implementation"
" is not included in Apex. Try upgrading to the latest version"
)
apply_rotary_pos_emb.printed_fused_warning = True

if getattr(config, "multi_latent_attention", False) and config.rotary_interleaved:
logger.warning(
"Setting apply_rope_fusion to false because its implementation"
" is not included in Apex. Try upgrading to the latest version"
"rotary_interleaved is not supported with multi_latent_attention, setting it to False"
)
apply_rotary_pos_emb.printed_fused_warning = True
config.rotary_interleaved = False

if config.apply_rope_fusion:
if cu_seqlens is None:
return fused_apply_rotary_pos_emb(t, freqs, transpose_output_memory=True)
else:
return fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs)
else:
if cu_seqlens is None:
return _apply_rotary_pos_emb_bshd(
t,
freqs,
rotary_interleaved=config.rotary_interleaved,
multi_latent_attention=config.multi_latent_attention,
mscale=mscale,
)
else:
return _apply_rotary_pos_emb_thd(
t,
cu_seqlens,
freqs,
rotary_interleaved=config.rotary_interleaved,
multi_latent_attention=config.multi_latent_attention,
mscale=mscale,
)

class FusedRoPEFunc(torch.autograd.Function):
"""
Function for FusedRoPE

if getattr(config, "multi_latent_attention", False) and config.rotary_interleaved:
logger.warning(
"rotary_interleaved is not supported with multi_latent_attention, setting it to False"
)
config.rotary_interleaved = False
This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
"""

if config.apply_rope_fusion:
if cu_seqlens is None:
return fused_apply_rotary_pos_emb(t, freqs, transpose_output_memory=True)
@staticmethod
def forward(
ctx,
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
if freqs.dtype != torch.float32:
freqs = freqs.float()
if tensor_format == "sbhd":
output = tex.fused_rope_forward(t, freqs, False)
elif tensor_format == "bshd":
output = tex.fused_rope_forward(
t.transpose(0, 1), freqs, True
).transpose(0, 1)
elif tensor_format == "thd":
output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs)
else:
return fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs)
else:
if cu_seqlens is None:
return _apply_rotary_pos_emb_bshd(
t,
freqs,
rotary_interleaved=config.rotary_interleaved,
multi_latent_attention=config.multi_latent_attention,
mscale=mscale,
)
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
ctx.save_for_backward(freqs, cu_seqlens)
ctx.tensor_format = tensor_format
return output

@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
freqs, cu_seqlens = ctx.saved_tensors
if ctx.tensor_format == "sbhd":
grad_input = tex.fused_rope_backward(grad_output, freqs, False)
elif ctx.tensor_format == "bshd":
grad_input = tex.fused_rope_backward(
grad_output.transpose(0, 1), freqs, True
).transpose(0, 1)
elif ctx.tensor_format == "thd":
grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs)
else:
return _apply_rotary_pos_emb_thd(
t,
cu_seqlens,
freqs,
rotary_interleaved=config.rotary_interleaved,
multi_latent_attention=config.multi_latent_attention,
mscale=mscale,
)
raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")

return grad_input, None, None, None, None


def apply_rotary_pos_emb_fused_te(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
config: TransformerConfig = None,
fused: bool = True,
cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
"""
Apply rotary positional embedding tensor to the input tensor.

Parameters
----------
t: torch.Tensor
Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied.
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
cu_seqlens: torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'.
"""

if fused:
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)

assert tensor_format in ("sbhd", "bshd"), (
"Only formats `sbhd` or `bshd` are supported for input tensor `t` "
f"when fused is False, got {tensor_format}."
)

max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert cur_seq_len <= max_seq_len, (
f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
)
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
# cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)

rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * cos_) + (_rotate_half(t) * sin_)
return torch.cat((t, t_pass), dim=-1)
2 changes: 1 addition & 1 deletion megatron/core/models/common/embeddings/rotary_pos_embedding.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,4 @@ def get_rotary_seq_len(

rotary_seq_len *= transformer_config.context_parallel_size

return rotary_seq_len
return rotary_seq_len
3 changes: 3 additions & 0 deletions megatron/core/transformer/transformer_config.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ class TransformerConfig(ModelParallelConfig):
apply_rope_fusion: bool = False
"""If True, use fused RoPE kernel."""

disable_te_fused_rope: bool = False
"""If True, disable fused RoPE kernel from transformer engine"""

####################
# activation recomputation
####################
Expand Down
32 changes: 17 additions & 15 deletions megatron/legacy/fused_kernels/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import os
import pathlib
import subprocess

import torch
from torch.utils import cpp_extension


# Setting this param to a list has a problem of generating different
# compilation commands (with diferent order of architectures) and
# leading to recompilation of fused kernels. Set it to empty string
Expand All @@ -16,22 +17,23 @@

def load(args):

# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
cpp_extension.CUDA_HOME
)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
if int(bare_metal_minor) >= 8:
if torch.cuda.is_available() and torch.version.cuda:
# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
cpp_extension.CUDA_HOME
)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_90,code=sm_90')
cc_flag.append('arch=compute_80,code=sm_80')
if int(bare_metal_minor) >= 8:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_90,code=sm_90')

# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / "build"
_create_build_dir(buildpath)
# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / "build"
_create_build_dir(buildpath)

# Helper function to build the kernels.
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
Expand Down
4 changes: 4 additions & 0 deletions megatron/training/arguments.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,8 @@ def core_transformer_config_from_args(args, config_class=None):
else:
kw_args['num_query_groups'] = None
kw_args['config_logger_dir'] = args.config_logger_dir
if args.disable_te_fused_rope:
kw_args['disable_te_fused_rope'] = args.disable_te_fused_rope

# Return config.
return config_class(**kw_args)
Expand Down Expand Up @@ -853,6 +855,8 @@ def _add_network_size_args(parser):
action='store_false',
help='Disable position embedding. Deprecated: use --position-embedding-type',
dest='add_position_embedding')
group.add_argument('--disable-te-fused-rope', action='store_true', default = False,
help='Disable fused rope from transformer-engine: use --disable_te_fused_rope')
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
Expand Down
2 changes: 1 addition & 1 deletion tools/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def get_args():
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer',
'GPTSentencePieceTokenizer', 'Llama2Tokenizer',
'Llama3Tokenizer', 'MistralTokenizer', 'NullTokenizer'],
'Llama3Tokenizer', 'MistralTokenizer', 'HuggingFaceTokenizer', 'NullTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='YTTM tokenizer model.')
Expand Down