diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 0000000000..7f73a1ce70 --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,67 @@ +import org.apache.commons.io.FilenameUtils +import groovy.json.JsonOutput + + +def show_node_info() { + sh """ + echo "NODE_NAME = \$NODE_NAME" || true + lsb_release -sd || true + uname -r || true + cat /sys/module/amdgpu/version || true + ls /opt/ -la || true + """ +} + +def clean_up_docker() { + sh 'docker ps -a || true' // "|| true" suppresses errors + sh 'docker kill $(docker ps -q) || true' + sh 'docker rm $(docker ps -a -q) || true' + sh 'docker rmi $(docker images -q) || true' + sh 'docker system prune -af --volumes || true' +} + +def clean_up_docker_container() { + sh 'docker ps -a || true' // "|| true" suppresses errors + sh 'docker kill $(docker ps -q) || true' +} + +//makes sure multiple builds are not triggered for branch indexing +def resetbuild() { + if(currentBuild.getBuildCauses().toString().contains('BranchIndexingCause')) { + def milestonesList = [] + def build = currentBuild + + while(build != null) { + if(build.getBuildCauses().toString().contains('BranchIndexingCause')) { + milestonesList.add(0, build.number) + } + build = build.previousBuildInProgress + } + + for (buildNum in milestonesList) { + milestone(buildNum) + } + } +} + +pipeline { + agent any + + stages { + stage('Build') { + steps { + echo 'Building..' + } + } + stage('Test') { + steps { + echo 'Testing..' + } + } + stage('Deploy') { + steps { + show_node_info() + } + } + } +} diff --git a/megatron/core/models/common/embeddings/rope_utils.py b/megatron/core/models/common/embeddings/rope_utils.py old mode 100644 new mode 100755 index accb251961..ece04492af --- a/megatron/core/models/common/embeddings/rope_utils.py +++ b/megatron/core/models/common/embeddings/rope_utils.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Tuple, Union if TYPE_CHECKING: from megatron.core.transformer.transformer_config import TransformerConfig @@ -26,6 +26,11 @@ except ImportError: HAVE_APPLY_ROPE_FUSION = False +try: + import transformer_engine.pytorch.cpp_extensions as tex + HAVE_TE = True +except ImportError: + HAVE_TE = False def get_pos_emb_on_this_cp_rank(pos_emb: Tensor, seq_dim: int) -> Tensor: """Get the position embedding on the current context parallel rank. @@ -149,43 +154,162 @@ def apply_rotary_pos_emb( Reroute to the appropriate apply_rotary_pos_emb function depending on fused/unfused kernels, or bshd (conventional) / thd (packed seq) format """ - if config.apply_rope_fusion and not HAVE_APPLY_ROPE_FUSION: - # setting apply_rope_fusion in config to False - # so that subsequent queries to this config also return False - config.apply_rope_fusion = False - if not getattr(apply_rotary_pos_emb, "printed_fused_warning", False): + if not config.disable_te_fused_rope and HAVE_TE and torch.cuda.is_available() and torch.version.hip: + return apply_rotary_pos_emb_fused_te(t = t, freqs = freqs, config = config, cu_seqlens = cu_seqlens) + else: + if config.apply_rope_fusion and not HAVE_APPLY_ROPE_FUSION: + # setting apply_rope_fusion in config to False + # so that subsequent queries to this config also return False + config.apply_rope_fusion = False + if not getattr(apply_rotary_pos_emb, "printed_fused_warning", False): + logger.warning( + "Setting apply_rope_fusion to false because its implementation" + " is not included in Apex. Try upgrading to the latest version" + ) + apply_rotary_pos_emb.printed_fused_warning = True + + if getattr(config, "multi_latent_attention", False) and config.rotary_interleaved: logger.warning( - "Setting apply_rope_fusion to false because its implementation" - " is not included in Apex. Try upgrading to the latest version" + "rotary_interleaved is not supported with multi_latent_attention, setting it to False" ) - apply_rotary_pos_emb.printed_fused_warning = True + config.rotary_interleaved = False + + if config.apply_rope_fusion: + if cu_seqlens is None: + return fused_apply_rotary_pos_emb(t, freqs, transpose_output_memory=True) + else: + return fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs) + else: + if cu_seqlens is None: + return _apply_rotary_pos_emb_bshd( + t, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=config.multi_latent_attention, + mscale=mscale, + ) + else: + return _apply_rotary_pos_emb_thd( + t, + cu_seqlens, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=config.multi_latent_attention, + mscale=mscale, + ) + +class FusedRoPEFunc(torch.autograd.Function): + """ + Function for FusedRoPE - if getattr(config, "multi_latent_attention", False) and config.rotary_interleaved: - logger.warning( - "rotary_interleaved is not supported with multi_latent_attention, setting it to False" - ) - config.rotary_interleaved = False + This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and + the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid + the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern. + """ - if config.apply_rope_fusion: - if cu_seqlens is None: - return fused_apply_rotary_pos_emb(t, freqs, transpose_output_memory=True) + @staticmethod + def forward( + ctx, + t: torch.Tensor, + freqs: torch.Tensor, + tensor_format: str = "sbhd", + cu_seqlens: Union[torch.Tensor, None] = None, + ) -> torch.Tensor: + if freqs.dtype != torch.float32: + freqs = freqs.float() + if tensor_format == "sbhd": + output = tex.fused_rope_forward(t, freqs, False) + elif tensor_format == "bshd": + output = tex.fused_rope_forward( + t.transpose(0, 1), freqs, True + ).transpose(0, 1) + elif tensor_format == "thd": + output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs) else: - return fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs) - else: - if cu_seqlens is None: - return _apply_rotary_pos_emb_bshd( - t, - freqs, - rotary_interleaved=config.rotary_interleaved, - multi_latent_attention=config.multi_latent_attention, - mscale=mscale, - ) + raise ValueError(f"Unsupported tensor_format: {tensor_format}.") + ctx.save_for_backward(freqs, cu_seqlens) + ctx.tensor_format = tensor_format + return output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + freqs, cu_seqlens = ctx.saved_tensors + if ctx.tensor_format == "sbhd": + grad_input = tex.fused_rope_backward(grad_output, freqs, False) + elif ctx.tensor_format == "bshd": + grad_input = tex.fused_rope_backward( + grad_output.transpose(0, 1), freqs, True + ).transpose(0, 1) + elif ctx.tensor_format == "thd": + grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs) else: - return _apply_rotary_pos_emb_thd( - t, - cu_seqlens, - freqs, - rotary_interleaved=config.rotary_interleaved, - multi_latent_attention=config.multi_latent_attention, - mscale=mscale, - ) + raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.") + + return grad_input, None, None, None, None + + +def apply_rotary_pos_emb_fused_te( + t: torch.Tensor, + freqs: torch.Tensor, + tensor_format: str = "sbhd", + config: TransformerConfig = None, + fused: bool = True, + cu_seqlens: Union[torch.Tensor, None] = None, +) -> torch.Tensor: + """ + Apply rotary positional embedding tensor to the input tensor. + + Parameters + ---------- + t: torch.Tensor + Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which + rotary positional embedding will be applied. + freqs: torch.Tensor + Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + fused: bool, default = False + Whether to use a fused applying RoPE implementation. + tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' + is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is + of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True. + cu_seqlens: torch.Tensor, default = None. + Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and + dtype torch.int32. Only valid when `tensor_format` is 'thd'. + """ + + if fused: + assert ( + tensor_format != "thd" or cu_seqlens is not None + ), "cu_seqlens must not be None when tensor_format is 'thd'." + return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens) + + assert tensor_format in ("sbhd", "bshd"), ( + "Only formats `sbhd` or `bshd` are supported for input tensor `t` " + f"when fused is False, got {tensor_format}." + ) + + max_seq_len = freqs.shape[0] + cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] + + # Only apply the rotary embeddings up to the sequence length of the running + # input. + assert cur_seq_len <= max_seq_len, ( + f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" + ) + freqs = freqs[:cur_seq_len] + if tensor_format == "bshd": + freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] + # cos/sin first then dtype conversion for better precision + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + + rot_dim = freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * cos_) + (_rotate_half(t) * sin_) + return torch.cat((t, t_pass), dim=-1) diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py old mode 100644 new mode 100755 index 5232faec60..d16ae79cdb --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -183,4 +183,4 @@ def get_rotary_seq_len( rotary_seq_len *= transformer_config.context_parallel_size - return rotary_seq_len + return rotary_seq_len \ No newline at end of file diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py old mode 100644 new mode 100755 index a63171686a..b8968d6cf5 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -165,6 +165,9 @@ class TransformerConfig(ModelParallelConfig): apply_rope_fusion: bool = False """If True, use fused RoPE kernel.""" + disable_te_fused_rope: bool = False + """If True, disable fused RoPE kernel from transformer engine""" + #################### # activation recomputation #################### diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py old mode 100644 new mode 100755 index 87cceac3e3..f01088fd5a --- a/megatron/legacy/fused_kernels/__init__.py +++ b/megatron/legacy/fused_kernels/__init__.py @@ -3,9 +3,10 @@ import os import pathlib import subprocess - +import torch from torch.utils import cpp_extension + # Setting this param to a list has a problem of generating different # compilation commands (with diferent order of architectures) and # leading to recompilation of fused kernels. Set it to empty string @@ -16,22 +17,23 @@ def load(args): - # Check if cuda 11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( - cpp_extension.CUDA_HOME - ) - if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_80,code=sm_80') - if int(bare_metal_minor) >= 8: + if torch.cuda.is_available() and torch.version.cuda: + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( + cpp_extension.CUDA_HOME + ) + if int(bare_metal_major) >= 11: cc_flag.append('-gencode') - cc_flag.append('arch=compute_90,code=sm_90') + cc_flag.append('arch=compute_80,code=sm_80') + if int(bare_metal_minor) >= 8: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_90,code=sm_90') - # Build path - srcpath = pathlib.Path(__file__).parent.absolute() - buildpath = srcpath / "build" - _create_build_dir(buildpath) + # Build path + srcpath = pathlib.Path(__file__).parent.absolute() + buildpath = srcpath / "build" + _create_build_dir(buildpath) # Helper function to build the kernels. def _cpp_extention_load_helper(name, sources, extra_cuda_flags): diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py old mode 100644 new mode 100755 index e3d876a5f2..9411223126 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -695,6 +695,8 @@ def core_transformer_config_from_args(args, config_class=None): else: kw_args['num_query_groups'] = None kw_args['config_logger_dir'] = args.config_logger_dir + if args.disable_te_fused_rope: + kw_args['disable_te_fused_rope'] = args.disable_te_fused_rope # Return config. return config_class(**kw_args) @@ -853,6 +855,8 @@ def _add_network_size_args(parser): action='store_false', help='Disable position embedding. Deprecated: use --position-embedding-type', dest='add_position_embedding') + group.add_argument('--disable-te-fused-rope', action='store_true', default = False, + help='Disable fused rope from transformer-engine: use --disable_te_fused_rope') group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, help='Pad the vocab size to be divisible by this value.' 'This is added for computational efficieny reasons.') diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index a81fe8ca7e..a9575707b9 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -203,7 +203,7 @@ def get_args(): choices=['BertWordPieceLowerCase','BertWordPieceCase', 'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer', 'Llama2Tokenizer', - 'Llama3Tokenizer', 'MistralTokenizer', 'NullTokenizer'], + 'Llama3Tokenizer', 'MistralTokenizer', 'HuggingFaceTokenizer', 'NullTokenizer'], help='What type of tokenizer to use.') group.add_argument('--tokenizer-model', type=str, default=None, help='YTTM tokenizer model.')