Skip to content

Commit

Permalink
Merge branch 'transformer-config-doc' into 'main'
Browse files Browse the repository at this point in the history
Clean up transformer config docs.

See merge request ADLR/megatron-lm!1231
jaredcasper committed Mar 25, 2024
2 parents a746c16 + e1ca51b commit 3a70d14
Showing 2 changed files with 334 additions and 215 deletions.
324 changes: 169 additions & 155 deletions megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
@@ -10,216 +10,230 @@
class ModelParallelConfig:
"""Base configuration for Megatron Core
Model Parallelism
-----------------
tensor_model_parallel_size (int): Intra-layer model parallelism. Splits tensors across GPU ranks. Defaults to 1.
context_parallel_size (int): Splits network input along sequence dimension across GPU ranks. Defaults to 1.
pipeline_model_parallel_size (int): Inter-layer model parallelism. Splits transformer layers across GPU
ranks. Defaults to 1.
virtual_pipeline_model_parallel_size (int): Interleaved pipeline parallelism is used to improve performance by
reducing the pipeline bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel size. See Efficient
Large-Scale Language Model Training on GPU Clusters Using Megatron-LM: https://arxiv.org/pdf/2104.04473.pdf for
more details. Defaults to None.
sequence_parallel (bool): Makes tensor parallelism more memory efficient for LLMs (20B+) by
parallelizing layer norms and dropout sequentially. See Reducing Activation Recomputation in Large Transformer
Models: https://arxiv.org/abs/2205.05198 for more details. Defaults to False.
expert_model_parallel_size (int): Distributes Moe Experts across sub data parallel dimension. Defaults to False.
Initialization
--------------
perform_initialization (bool, optional): If true, weights are initialized. This option can be useful when you
know you are going to load values from a checkpoint. Defaults to True.
use_cpu_initialization: (bool, optional): When set to False, we initialize the weights directly on the GPU.
Transferring weights from CPU to GPU can take a significant amount of time for large models. Defaults to False.
Training
--------
fp16 (bool): If true, train with fp16 mixed precision training. Defaults to False.
bf16 (bool): If true, train with bf16 mixed precision training. Defaults to False.
params_dtype (torch.dtype): dtype used when intializing the weights. Defaults to torch.float32.
timers (optional, default=None): TODO.
Optimizations
-------------
gradient_accumulation_fusion (bool): If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA
extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with
--cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\"
". Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion.
Defaults to False.
async_tensor_model_parallel_allreduce (bool, optional): If true, enables asynchronous execution of
tensor-model-parallel all-reduce with weight gradient compuation of a column-linear layer. Defaults to True.
tp_comm_overlap (bool, optional): If true, allows overlapping of Linear layer execution with tensor parallel
communication collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever
possible during the forward and the backward pass. Defaults to False.
tp_comm_split_ag (bool, optional): If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM
and All-Gather splits. Don't care if tp_comm_overlap is False. Defaults to True.
tp_comm_atomic_ag (bool, optional): If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM
and All-Gather both done atomically. Don't care if tp_comm_overlap is False. Defaults to False.
tp_comm_split_rs (bool, optional): If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the
GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False. Defaults to True.
tp_comm_atomic_rs (bool, optional): If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the
GEMM and Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False. Defaults to False.
tp_comm_bulk_dgrad (bool, optional): If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't
care if tp_comm_overlap is False. Defaults to True.
The initialization function has an argument for each parameter.
"""

tp_comm_bulk_wgrad (bool, optional): If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't
care if tp_comm_overlap is False. Defaults to True.
###################
# Model parallelism
###################
tensor_model_parallel_size: int = 1
"""Intra-layer model parallelism. Splits tensors across GPU ranks."""

Parallelism
-----------
pipeline_model_parallel_size: int = 1
"""Inter-layer model parallelism. Splits transformer layers across GPU ranks."""

finalize_model_grads_func (optional): Function that finalizes gradients on all workers. Could include ensuring that
grads are all-reduced across data parallelism, pipeline parallelism, and sequence parallelism dimensions.
virtual_pipeline_model_parallel_size: Optional[int] = None
"""Interleaved pipeline parallelism is used to improve performance by reducing the pipeline
bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel
size. See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM:
arxiv.org/pdf/2104.04473.pdf for more details.
"""

Pipeline Parallelism
--------------------
sequence_parallel: bool = False
"""Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms
and dropout sequentially. See Reducing Activation Recomputation in Large Transformer Models
(https://arxiv.org/abs/2205.05198) for more details.
pipeline_dtype (required): dtype used in p2p communication, usually params_dtype
"""

grad_scale_func (optional): If using loss scaling, this function should take the loss and return the
scaled loss. If None, no function is called on the loss. Defaults to None.
context_parallel_size: int = 1
"""Splits network input along sequence dimension across GPU ranks."""

enable_autocast (bool): If true runs the forward step function inside torch.autocast context. Default is False.
expert_model_parallel_size: int = 1
"""Distributes Moe Experts across sub data parallel dimension."""

autocast_dtype (torch.dtype): dtype to pass to torch.amp.autocast when enabled. Default is pipeline_dtype.
variable_seq_lengths (bool, optional): Support for variable sequence lengths across microbatches. Setting this
communicates the size of tensors during pipeline parallelism communication, because of this extra overhead it
should only be set if the sequence length varies by microbatch within a global batch. Defaults to False.
###################
# Initialization
###################
perform_initialization: bool = True
"""If true, weights are initialized. This option can be useful when you know you are going to
load values from a checkpoint.
"""

num_microbatches_with_partial_activation_checkpoints (int, optional): If int, set the number of microbatches
where not all of the layers will be checkpointed and recomputed. The rest of the microbatches within the window
of maximum outstanding microbatches will recompute all layers (either full recompute or selective recompute). If
None, the checkpoint and recompute will be left up to the forward_step function. Defaults to None.
use_cpu_initialization: bool = False
"""When set to False, we initialize the weights directly on the GPU. CPU initialization is the
same regardless of tensor model parallelism, but GPU initialization is not. Transferring
weights from CPU to GPU can take a significant amount of time for large models.
"""

overlap_p2p_comm (bool, optional): When True some of the peer to peer communication for pipeline
parallelism will overlap with computation. Must be False if batch_p2p_comm is true. Defaults to False.
###################
# Training
###################
fp16: bool = False
"""If true, train with fp16 mixed precision training."""

batch_p2p_comm (bool, optional): Use batch_isend_irecv instead of individual isend/irecv calls. Must be False
if overlap_p2p_comm is True. Defaults to True.
bf16: bool = False
"""If true, train with bf16 mixed precision training."""

batch_p2p_sync (bool, optional): When using batch_isend_irecv, do a cuda.device.synchronize afterward to work
around a bug in older version of PyTorch. Defaults to True.
params_dtype: torch.dtype = torch.float32
"""dtype used when intializing the weights."""

use_ring_exchange_p2p (bool, optional): Use custom ring_exchange kernel instead of
torch.distributed.batch_isend_irecv(). Requires custom built torch with torch.distributed.ring_exchange.
Defaults to False.
timers: Callable = None
"""Timers object to call for various timing functions. See megatron.core.timers.Timers"""

deallocate_pipeline_outputs (optional): If True, output data is deallocated after the tensor is sent
to the next pipeline stage. Helps with saving memory, does nothing when pipeline parallel is not used.
Defaults to False.
finalize_model_grads_func: Callable = None
"""Function that finalizes gradients on all workers. Could include ensuring that grads are
all-reduced across data parallelism, pipeline parallelism, and sequence parallelism
dimensions.
"""

no_sync_func (optional): Function that creates a context that suppresses asynchronous data-parallel
communication. If the model is an instance of core.distributed.DistributedDataParallel, the default is to use
core.distributed.DistributedDataParallel.no_sync.
grad_scale_func: Callable = None
"""If using loss scaling, this function should take the loss and return the scaled loss. If
None, no function is called on the loss.
"""

