Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GPU-resident batch sizes #21

Merged
merged 9 commits into from
Oct 31, 2024
Merged

Conversation

dfyz
Copy link
Contributor

@dfyz dfyz commented Oct 15, 2024

This is a follow-up to #14.

The current implementation always expects the batch sizes to be on the CPU. If they are actually computed on the GPU (as happens in MoE training runs), this implies a GPU<->CPU synchronization point to fetch the batch sizes to the CPU, and then immediately send them back to the GPU in a slightly different format.

This PR gets rid of this roundtrip when it is possible (e.g. when using CUTLASS) by using a simple single-threadblock kernel that converts a list of batch sizes to the kernel arguments CUTLASS expects. I suppose 1024 experts (the practical limit on the number of threads in one threadblock) ought to be enough for everybody. :) If not, I can re-write it with a slightly less efficient kernel with multiple threadblocks.

Here's a slightly contrived synthetic example that illustrates the cost of the roundtrip (if the GEMMs and the context length are large enough, the cost is negligible):

import torch
import grouped_gemm as gg


use_cpu_batch_sizes = True


if __name__ == '__main__':
    # GRIN-MoE sizes.
    M = 1024
    K = 4096
    N = 6400
    E = 8

    torch.manual_seed(0)

    x = torch.rand(M, K, dtype=torch.bfloat16, device='cuda')
    w = torch.rand(E, K, N, dtype=torch.bfloat16, device='cuda')

    x.requires_grad_(True)
    w.requires_grad_(True)

    batch_sizes = torch.tensor([M//E]*E, device='cuda')

    with torch.profiler.profile(activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA
    ]) as prof:
        for _ in range(30):
            if use_cpu_batch_sizes:
                out = gg.ops.gmm(x, w, batch_sizes.cpu())
            else:
                out = gg.ops.gmm(x, w, batch_sizes)
            grad = out.sum().backward()

    torch.cuda.synchronize()
    prof.export_chrome_trace(f'gmm_trace.json')
Total kernel runtime (not including the first warmup kernels), ms Perfetto trace
With CPU<->GPU sync 68 a100_with_cpu_sync.json
Without CPU<->GPU sync 60 a100_no_cpu_sync.json

In this toy example a CPU<->GPU sync is mostly a minor annoyance, but in serious training runs that overlap compute and comms in ZeRO3-like fashion, this sync can be devastating because we can't schedule a NCCL collective to prefetch the parameters for the next layer until we are done with the current one. We saw throughput hits as large as 60%, but unfortunately, it's hard to come up with a minimal example to reproduce this.

The CUTLASS issue with k=0 grouped GEMMs is still not resolved, so I had to resort to ugly hacks on the GPU side as well. Hopefully I can remove them someday. :)

@dfyz
Copy link
Contributor Author

dfyz commented Oct 17, 2024

@tgale96 Hi Trevor! I'm trying to close some to-do items I left there previously. :) If you have the cycles, could you please take a look at this PR (whenever it is convenient for you, it's not urgent)?

template <
// If `k` is dynamic, we sort the problems by `k` in descending order.
// Otherwise, `m` is dynamic, and no sorting happens.
bool dynamic_k,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - I usually ~follow google cpp style. For static constants, I you can use constant naming to make it more clear which values are static and which are dynamic. Here, 'dynamic_k' would be kDynamicK.

https://google.github.io/styleguide/cppguide.html#Constant_Names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, renamed dynamic_k to kDynamicK.

csrc/fill_arguments.cuh Show resolved Hide resolved
dims.m() = batch_size;
}

using BlockScan = cub::BlockScan<int, kMaxExperts>;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious if there is a performance advantage of reducing the maximum number of experts you set for the cub primitives here? i.e., would it be worth it to dynamic dispatch to kMaxExperts in {2, 4, 8, 16, 32, ...} when you call these kernels?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, turns out there is (i.e., I wrongly asssumed it wouldn't make a difference for a single threadblock)!

For quick testing, I hardcoded the number of experts to 8 in my synthetic example from the PR description (which uses 8 experts). For the scan it doesn't make any difference (the kernel runtime is ≈3 microseconds on an A100 for both kMaxExperts=512 and kMaxExperts=8). However, the sort used for kDynamicK runs in ≈25 microseconds for kMaxExperts=512, and in ≈6 microseconds for kMaxExperts=8, which is very nice.

What if I land this optimization in a separate PR? I think it's not a blocker for this PR (the sorting time should still be negligible compared to the GEMM runtime), but I need some time to investigate this properly and decide how to better incorporate it into the FillArguments kernel (so that the code remains readable).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! SGTM!

}

