-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support GPU-resident batch sizes #21
Conversation
@tgale96 Hi Trevor! I'm trying to close some to-do items I left there previously. :) If you have the cycles, could you please take a look at this PR (whenever it is convenient for you, it's not urgent)? |
csrc/fill_arguments.cuh
Outdated
template < | ||
// If `k` is dynamic, we sort the problems by `k` in descending order. | ||
// Otherwise, `m` is dynamic, and no sorting happens. | ||
bool dynamic_k, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - I usually ~follow google cpp style. For static constants, I you can use constant naming to make it more clear which values are static and which are dynamic. Here, 'dynamic_k' would be kDynamicK.
https://google.github.io/styleguide/cppguide.html#Constant_Names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair, renamed dynamic_k
to kDynamicK
.
dims.m() = batch_size; | ||
} | ||
|
||
using BlockScan = cub::BlockScan<int, kMaxExperts>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious if there is a performance advantage of reducing the maximum number of experts you set for the cub primitives here? i.e., would it be worth it to dynamic dispatch to kMaxExperts in {2, 4, 8, 16, 32, ...} when you call these kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh, turns out there is (i.e., I wrongly asssumed it wouldn't make a difference for a single threadblock)!
For quick testing, I hardcoded the number of experts to 8
in my synthetic example from the PR description (which uses 8 experts). For the scan it doesn't make any difference (the kernel runtime is ≈3 microseconds on an A100 for both kMaxExperts=512
and kMaxExperts=8
). However, the sort used for kDynamicK
runs in ≈25 microseconds for kMaxExperts=512
, and in ≈6 microseconds for kMaxExperts=8
, which is very nice.
What if I land this optimization in a separate PR? I think it's not a blocker for this PR (the sorting time should still be negligible compared to the GEMM runtime), but I need some time to investigate this properly and decide how to better incorporate it into the FillArguments
kernel (so that the code remains readable).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! SGTM!
} | ||
|
||
template <typename Args> | ||
__global__ void IgnoreK0Problems(int num_experts, Args args) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q - Do we need a third kernel for this? This looks like it can be done in FillArguments, maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As of right now, we do need a separate kernel because ZeroOutK0Outputs()
uses the M
and N
dimensions to determine how many elements needs to be zeroed. So we only can set these dimension to zero when ZeroOutK0Outputs()
has finished.
There probably is some way to fit this in two kernels instead of three, but I'm not sure this is worth it, since eventually this hack should go away entirely (worst-case scenario, we can vendor CUTLASS with a one-line fix backported from my PR).
csrc/grouped_gemm.cu
Outdated
} | ||
|
||
template < | ||
bool dynamic_k, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto above - kDynamicK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, renamed again.
csrc/grouped_gemm.cu
Outdated
|
||
// We don't know the problem dimensions on the host, so we just base the number of threadblocks on occupancy here. | ||
threadblock_count = Gemm::sufficient(); | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - can we factor these two branches into helper functions to make this easier to parse?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i.e., have MakeArgumentsOnDevice and MakeArgumentsOnHost that you just call on either side of this branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, makes sense, introduced both MakeArgumentsOnDevice()
and MakeArgumentsOnHost()
.
csrc/grouped_gemm.cu
Outdated
@@ -218,23 +228,76 @@ typename Gemm::Arguments MakeArguments(torch::Tensor a, | |||
/*ldb=*/(int64_t*)ldb.data_ptr(), | |||
/*ldc=*/(int64_t*)ldc.data_ptr(), | |||
/*ldd=*/(int64_t*)ldc.data_ptr(), | |||
(cutlass::gemm::GemmCoord*)problem_sizes_host.data()); | |||
// We currently always use `GroupScheduleMode::kDeviceOnly`, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - this comment breaks ~80 chars per line I think? maybe just put it above the function call rather than inline with the arguments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put it above the function call rather than inline with the arguments
Fair, moved it above.
int64_t workspace_size = gemm.get_workspace_size(arguments); | ||
auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device()); | ||
torch::Tensor workspace = torch::empty(workspace_size, options); | ||
|
||
if (batch_sizes.is_cuda()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we break the body of this condition into helper functions to make it easier to parse?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, introduced a couple of helper functions (not sure it made the code much easier to parse though).
return arguments; | ||
} | ||
|
||
template <bool trans_a, bool trans_b> | ||
torch::Tensor CutlassGroupedGemm(torch::Tensor a, | ||
torch::Tensor b, | ||
torch::Tensor c, | ||
torch::Tensor batch_sizes) { | ||
torch::Tensor batch_sizes, | ||
::cutlass::gemm::GemmCoord coord_template) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this coord_template argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment to the line where coord_template
is defined. This is essentialy a template (hence the name) of the actual GemmCoord
that will be used for each element of the batch. This template is filled out later (possibly on the device).
grouped_gemm/ops_test.py
Outdated
for f in [(False,), (True,)]: | ||
out.append(y + f) | ||
for trans_b in (False, True): | ||
for gpu_batch_sizes in (False, True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - maybe batch_sizes_on_device
would be more clear than gpu_batch_sizes
for this flag? The latter looks like it is actually the array of batch sizes on GPUs, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair, renamed.
Looks awesome! Sorry for the delay - I've been traveling for the last couple weeks :) |
@mvpatel2000 for viz. |
611fefa
to
b8d11af
Compare
I just pushed the fixes for everything except the CUB dynamic dispatch proposal (I suggest we do it in a separate PR). I made each fix in a separate commit so that it's obvious where each comment is addressed, but feel free to squash the commits however you see fit before the merge. |
To add more motivation for this, the current need for having some tensors on CPU seems to cause issues with CUDA stream capture ( |
Apologies for the delay! Looks great, thanks for the changes! |
This is a follow-up to #14.
The current implementation always expects the batch sizes to be on the CPU. If they are actually computed on the GPU (as happens in MoE training runs), this implies a GPU<->CPU synchronization point to fetch the batch sizes to the CPU, and then immediately send them back to the GPU in a slightly different format.
This PR gets rid of this roundtrip when it is possible (e.g. when using CUTLASS) by using a simple single-threadblock kernel that converts a list of batch sizes to the kernel arguments CUTLASS expects. I suppose 1024 experts (the practical limit on the number of threads in one threadblock) ought to be enough for everybody. :) If not, I can re-write it with a slightly less efficient kernel with multiple threadblocks.
Here's a slightly contrived synthetic example that illustrates the cost of the roundtrip (if the GEMMs and the context length are large enough, the cost is negligible):
In this toy example a CPU<->GPU sync is mostly a minor annoyance, but in serious training runs that overlap compute and comms in ZeRO3-like fashion, this sync can be devastating because we can't schedule a NCCL collective to prefetch the parameters for the next layer until we are done with the current one. We saw throughput hits as large as 60%, but unfortunately, it's hard to come up with a minimal example to reproduce this.
The CUTLASS issue with
k=0
grouped GEMMs is still not resolved, so I had to resort to ugly hacks on the GPU side as well. Hopefully I can remove them someday. :)