grad_sync_func (optional): Function that launches asynchronous gradient reductions (e.g. distributed optimizer
gradient reduce-scatters). The function should take one argument: an iterable of parameters whose gradients are
to be synchronized.
no_sync_func: Callable = None
"""Function that creates a context that suppresses asynchronous data-parallel communication. If
the model is an instance of core.distributed.DistributedDataParallel, the default is to use
core.distributed.DistributedDataParallel.no_sync.
"""

param_sync_func (optional): Function that launches asynchronous parameter synchronizations (e.g. distributed
optimizer parameter all-gathers). The function should take one argument: an iterable of parameters to be
synchronized.
grad_sync_func: Callable = None
"""Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient
reduce-scatters). The function should take one argument: an iterable of parameters whose
gradients are to be synchronized.
"""

pipeline_model_parallel_split_rank (int, optional): If int, rank where encoder and decoder should be split in
cases where the model has both an encoder and decoder (e.g., T5). Ignored if None. Defaults to None.
param_sync_func: Callable = None
"""Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer
parameter all-gathers). The function should take one argument: an iterable of parameters to
be synchronized.
"""

barrier_with_L1_time (bool, optional): If true, use barrier with level 1 time measurements. It is up to the user
to make sure calling barrier with their timers will not result in hangs. This can happen if for example the user
adds a level 1 timer that is not called by all ranks. Defaults to True.
enable_autocast: bool = False
"""If true runs the forward step function inside torch.autocast context."""

CPU Offloading
--------------
autocast_dtype: torch.dtype = None
"""dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype."""

cpu_offloading (bool): When set to True, all the activations are offloaded to the CPU asynchronously. Defaults to True.
cpu_offloading_num_layers (int): Tells the number of transformer layers for which activations has to be offloaded. Defaults to 0.
cpu_offloading_activations (bool): If True, offloads the activations to CPU. Defaults to True.
cpu_offloading_weights (bool): If True, offloads the weights to CPU. Defaults to True.
num_microbatches_with_partial_activation_checkpoints: Optional[int] = None
"""If int, set the number of microbatches where not all of the layers will be checkpointed and
recomputed. The rest of the microbatches within the window of maximum outstanding
microbatches will recompute all layers (either full recompute or selective recompute). If
None, the checkpoint and recompute will be left up to the forward_step function.
"""

# Model parallelism
tensor_model_parallel_size: int = 1
context_parallel_size: int = 1
pipeline_model_parallel_size: int = 1
virtual_pipeline_model_parallel_size: Optional[int] = None
sequence_parallel: bool = False
expert_model_parallel_size: int = 1

# Initialization
perform_initialization: bool = True
use_cpu_initialization: bool = False

# Training
fp16: bool = False
bf16: bool = False
params_dtype: torch.dtype = torch.float32
timers: Callable = None

###################
# Optimizations
###################
gradient_accumulation_fusion: bool = False
"""If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install
APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\"
--global-option=\"--cuda_ext\" ". Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion.
"""

async_tensor_model_parallel_allreduce: bool = False
"""If true, enables asynchronous execution of tensor-model-parallel all-reduce with weight
gradient compuation of a column-linear layer.
"""
tp_comm_overlap: bool = False
"""If true, allows overlapping of Linear layer execution with tensor parallel communication
collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever
possible during the forward and the backward pass.
"""

# Debug Options
tp_comm_split_ag: bool = True
"""If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
splits. Don't care if tp_comm_overlap is False.
"""

tp_comm_atomic_ag: bool = False
"""If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather both
done atomically. Don't care if tp_comm_overlap is False.
"""

tp_comm_split_rs: bool = True
"""If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
"""

tp_comm_atomic_rs: bool = False
"""If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False.
"""

tp_comm_bulk_wgrad: bool = True
tp_comm_bulk_dgrad: bool = True
"""If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if
tp_comm_overlap is False.
"""

# Parallelism
finalize_model_grads_func: Callable = None
tp_comm_bulk_dgrad: bool = True
"""If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if
tp_comm_overlap is False.
"""

###################
# Pipeline Parallel
###################
pipeline_dtype: torch.dtype = None
grad_scale_func: Callable = None
enable_autocast: bool = False
autocast_dtype: torch.dtype = None
"""dtype used in p2p communication, usually params_dtype"""

variable_seq_lengths: bool = False
num_microbatches_with_partial_activation_checkpoints: Optional[int] = None
"""Support for variable sequence lengths across microbatches. Setting this communicates the size
of tensors during pipeline parallelism communication, because of this extra overhead it
should only be set if the sequence length varies by microbatch within a global batch.
"""

overlap_p2p_comm: bool = False
"""When True some of the peer to peer communication for pipeline parallelism will overlap with
computation. Must be False if batch_p2p_comm is true.
"""

batch_p2p_comm: bool = True
"""Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if
overlap_p2p_comm is True.
"""

batch_p2p_sync: bool = True
"""When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in
older version of PyTorch.
"""

use_ring_exchange_p2p: bool = False
"""Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires
custom built torch with torch.distributed.ring_exchange.
"""

deallocate_pipeline_outputs: bool = False
no_sync_func: Callable = None
grad_sync_func: Callable = None
param_sync_func: Callable = None
"""If True, output data is deallocated after the tensor is sent to the next pipeline stage.
Helps with saving memory, does nothing when pipeline parallel is not used.
"""

pipeline_model_parallel_split_rank: Optional[int] = None
"""If int, rank where encoder and decoder should be split in cases where the model has both an
encoder and decoder (e.g., T5). Ignored if None.
"""

###################
# CPU Offloading
###################
cpu_offloading: bool = False
"""When set to True, all the activations are offloaded to the CPU asynchronously."""

cpu_offloading_num_layers: int = 0
"""Tells the number of transformer layers for which activations has to be offloaded."""

_cpu_offloading_context: ContextManager = None # Used for internal use only, not to be set by the user. TODO: Need to move to the 'right' place when possible.
"""For internal use only, do not set."""

cpu_offloading_activations: bool = True
"""If True, offloads the activations to CPU."""

cpu_offloading_weights: bool = True
"""If True, offloads the weights to CPU."""

###################
# Timing
###################
barrier_with_L1_time: bool = True
"""If true, use barrier with level 1 time measurements. It is up to the user to make sure
calling barrier with their timers will not result in hangs. This can happen if for example
the user adds a level 1 timer that is not called by all ranks.
"""

def __post_init__(self):
""" Python dataclass method that is used to modify attributes after initialization.
225 changes: 165 additions & 60 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
@@ -15,138 +15,243 @@
class TransformerConfig(ModelParallelConfig):
"""Configuration object for megatron-core transformers.
num_layers (int): Number of transformer layers in a transformer block.
hidden_size (int): Transformer hidden size.
ffn_hidden_size (int): Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size if not provided. Defaults to None.')
num_attention_heads (int): Number of transformer attention heads.
kv_channels (int): Projection weights dimension in multi-head attention. This is set to hidden_size // num_attention_heads if not provided. Defaults to None.
num_query_groups (int): Number of query groups for group query attention. If None, normal attention is used.
hidden_dropout (float): Dropout probability for transformer hidden state. Defaults to 0.1.
attention_dropout (float): Post attention dropout probability. Defaults to 0.1.
fp32_residual_connection (bool): If true, move residual connections to fp32.
apply_residual_connection_post_layernorm (bool): If true, uses the original BERT residule connection ordering. Defaults to False.
layernorm_epsilon (float): Layernorm epsilon. Defaults to 1e-5.
layernorm_zero_centered_gamma (bool): if set to 'True', the LayerNorm is adjusted to center the gamma values around 0. This improves numerical stability. Defaults to False.
add_bias_linear (bool): Include a bias term in all linear layers (QKV projections, after core attention, and two in MLP layer). Default is True.
add_qkv_bias (bool): Add a bias term only for QKV projections. Default is False.
gated_linear_unit (bool): Use a gated linear unit for the first linear layer in the MLP. Defaults to False.
activation_func (Callable): Activation function to use for the non-linearity in the MLP. Defaults to F.gelu.
num_moe_experts (int): Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Defaults to None (no MoE).
rotary_interleaved (bool): True is rotate pairs of even and odd dimensions (RoFormer style), False is rotate pairs of first half and second half (LLaMa style). Default to False.
init_method (Callable): Method to initialize weights. Note that bias is always set to zero. Should be a function that takes a single Tensor and initializes it. Defaults to megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with mean=0.0 and std=init_method_Std.
output_layer_init_method (Callable): Method to initialize weights of the output layer of both attention and MLP blocks. Defaults to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers).
init_method_std (float): Standard deviation of the zero mean normal for the default initialization method, not used if init_method and output_layer_init_method are provided. Defaults to 0.02.
apply_query_key_layer_scaling (bool): If true, scale Q * K^T by 1 / layer-number. Defaults to True.
attention_softmax_in_fp32 (bool): If true, run attention masking and softmax in fp32. This should be true if apply_query_key_layer_scaling is true.
bias_gelu_fustion (bool): If true, fuses bias and gelu. Defaults to False.
masked_softmax_fusion (bool): If true, uses softmax fusion.
persist_layer_norm (bool): If true, uses the persistent fused layer norm kernel. This kernel only supports a fixed set of hidden sizes. Defaults to False.
memory_efficient_layer_norm(bool): If True, and using local layers (not from TransformerEngine), tells Apex to use the memory efficient fused LayerNorm kernel. Ignored if not using LayerNorm. Defaults to False.
bias_dropout_fusion (bool): If true, uses bias dropout fusion.
recompute_granularity (str): megatron-core supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. 'full' will checkpoint the entire transformer layer. Must be 'selective' or 'full'. 'selective' always uses all layers. Defaults to None.
recompute_method (str): uniform will uniformly divide the total number of transformer layers in a transformer block and recompute the input activation of each divided chunk at the specified granularity. block will recompute the input activations for only a set number of transformer layers per pipeline stage. The rest of the layers in the pipeline stage will not have any activations recomputed. Must be 'uniform' or 'block'. Defaults to None.
recompute_num_layers (int): When recompute_method is uniform, recompute_num_layers is the number of transformer layers in each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is the number of transformer layers to recompute within each pipeline stage. Must be None for 'selective' activation checkpointing. Defaults to None.
distribute_saved_activations (bool): If true, distribute recomputed activations across the model parallel group. Defaults to None.
fp8 (str): If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined choices: (1) 'e4m3' uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8 activation and weight tensors and e5m2 for all FP8 output activation gradient tensors. Defaults to None.
fp8_margin (int): Margin for the scaling factor computation.
fp8_interval (int): Controls how often the scaling factor is recomputed.
fp8_amax_history_len (int): The length of the amax history window used for scaling factor computation.
fp8_amax_compute_algo (str): Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2 predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent` always chooses the most recently seen value.
fp8_wgrad (bool): When set to False, override FP8 config options and do the wgrad computation in higher precision. Defaults to True.
clone_scatter_output_in_embedding (bool): When set to true, clone the output of scatter_to_sequence_parallel_region in embedding layer to facilitate garbage collection of input.
disable_parameter_transpose_cache (bool): When set to true, the parameter transposes are not cached for subsequent iterations. Defaults to False.
normalization (str): Swtich b/w `LayerNorm` and `RMSNorm` as normalization layers. For now, these are primarily used by Transformer-Engine's layers like `LayerNormLinear`. Default value is `LayerNorm`.
window_size ((int,int) or None): If not None, then will use sliding window attention. The size of the window is specified by the numbers inside the tuple; -1 is special value meaning "infinite window size".
moe_router_load_balancing_type (str): Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".
moe_router_topk (int): Number of experts to route to for each token. The default is 2.
moe_grouped_gemm (bool): When there are multiple experts per rank, compress multiple local (potentially small)
gemms in a single kernel launch to improve the utilization and performance by leveraging the Grouped GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).
moe_aux_loss_coeff (float): Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.
moe_z_loss_coeff (float): Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended.
moe_input_jitter_eps (float): Add noise to the input tensor by applying jitter with a specified epsilon value.
moe_token_dropping (bool): This feature involves selectively dropping and padding tokens for each expert to achieve a specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note: Currently unsupported.
qk_layernorm (bool): Whether to apply LayerNorm to the query and key embeddings.
test_mode (bool): Whether to run real-time tests.
The initialization function has an argument for each parameter, including those in ModelParallelConfig.
"""

####################
# model architecture
####################
num_layers: int = 0
"""Number of transformer layers in a transformer block."""

hidden_size: int = 0
"""Transformer hidden size."""

num_attention_heads: int = 0
"""Number of transformer attention heads."""

num_query_groups: int = None
"""Number of query groups for group query attention. If None, normal attention is used."""

ffn_hidden_size: int = None
"""Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size if not provided."""

kv_channels: int = None
"""Projection weights dimension in multi-head attention. This is set to hidden_size //
num_attention_heads if not provided."""

hidden_dropout: float = 0.1
"""Dropout probability for transformer hidden state."""

attention_dropout: float = 0.1
"""Post attention dropout probability."""

fp32_residual_connection: bool = False
"""If true, move residual connections to fp32."""

# @jcasper should we keep this option?
apply_residual_connection_post_layernorm: bool = False
"""If True, uses the original BERT residule connection ordering."""

layernorm_epsilon: float = 1e-5
"""Epsilon value for any LayerNorm operations."""

layernorm_zero_centered_gamma: bool = False
"""If set to True, the LayerNorm is adjusted to center the gamma values around 0. This improves
numerical stability."""

add_bias_linear: bool = True
"""Include a bias term in all linear layers (QKV projections, after core attention, and two in
MLP layer)."""

add_qkv_bias: bool = False
"""Add a bias term only for QKV projections."""

gated_linear_unit: bool = False
"""Use a gated linear unit for the first linear layer in the MLP."""

activation_func: Callable = F.gelu
"""Activation function to use for the non-linearity in the MLP."""

num_moe_experts: int = None
"""Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None
for no MoE."""

rotary_interleaved: bool = False
"""True is rotate pairs of even and odd dimensions (RoFormer style), False is rotate pairs of
first half and second half (LLaMa style). Default to False."""

window_size: Optional[Tuple[int, int]] = None
"""If not None, then will use sliding window attention. The size of the window is specified by
the numbers inside the tuple; -1 is special value meaning "infinite window size"."""

normalization: bool = "LayerNorm"
"""Which norm to use for normalization layers, valid options are `LayerNorm` and `RMSNorm`."""

qk_layernorm: bool = False
"""Whether to apply LayerNorm to the query and key embeddings."""

test_mode: bool = False
"""Whether to run real-time tests."""

####################
# initialization
####################
init_method: Callable = None
"""Method to initialize weights. Note that bias is always set to zero. Should be a function that
takes a single Tensor and initializes it. If None, will be set to
megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with
mean=0.0 and std=init_method_std."""

output_layer_init_method: Callable = None
"""Method to initialize weights of the output layer of both attention and MLP blocks. If None,
will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn
init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers)."""

init_method_std: float = 0.02
"""Standard deviation of the zero mean normal for the default initialization method, not used if
init_method and output_layer_init_method are provided."""

####################
# mixed-precision
####################
apply_query_key_layer_scaling: bool = False
attention_softmax_in_fp32: bool = True
"""If true, scale Q * K^T by 1 / layer-number. This improve numeric stability when training with
fp16."""

# communication
attention_softmax_in_fp32: bool = True
"""If True, run attention masking and softmax in fp32. This should be True if
apply_query_key_layer_scaling is True."""

####################
# fusion
####################
bias_activation_fusion: bool = False
"""If True, fuses bias addition and the activation function when possible."""

masked_softmax_fusion: bool = False
"""If True, uses softmax fusion."""

persist_layer_norm: bool = False
"""If True, uses the persistent fused layer norm kernel. This kernel only supports a fixed set
of hidden sizes."""

memory_efficient_layer_norm: bool = False
"""If True, and using local layers (not from TransformerEngine), tells Apex to use the memory
efficient fused LayerNorm kernel. Ignored if not using LayerNorm."""

bias_dropout_fusion: bool = False # TODO: this should be bias_dropout_add_fusion?
"""If True, uses bias dropout fusion."""

apply_rope_fusion: bool = False
"""If True, use fused RoPE kernel."""

####################
# activation recomputation
####################
recompute_granularity: str = None
recompute_granularity: str = None
"""Determines which type of activation recompute to use. Megatron-core supports 'selective'
activation checkpointing where only the memory intensive part of attention is checkpointed.
These memory intensive activations are also less compute intensive which makes activation
checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large
Transformer Models (https://arxiv.org/abs/2205.05198) for more details. 'full' will checkpoint
the entire transformer layer. If None, no recompute is performed and all activations are saved.
If set, must be 'selective' or 'full'. 'selective' always uses all layers.
"""

recompute_method: str = None
"""Determines which transformer layers will be recomputed. uniform will uniformly divide the
total number of transformer layers in a transformer block and recompute the input activation of
each divided chunk at the specified granularity. block will recompute the input activations for
only a set number of transformer layers per pipeline stage. The rest of the layers in the
pipeline stage will not have any activations recomputed. If None, and recompute is enabled, all
layers will do recomputation. If set, must be 'uniform' or 'block'."""

recompute_num_layers: int = None
"""When recompute_method is uniform, recompute_num_layers is the number of transformer layers in
each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is
the number of transformer layers to recompute within each pipeline stage. Must be None for
'selective' activation checkpointing."""

distribute_saved_activations: bool = None
"""If True, distribute recomputed activations across the model parallel group."""

####################
# fp8 related
####################
fp8: str = None
"""If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined
choices (1) 'e4m3' uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8
activation and weight tensors and e5m2 for all FP8 output activation gradient tensors."""

fp8_margin: int = 0
"""Margin for the scaling factor computation."""

fp8_interval: int = 1
"""Controls how often the scaling factor is recomputed."""

fp8_amax_history_len: int = 1
"""The length of the amax history window used for scaling factor computation."""

fp8_amax_compute_algo: str = "most_recent"
fp8_wgrad: bool = True
"""Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2
predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent`
always chooses the most recently seen value.
# miscellaneous
clone_scatter_output_in_embedding: bool = True
disable_parameter_transpose_cache: bool = False
"""

# experimental section (TODO: move to apt. section above once stable)
normalization: str = "LayerNorm" # alt value supported by TE: "RMSNorm"
fp8_wgrad: bool = True
"""When set to False, override FP8 config options and do the wgrad computation in higher precision."""

####################
# MoE related
####################
moe_router_load_balancing_type: str = "aux_loss"
"""Determines the load balancing strategy for the router. "aux_loss" corresponds to the load
balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing
algorithm used in S-BASE, and "none" implies no load balancing."""

moe_router_topk: int = 2
"""Number of experts to route to for each token."""

moe_grouped_gemm: bool = False
"""When there are multiple experts per rank, compress multiple local (potentially small) gemms
in a single kernel launch to improve the utilization and performance by leveraging the Grouped
GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).
"""

moe_aux_loss_coeff: float = 0 # 1e-2 would be a good start value for load balance loss.
"""Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended."""

moe_z_loss_coeff: float = None # 1e-3 would be a good start value for z-loss
"""Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended."""

moe_input_jitter_eps: float = None
"""Add noise to the input tensor by applying jitter with a specified epsilon value."""

moe_token_dropping: bool = False # TODO: Support token dropping.
"""This feature involves selectively dropping and padding tokens for each expert to achieve a
specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note that this is
currently unsupported so should remain False."""

####################
# miscellaneous
####################
clone_scatter_output_in_embedding: bool = True
"""When set to True, clone the output of scatter_to_sequence_parallel_region in embedding layer
to facilitate garbage collection of input."""

disable_parameter_transpose_cache: bool = False
"""When set to true, the parameter transposes are not cached for subsequent iterations."""

# These 2 attributes are WAR for TRTLLM export. DO NOT USE!! WILL BE DEPRECATED SOON!!
max_position_embeddings: int = 0
"""Deprecated. Do not use."""

rotary_percent: float = 0
"""Deprecated. Do not use."""

def __post_init__(self):
""" Python dataclass method that is used to modify attributes after initialization.

0 comments on commit 3a70d14

Please sign in to comment.