Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUESTION] Whether to split bw when send_backward_recv_forward is not enabled #17

Open
AndSonder opened this issue Apr 2, 2024 · 4 comments

Comments

@AndSonder
Copy link

AndSonder commented Apr 2, 2024

Hi, very appreciate your work. I have a question for zbh1 mode.

This is one part of your code:

# For BWF pattern or in rank 0, we don't split W and B for reasons below.
#   1. to leverage batched p2p op (send_backward_recv_forward)
#   2. to overlap grad all-reduce for tensor parallel
#   3. to avoid redoing grad all-gather for sequence parallel
# Note that the order of grad accumulation is changed by this behavior,
# thus causing a minor precision error compared to 1F1B even it's mathematically correct.
WeightGradStore.split_bw = (i < rank or last_iteration) and rank > 0

You said that there is no need to split bw for BWF pattern.

My question is if we do not enable send_backward_recv_forward, is it better to split bw? A finer grain makes a smaller bubble, doesn't it?

@ufotalent
Copy link

Hi @AndSonder Thanks for your interest. I'm not sure whether I'm understanding your question correctly. For zbh1 rank 0, the schedule pattern is that W is always after B, and there're no communication after a B, so a B-W split won't make bubble smaller here.

@AndSonder
Copy link
Author

Hi @AndSonder Thanks for your interest. I'm not sure whether I'm understanding your question correctly. For zbh1 rank 0, the schedule pattern is that W is always after B, and there're no communication after a B, so a B-W split won't make bubble smaller here.

@ufotalent Thanks for your replay. If there have communication after a B (just like the picture in your paper), is the bubble going to be smaller?

@ufotalent
Copy link

Hi @AndSonder Thanks for your interest. I'm not sure whether I'm understanding your question correctly. For zbh1 rank 0, the schedule pattern is that W is always after B, and there're no communication after a B, so a B-W split won't make bubble smaller here.

@ufotalent Thanks for your replay. If there have communication after a B (just like the picture in your paper), is the bubble going to be smaller?

@AndSonder Hi, on other ranks, theoretically a [B, send_B, W] schedule (with split) will be better than [B, W, send_B] (without split). However, something to notice here is that the send_B implementation in Megatron-LM is synchonized and may probably wait until its recv peer completes. This mens the send_B can delay W a lot if splitted. If we really want to do split, we should use async send here.

@AndSonder
Copy link
Author

OK ~ I know it. Very thanks for your answer.

@AndSonder AndSonder reopened this Apr 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants