Skip to content

Conversation

@wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Dec 12, 2025

Numerics: (Thanks for @fegin's script: python scripts/loss_compare.py main qwen3-dtensor --baseline-config=torchtitan/models/qwen3/train_configs/qwen3_0.6b.toml --baseline-ngpus=8 --test-ngpus=8 --baseline-options='--parallelism.tensor_parallel_degree=2' --test-options='--parallelism.tensor_parallel_degree=2')

Screenshot 2025-12-12 at 2 35 52 PM Screenshot 2025-12-12 at 2 36 00 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 12, 2025
@wwwjn
Copy link
Contributor Author

wwwjn commented Dec 12, 2025

If this change looks good, I will go ahead and change TP plan of deepseek_v3, llama3, llama4

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, it seems it failed with MoE. Could you try to adjust MoE so that only the boundaries (input and output) are DTensor? We can discuss if you hit issues.

@wwwjn wwwjn changed the title Use all DTensor for Qwen3 model TP Use all DTensor for Qwen3 and llama4 through TP region Dec 12, 2025
@wwwjn
Copy link
Contributor Author

wwwjn commented Dec 12, 2025

nice, it seems it failed with MoE. Could you try to adjust MoE so that only the boundaries (input and output) are DTensor? We can discuss if you hit issues.

Updated! As qwen3 and llama4 share the apply_moe_ep_tp function, I updated llama4 TP plan together in this PR

@tianyu-l
Copy link
Contributor

test didn't pass

@wwwjn
Copy link
Contributor Author

wwwjn commented Dec 13, 2025

test didn't pass

Updated dsv3 as well - Because dsv3 reused llama4 apply_moe_ep_tp.

Run loss comparison on dsv3 and qwen3 as well, they both passed.

@tianyu-l
Copy link
Contributor

Could you also modify the test_generate script in torchtitan to use sp instead of no-sp?

Btw, I think FlexAttention doesn't work with DTensor yet. Is your PR only enforce DTensor in SP regions (norm) but not TP regions (embedding, attention, mlp)? Specifically, have you tried your change with FlexAttn?

@wwwjn
Copy link
Contributor Author

wwwjn commented Dec 13, 2025

Is your PR only enforce DTensor in SP regions (norm) but not TP regions (embedding, attention, mlp)? Specifically, have you tried your change with FlexAttn?

No, it's acutally enforce DTensor in all TP regions. I tested dsv3 with FlexAttention and it works because of the special attention_kernel_plan:

attention_kernel_plan = prepare_module_input(
, which turns DTensor to plain tensor before attention kernel

But I don't think qwen3 / llama4 will work with FlexAttn out-of-box, let me fix, Thanks for catching!

@fegin
Copy link
Contributor

fegin commented Dec 13, 2025

Nice, but DSV3 integration test fails. I believe it is because tok_embeddings doesn't exist for some stages when PP is enabled. You will need to check if tok_embeddings before applying the plan.

input_layouts=(Shard(1), None, None, None),
desired_input_layouts=(Replicate(), None, None, None),
input_layouts=(Shard(1), Replicate(), None, None),
desired_input_layouts=(Replicate(), Replicate(), None, None),
Copy link
Contributor

@fegin fegin Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, now we consistently make freqs_cis as a DTensor. The only one model that still uses plain tensor for freqs_cis is llama3.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much! I plan to add change for llama3 as well

@wwwjn
Copy link
Contributor Author

wwwjn commented Dec 14, 2025

Nice, but DSV3 integration test fails. I believe it is because tok_embeddings doesn't exist for some stages when PP is enabled. You will need to check if tok_embeddings before applying the plan.

From the error stack, it's raising error from attention() part , which might not related to tok_embeddings? is it related to freqs_cis? cc @H-Huang

    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1830, in inner
      result = forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/pytorch/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 497, in forward
      h = layer(h, self.freqs_cis, attention_masks, positions)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1882, in _call_impl
      return inner()
             ^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1830, in inner
      result = forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/pytorch/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 384, in forward
      x = x + self.attention(
              ^^^^^^^^^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1882, in _call_impl
      return inner()
             ^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1819, in inner
      args_result = hook(self, args)
                    ^^^^^^^^^^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/distributed/tensor/parallel/style.py", line 580, in <lambda>
      lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/distributed/tensor/parallel/style.py", line 552, in _prepare_input_fn
      self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout)
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/distributed/tensor/parallel/style.py", line 529, in _prepare_input_arg
      dt_inp = dt_inp.redistribute(placements=(desired_layout,))
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 566, in redistribute
      return Redistribute.apply(
             ^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/autograd/function.py", line 583, in apply
      return super().apply(*args, **kwargs)  # type: ignore[misc]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/distributed/tensor/_redistribute.py", line 949, in forward
      output = redistribute_local_tensor(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/distributed/tensor/_redistribute.py", line 840, in redistribute_local_tensor
      new_local_tensor = current_placement._to_replicate_tensor(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/distributed/tensor/placement_types.py", line 303, in _to_replicate_tensor
      result = funcol.all_gather_tensor(
               ^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/distributed/_functional_collectives.py", line 200, in all_gather_tensor
      group_size = c10d._get_group_size_by_name(group_name)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 1118, in _get_group_size_by_name
      group = _resolve_process_group(group_name)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  RuntimeError: Could not resolve the process group registered under the name 18
  

@wwwjn
Copy link
Contributor Author

wwwjn commented Dec 14, 2025

The error with TP + PP might because: The original TP plan specified Replicate() for freqs_cis, which caused the TP input hook to try to redistribute it as a DTensor. When PP is enabled, the freqs_cis buffer goes through various transformations (deep copy during PP split, to_empty, init_weights) that can corrupt any DTensor metadata.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants