Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update theoretical memory footprint formula #1345

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
114 changes: 87 additions & 27 deletions megatron/training/theoretical_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import math

NUM_BYTES_IN_MEGABYTE = 1024 * 1024
NUM_BYTES_IN_GIGABYTE = 1024 * 1024 * 1024


def compute_weight_and_optimizer_memory(args, verbose=False):
Expand All @@ -32,10 +32,11 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
# MLP.
+ ((args.ffn_hidden_size / args.hidden_size) * num_experts * gated_linear_multiplier)
# Transformer layernorms.
+ (2 / args.hidden_size)
# Final layernorm.
+ (1 / (args.num_layers * args.hidden_size))
+ (1 / args.hidden_size)
)
) + (
# final layernorm
args.hidden_size
)
embedding_size = args.hidden_size * args.padded_vocab_size
if args.untie_embeddings_and_output_weights:
Expand All @@ -56,12 +57,15 @@ def compute_weight_and_optimizer_memory(args, verbose=False):

# Most loaded model shard has (1/pp_size transformer layers + 1 embedding layer) / tp_size.
num_parameters_on_most_loaded_model_shard = (
(num_parameters_in_transformer_layers / args.pipeline_model_parallel_size) + embedding_size
( # - last layer norm
(num_parameters_in_transformer_layers - args.hidden_size) / args.pipeline_model_parallel_size
) + embedding_size # (embedding layer) for first stage
) / args.tensor_model_parallel_size
if args.untie_embeddings_and_output_weights and args.pipeline_model_parallel_size == 1:
num_parameters_on_most_loaded_model_shard += (
embedding_size / args.tensor_model_parallel_size
)
num_parameters_on_most_loaded_model_shard += args.hidden_size # last layer norm
if verbose:
print(
f"Number of parameters in most loaded shard in billions: "
Expand All @@ -79,8 +83,11 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
f"{num_parameters_on_other_model_shards / 10**9:.4f}"
)

