Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,9 @@ def apply_non_moe_tp(
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
use_local_output=False,
),
"norm": SequenceParallel(),
"norm": SequenceParallel(use_local_output=False),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
Expand All @@ -223,7 +224,9 @@ def apply_non_moe_tp(
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
for transformer_block in model.layers.values():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention_norm": SequenceParallel(
use_local_output=False,
),
# NOTE: when the fourth argument (positions) is not None, its input layout
# and desired input layout should be Replicate()
"attention": prepare_module_input(
Expand All @@ -238,8 +241,13 @@ def apply_non_moe_tp(
"attention.kv_norm": NoParallel(use_local_output=False),
# NOTE: use_local_output=True so that the inputs to FlexAttention are plain Tensors
"attention.inner_attention": attention_kernel_plan,
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"attention.wo": rowwise_parallel(
output_layouts=Shard(1),
use_local_output=False,
),
"ffn_norm": SequenceParallel(
use_local_output=False,
),
}

if transformer_block.attention.q_lora_rank == 0:
Expand All @@ -266,9 +274,11 @@ def apply_non_moe_tp(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": colwise_parallel(),
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
"feed_forward.w3": colwise_parallel(),
"feed_forward.w1": colwise_parallel(use_local_output=False),
"feed_forward.w2": rowwise_parallel(
output_layouts=Shard(1), use_local_output=False
),
"feed_forward.w3": colwise_parallel(use_local_output=False),
}
)

Expand Down
34 changes: 22 additions & 12 deletions torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,11 @@ def apply_non_moe_tp(
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
use_local_output=False,
),
"norm": SequenceParallel(
use_local_output=False,
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
Expand Down Expand Up @@ -239,18 +242,22 @@ def apply_non_moe_tp(
# Apply tensor + sequence parallelism to every transformer block
for transformer_block in model.layers.values():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention_norm": SequenceParallel(
use_local_output=False,
),
# NOTE: when the fourth argument (positions) is not None, its input layout
# and desired input layout should be Replicate()
"attention": prepare_module_input(
input_layouts=(Shard(1), None, None, None),
desired_input_layouts=(Replicate(), None, None, None),
input_layouts=(Shard(1), Replicate(), None, None),
desired_input_layouts=(Replicate(), Replicate(), None, None),
Copy link
Contributor

@fegin fegin Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, now we consistently make freqs_cis as a DTensor. The only one model that still uses plain tensor for freqs_cis is llama3.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much! I plan to add change for llama3 as well

),
"attention.wq": colwise_parallel(),
"attention.wk": colwise_parallel(),
"attention.wv": colwise_parallel(),
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"attention.wq": colwise_parallel(use_local_output=False),
"attention.wk": colwise_parallel(use_local_output=False),
"attention.wv": colwise_parallel(use_local_output=False),
"attention.wo": rowwise_parallel(
output_layouts=Shard(1), use_local_output=False
),
"ffn_norm": SequenceParallel(use_local_output=False),
}
if not transformer_block.moe_enabled:
layer_plan.update(
Expand All @@ -259,9 +266,11 @@ def apply_non_moe_tp(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": colwise_parallel(),
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
"feed_forward.w3": colwise_parallel(),
"feed_forward.w1": colwise_parallel(use_local_output=False),
"feed_forward.w2": rowwise_parallel(
output_layouts=Shard(1), use_local_output=False
),
"feed_forward.w3": colwise_parallel(use_local_output=False),
}
)

Expand Down Expand Up @@ -461,6 +470,7 @@ def apply_moe_ep_tp(
use_local_input=True,
output_layouts=(Partial(),),
desired_output_layouts=(Shard(1),),
use_local_output=False,
),
# replicate computation for the router
"moe.router.gate": NoParallel(),
Expand Down
27 changes: 18 additions & 9 deletions torchtitan/models/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,9 @@ def apply_non_moe_tp(
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
use_local_output=False,
),
"norm": SequenceParallel(),
"norm": SequenceParallel(use_local_output=False),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
Expand Down Expand Up @@ -240,7 +241,7 @@ def apply_non_moe_tp(
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
for transformer_block in model.layers.values():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention_norm": SequenceParallel(use_local_output=False),
# NOTE: when the fourth argument (positions) is not None, its input layout
# and desired input layout should be Replicate()
"attention": prepare_module_input(
Expand All @@ -250,10 +251,16 @@ def apply_non_moe_tp(
"attention.wq": colwise_parallel(use_local_output=False),
"attention.wk": colwise_parallel(use_local_output=False),
"attention.wv": colwise_parallel(use_local_output=False),
"attention.q_norm": SequenceParallel(sequence_dim=2),
"attention.k_norm": SequenceParallel(sequence_dim=2),
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"attention.q_norm": SequenceParallel(
sequence_dim=2, use_local_output=False
),
"attention.k_norm": SequenceParallel(
sequence_dim=2, use_local_output=False
),
"attention.wo": rowwise_parallel(
output_layouts=Shard(1), use_local_output=False
),
"ffn_norm": SequenceParallel(use_local_output=False),
}

if not transformer_block.moe_enabled:
Expand All @@ -263,9 +270,11 @@ def apply_non_moe_tp(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": colwise_parallel(),
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
"feed_forward.w3": colwise_parallel(),
"feed_forward.w1": colwise_parallel(use_local_output=False),
"feed_forward.w2": rowwise_parallel(
output_layouts=Shard(1), use_local_output=False
),
"feed_forward.w3": colwise_parallel(use_local_output=False),
}
)

Expand Down
Loading