-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
interleaved 1F1B seems to work better #21
Comments
Hi, can you share the settings of the pipeline parallelism? Like the degree of PP, the number of microbatches (can be calculated by globalbs/microbs/DP) and the number of virtual stages of interleaved_1f1b? One possible reason is the number of microbatches or the number of virtual stages in interleaved 1f1b is very high so interleaved 1f1b is already having a good efficiency. |
Yes, I think Zero Bubble needs to be compared with the optimal performance of interleaved 1f1b to show its value, so the virtual stages are set to the maximum value. When the number of microbatches is small, ZB and interleaved_1f1b will be better than 1F1B. When the number of microbatches is large enough(eg: 512), the difference between the three is almost negligible(The performance difference is less than 1%).In actual pre-training, ZeroBubble has very little improvement than 1F1B. Did I make a mistake somewhere?Did I make a mistake somewhere? |
Interleaved 1F1B is a good strategy. that being said, zero bubble is actually a method orthogonal to interleaving. We can easily produce a schedule that's both interleaving and zero bubble. One example is sth like below, which is combining ZB-H1 and interleaving. We leave this to the community to explore, because the main point of zero bubble is the idea of B-W split and its ability of bubble elimination, not saying it's better than everything else. |
I tried multiple sets of experiments, but found that ZB is better than 1F1B. Interleaved 1F1B seems to be slightly faster than ZB_V, slightly slower than ZB_2P but saves a lot of GPU memory.
machine: 8*H800 80G
model:6.2B
1F1B 55 samples/(8 GPU)/seconds 48G MEM
INTERLEAVED_1F1B 66 samples/(8 GPU)/seconds 57G MEM
ZB_2P 67 samples/(8 GPU)/seconds 79G MEM
ZB_V 64 samples/(8 GPU)/seconds 53G MEM
The text was updated successfully, but these errors were encountered: