-
Notifications
You must be signed in to change notification settings - Fork 635
Use all DTensor for Qwen3 and llama4 through TP region #2149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
If this change looks good, I will go ahead and change TP plan of deepseek_v3, llama3, llama4 |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, it seems it failed with MoE. Could you try to adjust MoE so that only the boundaries (input and output) are DTensor? We can discuss if you hit issues.
Updated! As qwen3 and llama4 share the |
|
test didn't pass |
Updated dsv3 as well - Because dsv3 reused llama4 Run loss comparison on dsv3 and qwen3 as well, they both passed. |
|
Could you also modify the test_generate script in torchtitan to use sp instead of no-sp? Btw, I think FlexAttention doesn't work with DTensor yet. Is your PR only enforce DTensor in SP regions (norm) but not TP regions (embedding, attention, mlp)? Specifically, have you tried your change with FlexAttn? |
No, it's acutally enforce DTensor in all TP regions. I tested dsv3 with FlexAttention and it works because of the special
But I don't think qwen3 / llama4 will work with FlexAttn out-of-box, let me fix, Thanks for catching! |
|
Nice, but DSV3 integration test fails. I believe it is because |
| input_layouts=(Shard(1), None, None, None), | ||
| desired_input_layouts=(Replicate(), None, None, None), | ||
| input_layouts=(Shard(1), Replicate(), None, None), | ||
| desired_input_layouts=(Replicate(), Replicate(), None, None), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, now we consistently make freqs_cis as a DTensor. The only one model that still uses plain tensor for freqs_cis is llama3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much! I plan to add change for llama3 as well
From the error stack, it's raising error from attention() part , which might not related to |
|
The error with TP + PP might because: The original TP plan specified Replicate() for freqs_cis, which caused the TP input hook to try to redistribute it as a DTensor. When PP is enabled, the freqs_cis buffer goes through various transformations (deep copy during PP split, to_empty, init_weights) that can corrupt any DTensor metadata. |
Numerics: (Thanks for @fegin's script:
python scripts/loss_compare.py main qwen3-dtensor --baseline-config=torchtitan/models/qwen3/train_configs/qwen3_0.6b.toml --baseline-ngpus=8 --test-ngpus=8 --baseline-options='--parallelism.tensor_parallel_degree=2' --test-options='--parallelism.tensor_parallel_degree=2')