From 3f7c06ac343ae437abb59fafb240e931f5c7f8b4 Mon Sep 17 00:00:00 2001 From: kazuki Date: Fri, 3 Jan 2025 16:55:15 +0900 Subject: [PATCH 1/6] Fix number of parameters calculation logic --- megatron/training/theoretical_memory_usage.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/megatron/training/theoretical_memory_usage.py b/megatron/training/theoretical_memory_usage.py index f9b75031ae..48966f0915 100644 --- a/megatron/training/theoretical_memory_usage.py +++ b/megatron/training/theoretical_memory_usage.py @@ -32,10 +32,11 @@ def compute_weight_and_optimizer_memory(args, verbose=False): # MLP. + ((args.ffn_hidden_size / args.hidden_size) * num_experts * gated_linear_multiplier) # Transformer layernorms. - + (2 / args.hidden_size) - # Final layernorm. - + (1 / (args.num_layers * args.hidden_size)) + + (1 / args.hidden_size) ) + ) + ( + # final layernorm + args.hidden_size ) embedding_size = args.hidden_size * args.padded_vocab_size if args.untie_embeddings_and_output_weights: From 03d1e9dd484f5ec197e07e63687181c8894d5550 Mon Sep 17 00:00:00 2001 From: kazuki Date: Fri, 3 Jan 2025 16:58:55 +0900 Subject: [PATCH 2/6] Fix calc logic for num of param on first pipeline stage --- megatron/training/theoretical_memory_usage.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/megatron/training/theoretical_memory_usage.py b/megatron/training/theoretical_memory_usage.py index 48966f0915..83c66e29b6 100644 --- a/megatron/training/theoretical_memory_usage.py +++ b/megatron/training/theoretical_memory_usage.py @@ -57,12 +57,15 @@ def compute_weight_and_optimizer_memory(args, verbose=False): # Most loaded model shard has (1/pp_size transformer layers + 1 embedding layer) / tp_size. num_parameters_on_most_loaded_model_shard = ( - (num_parameters_in_transformer_layers / args.pipeline_model_parallel_size) + embedding_size + ( # - last layer norm + (num_parameters_in_transformer_layers - args.hidden_size) / args.pipeline_model_parallel_size + ) + embedding_size # (embedding layer) for first stage ) / args.tensor_model_parallel_size if args.untie_embeddings_and_output_weights and args.pipeline_model_parallel_size == 1: num_parameters_on_most_loaded_model_shard += ( embedding_size / args.tensor_model_parallel_size ) + num_parameters_on_most_loaded_model_shard += args.hidden_size # last layer norm if verbose: print( f"Number of parameters in most loaded shard in billions: " From fa2cafbd9a6c9130dd1e1bf964ab32c5f1ecaafc Mon Sep 17 00:00:00 2001 From: kazuki Date: Fri, 3 Jan 2025 17:01:33 +0900 Subject: [PATCH 3/6] Fix calc logic for number of param when accumulate_allreduce_grads_in_fp32 is False --- megatron/training/theoretical_memory_usage.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/megatron/training/theoretical_memory_usage.py b/megatron/training/theoretical_memory_usage.py index 83c66e29b6..9d5d75e31b 100644 --- a/megatron/training/theoretical_memory_usage.py +++ b/megatron/training/theoretical_memory_usage.py @@ -83,8 +83,11 @@ def compute_weight_and_optimizer_memory(args, verbose=False): f"{num_parameters_on_other_model_shards / 10**9:.4f}" ) + gradient_accumulation_factor = 4 if args.accumulate_allreduce_grads_in_fp32 else 2 + # parameters: bf16 + gradients: fp32 + # optimizer states: param: fp32, momentum: fp32, variance: fp32 num_bytes_per_parameter = ( - 18 if not args.use_distributed_optimizer else 6 + (12 / args.data_parallel_size) + (2 + gradient_accumulation_factor + (4 + 4 + 4)) if not args.use_distributed_optimizer else (2 + gradient_accumulation_factor) + (12 / args.data_parallel_size / args.context_parallel_size) ) weight_and_optimizer_memory = ( num_parameters_on_most_loaded_model_shard * num_bytes_per_parameter From 1c11f757138b1c4a3d11bef3a2136e573fd78189 Mon Sep 17 00:00:00 2001 From: kazuki Date: Fri, 3 Jan 2025 17:03:16 +0900 Subject: [PATCH 4/6] Change the display unit for memory footprints from MB to GB --- megatron/training/theoretical_memory_usage.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/megatron/training/theoretical_memory_usage.py b/megatron/training/theoretical_memory_usage.py index 9d5d75e31b..d6164df069 100644 --- a/megatron/training/theoretical_memory_usage.py +++ b/megatron/training/theoretical_memory_usage.py @@ -5,7 +5,7 @@ import math -NUM_BYTES_IN_MEGABYTE = 1024 * 1024 +NUM_BYTES_IN_GIGABYTE = 1024 * 1024 * 1024 def compute_weight_and_optimizer_memory(args, verbose=False): @@ -111,7 +111,7 @@ def compute_activation_memory(args, num_microbatches, verbose=False): if verbose: print( f"Activation memory footprint per transformer layer: " - f"{activation_memory / NUM_BYTES_IN_MEGABYTE / args.tensor_model_parallel_size:.1f} MB" + f"{activation_memory / NUM_BYTES_IN_GIGABYTE / args.tensor_model_parallel_size:.1f} GB" ) activation_memory *= args.num_layers @@ -172,23 +172,23 @@ def compute_activation_memory(args, num_microbatches, verbose=False): def report_theoretical_memory(args, num_microbatches=None, verbose=False): weight_and_optimizer_memory = ( - compute_weight_and_optimizer_memory(args, verbose=verbose) / NUM_BYTES_IN_MEGABYTE + compute_weight_and_optimizer_memory(args, verbose=verbose) / NUM_BYTES_IN_GIGABYTE ) # Formulae here assume sequence parallelism and selective activation recomputation. if not args.sequence_parallel or args.recompute_granularity != 'selective': print( - f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} MB" + f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} GB" ) return activation_memory = ( compute_activation_memory(args, num_microbatches=num_microbatches, verbose=verbose) - / NUM_BYTES_IN_MEGABYTE + / NUM_BYTES_IN_GIGABYTE ) total_memory = weight_and_optimizer_memory + activation_memory print( - f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} MB, " - f"activation={activation_memory:.2f} MB, total={total_memory:.2f} MB\n" + f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} GB, " + f"activation={activation_memory:.2f} GB, total={total_memory:.2f} GB\n" ) From 1242f36ba38e09610bf3b232e133f051728c4505 Mon Sep 17 00:00:00 2001 From: kazuki Date: Fri, 3 Jan 2025 17:10:56 +0900 Subject: [PATCH 5/6] Change the condition so that memory footprint is output not only when selective recomputation is used, but also when flash attention is used --- megatron/training/theoretical_memory_usage.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/megatron/training/theoretical_memory_usage.py b/megatron/training/theoretical_memory_usage.py index d6164df069..626fb54f44 100644 --- a/megatron/training/theoretical_memory_usage.py +++ b/megatron/training/theoretical_memory_usage.py @@ -175,8 +175,10 @@ def report_theoretical_memory(args, num_microbatches=None, verbose=False): compute_weight_and_optimizer_memory(args, verbose=verbose) / NUM_BYTES_IN_GIGABYTE ) - # Formulae here assume sequence parallelism and selective activation recomputation. - if not args.sequence_parallel or args.recompute_granularity != 'selective': + # Formulae here assume sequence parallelism and selective activation recomputation or flash-attention. + if not args.sequence_parallel or not ( + args.recompute_granularity == 'selective' or args.use_flash_attn is True + ): print( f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} GB" ) From 07df2052c59650c11e69009c8a0553424ebdade4 Mon Sep 17 00:00:00 2001 From: kazuki Date: Fri, 3 Jan 2025 20:58:41 +0900 Subject: [PATCH 6/6] Update to an activation memory footprint formula that takes into account swiglu, GQA, and CP --- megatron/training/theoretical_memory_usage.py | 77 +++++++++++++++---- 1 file changed, 64 insertions(+), 13 deletions(-) diff --git a/megatron/training/theoretical_memory_usage.py b/megatron/training/theoretical_memory_usage.py index 626fb54f44..01331e5969 100644 --- a/megatron/training/theoretical_memory_usage.py +++ b/megatron/training/theoretical_memory_usage.py @@ -105,24 +105,77 @@ def compute_activation_memory(args, num_microbatches, verbose=False): # different from hidden_size. # Memory footprint from transformer layer (self-attention and MLP). - activation_memory = (args.seq_length * args.micro_batch_size * args.hidden_size) * ( - 18 + (4 * (args.ffn_hidden_size / args.hidden_size)) + s = args.seq_length + h = args.hidden_size + b = args.micro_batch_size + a = args.num_attention_heads + k = args.num_query_groups + + s = s / args.context_parallel_size + args.selective_activation_recomputation = ( + args.recompute_granularity == 'selective' + ) + + activation_memory = ( + 2 * s * b * h # LayerNorm + + ( + # attention + 2 * s * b * h # input + + (2 * s * b * h) # Q + + (2 * s * b * h) * (k / a) # K + + ( + 0 if ( + args.selective_activation_recomputation or args.use_flash_attn + ) else (2 * b * s * s * a) # QK^T + ) + + ( + 0 if args.attention_dropout == 0.0 else 0 if ( + args.selective_activation_recomputation or args.use_flash_attn + ) else (1 * b * s * s * a) # Dropout + ) + + ((2 * b * s * h) * (k / a)) # V + + ( + 0 if ( + args.selective_activation_recomputation or args.use_flash_attn + ) else (2 * b * a * s * s) # Dropout(softmax(QK^T)) * V + ) + + (2 * b * s * h) # linear + ) + + (0 if args.hidden_dropout == 0.0 else (b * s * h)) # Dropout + + (2 * b * s * h) # LayerNorm + + ( + ( + # SwiGLU + 2 * b * s * h # input + + 2 * b * s * args.ffn_hidden_size # up_proj + + 2 * b * s * args.ffn_hidden_size # gate_proj + + 2 * b * s * args.ffn_hidden_size # act_fn + + 2 * b * s * args.ffn_hidden_size # down_proj + ) if args.swiglu else ( + 2 * b * s * h # h -> ffn_h + + 2 * b * s * args.ffn_hidden_size # act + + 2 * b * s * args.ffn_hidden_size # ffn_h -> h + ) + ) + + (0 if args.hidden_dropout == 0.0 else (b * s * h)) # Dropout ) if verbose: print( f"Activation memory footprint per transformer layer: " f"{activation_memory / NUM_BYTES_IN_GIGABYTE / args.tensor_model_parallel_size:.1f} GB" ) - activation_memory *= args.num_layers + activation_memory = activation_memory * args.num_layers / args.tensor_model_parallel_size # Now add activation memory required for input embeddings, last LayerNorm and output layer. # Input to embedding (pp_size microbatches in flight). activation_memory += ( - 8 * args.seq_length * args.micro_batch_size * args.pipeline_model_parallel_size - ) + # 8 bytes (int64) + 8 * s * b * h * args.pipeline_model_parallel_size + ) / args.tensor_model_parallel_size + # Dropout in embedding layer (pp_size microbatches in flight). - activation_memory += ( + activation_memory += 0 if args.hidden_dropout == 0 else ( args.seq_length * args.micro_batch_size * args.hidden_size @@ -159,15 +212,13 @@ def compute_activation_memory(args, num_microbatches, verbose=False): if args.pipeline_model_parallel_size == 1: # Inputs to output layer and CE loss. activation_memory += ( - args.seq_length - * args.micro_batch_size - * args.hidden_size - * 4 - * (1 + (args.padded_vocab_size / args.hidden_size)) - ) + # lm-head cross entropy (FP32) + # output layer (layer norm) + output layer (linear) + 4 * s * b * h * (1 + args.padded_vocab_size / h) + ) / args.tensor_model_parallel_size # Activation memory is partitioned by TP size due to tensor and sequence model parallelism. - return activation_memory / args.tensor_model_parallel_size + return activation_memory def report_theoretical_memory(args, num_microbatches=None, verbose=False):