gradient_accumulation_factor = 4 if args.accumulate_allreduce_grads_in_fp32 else 2
# parameters: bf16 + gradients: fp32
# optimizer states: param: fp32, momentum: fp32, variance: fp32
num_bytes_per_parameter = (
18 if not args.use_distributed_optimizer else 6 + (12 / args.data_parallel_size)
(2 + gradient_accumulation_factor + (4 + 4 + 4)) if not args.use_distributed_optimizer else (2 + gradient_accumulation_factor) + (12 / args.data_parallel_size / args.context_parallel_size)
)
weight_and_optimizer_memory = (
num_parameters_on_most_loaded_model_shard * num_bytes_per_parameter
Expand All @@ -98,24 +105,77 @@ def compute_activation_memory(args, num_microbatches, verbose=False):
# different from hidden_size.

# Memory footprint from transformer layer (self-attention and MLP).
activation_memory = (args.seq_length * args.micro_batch_size * args.hidden_size) * (
18 + (4 * (args.ffn_hidden_size / args.hidden_size))
s = args.seq_length
h = args.hidden_size
b = args.micro_batch_size
a = args.num_attention_heads
k = args.num_query_groups

s = s / args.context_parallel_size
args.selective_activation_recomputation = (
args.recompute_granularity == 'selective'
)

activation_memory = (
2 * s * b * h # LayerNorm
+ (
# attention
2 * s * b * h # input
+ (2 * s * b * h) # Q
+ (2 * s * b * h) * (k / a) # K
+ (
0 if (
args.selective_activation_recomputation or args.use_flash_attn
) else (2 * b * s * s * a) # QK^T
)
+ (
0 if args.attention_dropout == 0.0 else 0 if (
args.selective_activation_recomputation or args.use_flash_attn
) else (1 * b * s * s * a) # Dropout
)
+ ((2 * b * s * h) * (k / a)) # V
+ (
0 if (
args.selective_activation_recomputation or args.use_flash_attn
) else (2 * b * a * s * s) # Dropout(softmax(QK^T)) * V
)
+ (2 * b * s * h) # linear
)
+ (0 if args.hidden_dropout == 0.0 else (b * s * h)) # Dropout
+ (2 * b * s * h) # LayerNorm
+ (
(
# SwiGLU
2 * b * s * h # input
+ 2 * b * s * args.ffn_hidden_size # up_proj
+ 2 * b * s * args.ffn_hidden_size # gate_proj
+ 2 * b * s * args.ffn_hidden_size # act_fn
+ 2 * b * s * args.ffn_hidden_size # down_proj
) if args.swiglu else (
2 * b * s * h # h -> ffn_h
+ 2 * b * s * args.ffn_hidden_size # act
+ 2 * b * s * args.ffn_hidden_size # ffn_h -> h
)
)
+ (0 if args.hidden_dropout == 0.0 else (b * s * h)) # Dropout
)
if verbose:
print(
f"Activation memory footprint per transformer layer: "
f"{activation_memory / NUM_BYTES_IN_MEGABYTE / args.tensor_model_parallel_size:.1f} MB"
f"{activation_memory / NUM_BYTES_IN_GIGABYTE / args.tensor_model_parallel_size:.1f} GB"
)
activation_memory *= args.num_layers
activation_memory = activation_memory * args.num_layers / args.tensor_model_parallel_size

# Now add activation memory required for input embeddings, last LayerNorm and output layer.

# Input to embedding (pp_size microbatches in flight).
activation_memory += (
8 * args.seq_length * args.micro_batch_size * args.pipeline_model_parallel_size
)
# 8 bytes (int64)
8 * s * b * h * args.pipeline_model_parallel_size
) / args.tensor_model_parallel_size

# Dropout in embedding layer (pp_size microbatches in flight).
activation_memory += (
activation_memory += 0 if args.hidden_dropout == 0 else (
args.seq_length
* args.micro_batch_size
* args.hidden_size
Expand Down Expand Up @@ -152,36 +212,36 @@ def compute_activation_memory(args, num_microbatches, verbose=False):
if args.pipeline_model_parallel_size == 1:
# Inputs to output layer and CE loss.
activation_memory += (
args.seq_length
* args.micro_batch_size
* args.hidden_size
* 4
* (1 + (args.padded_vocab_size / args.hidden_size))
)
# lm-head cross entropy (FP32)
# output layer (layer norm) + output layer (linear)
4 * s * b * h * (1 + args.padded_vocab_size / h)
) / args.tensor_model_parallel_size

# Activation memory is partitioned by TP size due to tensor and sequence model parallelism.
return activation_memory / args.tensor_model_parallel_size
return activation_memory


def report_theoretical_memory(args, num_microbatches=None, verbose=False):
weight_and_optimizer_memory = (
compute_weight_and_optimizer_memory(args, verbose=verbose) / NUM_BYTES_IN_MEGABYTE
compute_weight_and_optimizer_memory(args, verbose=verbose) / NUM_BYTES_IN_GIGABYTE
)

# Formulae here assume sequence parallelism and selective activation recomputation.
if not args.sequence_parallel or args.recompute_granularity != 'selective':
# Formulae here assume sequence parallelism and selective activation recomputation or flash-attention.
if not args.sequence_parallel or not (
args.recompute_granularity == 'selective' or args.use_flash_attn is True
):
print(
f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} MB"
f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} GB"
)
return

activation_memory = (
compute_activation_memory(args, num_microbatches=num_microbatches, verbose=verbose)
/ NUM_BYTES_IN_MEGABYTE
/ NUM_BYTES_IN_GIGABYTE
)
total_memory = weight_and_optimizer_memory + activation_memory

print(
f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} MB, "
f"activation={activation_memory:.2f} MB, total={total_memory:.2f} MB\n"
f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} GB, "
f"activation={activation_memory:.2f} GB, total={total_memory:.2f} GB\n"
)