template <typename Args>
__global__ void IgnoreK0Problems(int num_experts, Args args) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q - Do we need a third kernel for this? This looks like it can be done in FillArguments, maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of right now, we do need a separate kernel because ZeroOutK0Outputs() uses the M and N dimensions to determine how many elements needs to be zeroed. So we only can set these dimension to zero when ZeroOutK0Outputs() has finished.

There probably is some way to fit this in two kernels instead of three, but I'm not sure this is worth it, since eventually this hack should go away entirely (worst-case scenario, we can vendor CUTLASS with a one-line fix backported from my PR).

}

template <
bool dynamic_k,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto above - kDynamicK

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, renamed again.


// We don't know the problem dimensions on the host, so we just base the number of threadblocks on occupancy here.
threadblock_count = Gemm::sufficient();
} else {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - can we factor these two branches into helper functions to make this easier to parse?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e., have MakeArgumentsOnDevice and MakeArgumentsOnHost that you just call on either side of this branch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, makes sense, introduced both MakeArgumentsOnDevice() and MakeArgumentsOnHost().

@@ -218,23 +228,76 @@ typename Gemm::Arguments MakeArguments(torch::Tensor a,
/*ldb=*/(int64_t*)ldb.data_ptr(),
/*ldc=*/(int64_t*)ldc.data_ptr(),
/*ldd=*/(int64_t*)ldc.data_ptr(),
(cutlass::gemm::GemmCoord*)problem_sizes_host.data());
// We currently always use `GroupScheduleMode::kDeviceOnly`,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - this comment breaks ~80 chars per line I think? maybe just put it above the function call rather than inline with the arguments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put it above the function call rather than inline with the arguments

Fair, moved it above.

int64_t workspace_size = gemm.get_workspace_size(arguments);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device());
torch::Tensor workspace = torch::empty(workspace_size, options);

if (batch_sizes.is_cuda()) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we break the body of this condition into helper functions to make it easier to parse?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, introduced a couple of helper functions (not sure it made the code much easier to parse though).

return arguments;
}

template <bool trans_a, bool trans_b>
torch::Tensor CutlassGroupedGemm(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes) {
torch::Tensor batch_sizes,
::cutlass::gemm::GemmCoord coord_template) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this coord_template argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment to the line where coord_template is defined. This is essentialy a template (hence the name) of the actual GemmCoord that will be used for each element of the batch. This template is filled out later (possibly on the device).

for f in [(False,), (True,)]:
out.append(y + f)
for trans_b in (False, True):
for gpu_batch_sizes in (False, True):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - maybe batch_sizes_on_device would be more clear than gpu_batch_sizes for this flag? The latter looks like it is actually the array of batch sizes on GPUs, I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, renamed.

@tgale96
Copy link
Owner

tgale96 commented Oct 18, 2024

Looks awesome! Sorry for the delay - I've been traveling for the last couple weeks :)

@tgale96
Copy link
Owner

tgale96 commented Oct 18, 2024

@mvpatel2000 for viz.

@dfyz dfyz force-pushed the support-gpu-batch-sizes branch from 611fefa to b8d11af Compare October 18, 2024 20:54
@dfyz
Copy link
Contributor Author

dfyz commented Oct 18, 2024

I just pushed the fixes for everything except the CUB dynamic dispatch proposal (I suggest we do it in a separate PR). I made each fix in a separate commit so that it's obvious where each comment is addressed, but feel free to squash the commits however you see fit before the merge.

@ntoxeg
Copy link

ntoxeg commented Oct 31, 2024

To add more motivation for this, the current need for having some tensors on CPU seems to cause issues with CUDA stream capture (RuntimeError: CUDA error: operation not permitted when stream is capturing). As far as I understand, graph capture needs to be done on a specific device, so moving things to the CPU is not allowed.

@tgale96 tgale96 merged commit ebeae0b into tgale96:main Oct 31, 2024
@tgale96
Copy link
Owner

tgale96 commented Oct 31, 2024

Apologies for the delay! Looks great, thanks for the changes!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants