Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow scalar broadcasting in VisitorRowBroadcast and VisitorColBroadcast #1539

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

tlrmchlsmth
Copy link

This PR addresses an inconsistency between the VisitorRowBroadcast/VisitorColBroadcast epilogues and the SM90RowBroadcast/SM90ColBroadcast epilogues.

The inconsistency is that the SM90 epilogues can handle either row/column broadcasting by passing in a nullptr for the first argument, and a float for the second, while the visitor epilogues cannot. This PR adds this functionality to the visitor epilogues.

I am using this for quantized GEMMs that can handle either per-token/per channel quantization or per-tensor quantization without compiling and distributing multiple kernels to handle all cases.

For reference, I ran into this issue when developing vllm-project/vllm#4749

@tlrmchlsmth
Copy link
Author

Very happy to add unit tests and put in the work to get this PR into a landable state. But first hoping to get some high-level feedback on whether this is the right approach or a reasonable thing to do. Thanks!

@tlrmchlsmth
Copy link
Author

cc @mnicely

@mnicely mnicely added the feature request New feature or request label May 24, 2024
@Hongbosherlock
Copy link

Hongbosherlock commented Jun 4, 2024

Hi @tlrmchlsmth thanks for your contribution. I'm working on int8 GEMM with dequant fusion.
Can the following code work with the original VisitorRowBroadcast/VisitorColBroadcast epilogues?

    // inputs
    //     A           [M, K]    int8
    //     B           [N, K]    int8
    //     alphaCol    [M, 1]    fp32
    //     alphaRow    [1, N]    fp32
    // outputs
    //     mat [M, N]            fp32

    // alphaCol    [M, 1]    fp32
    using V1Broadcast = cutlass::epilogue::threadblock::VisitorColBroadcast<
        OutputTileThreadMap, ElementC,
        cute::Stride<int32_t, _1, _0>  // StrideMNL
    >;

    // alphaRow    [1, N]    fp32
    using V2Broadcast = cutlass::epilogue::threadblock::VisitorRowBroadcast<
        OutputTileThreadMap, ElementC,
        cute::Stride<_0, _1, int32_t>  // StrideMNL
    >;

The inconsistency is that the SM90 epilogues can handle either row/column broadcasting by passing in a nullptr for the first argument, and a float for the second, while the visitor epilogues cannot. This PR adds this functionality to the visitor epilogues.

I don’t quite understand this PR. Regarding this issue, could you please provide some examples? In what situations won’t it work, and in what situations will it work based on this PR?

@tlrmchlsmth
Copy link
Author

@Hongbosherlock
Compare the cutlass 2.0 epilogues in include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp and
the cutlass 3.0 epilogues in include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp

In the second, the row and column broadcast epilogues (Sm90RowBroadcast and Sm90ColBroadcast) have null_default arguments that are used to provide scalar broadcast functionality. In the first file, the similar row and column broadcast epilogues also have null_default arguments but they simply aren't used.

I tried the approach you suggest for cutlass 2.0 but couldn't get it to compile. If you have a full working example, I'd like to see it :)

Anyway, the same approach won't work for cutlass 3.0, as you will fail this static assert cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>), and the problem this PR is addressing is the inconsistency between these two very similar types.

@thakkarV
Copy link
Collaborator

@hwu36 can we ask Zhaodong to merge this? I don't know his GitHub username

@tlrmchlsmth
Copy link
Author

JFYI I did end up going in a different direction with these epilogue changes. See vllm-project/vllm#5137 -- I found that it was much nicer for a variety of reasons if both the scalar and the vector broadcast cases take a float * that points to device memory as an argument.

bool guard = get<1>(coord_v(i)) < n;
cutlass::arch::global_load<VecType, sizeof(VecType)>(dst_v(i), (void const*)&src_v(i), guard);
}
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: New line after branch close.

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
if(get<1>(coord_v(i)) < n)
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not: no new line before brace open


CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
if(get<1>(coord_v(i)) < n)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: spacing if (get

copy_if(pred, tC_gCol, tC_rCol);

if (params_ptr->ptr_col) {
// In this case we are loading from a column vector and broadcasting
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A design question: this isn't really a scalar operation anymore. Does it make sense to extend this visitor, or to add replace this with a vector broadcast instead that then has a broadcasting layout for its data

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace this with a vector broadcast instead that then has a broadcasting layout for its data

Could you point me to an example of this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for taking a look at the PR BTW :)

@hwu36
Copy link
Collaborator

hwu36 commented Jun 13, 2024

@apuaaChen

@apuaaChen
Copy link

Hi @tlrmchlsmth, thanks for the PR! One question I have is that can we use the VisitorScalarBroadcast to achieve the same target? It also takes a scalar (e.g. float) and broadcast to the whole epilogue tile.

@tlrmchlsmth
Copy link
Author

@apuaaChen We could totally do that, but then in order to have a kernel for every case of fp8 quantized GEMM that we need to support, this is 4x the number of kernels. The activations can have per-tensor or per-token scales and weights can have per-tensor or per-output channel scales. So this PR lets us pick a another point in the binary size/performance tradeoff space.

@apuaaChen
Copy link

@tlrmchlsmth Got it! Let me merge it. Thanks for the explanation.

@ProExpertProg
Copy link

@apuaaChen while @tlrmchlsmth ended up using a custom visitor that loads both a row and a scalar from the float* argument, I ran into this use-case when the scalar is a known constant (0 bias). Again the benefit is reducing code size by having one kernel handle both cases. Could we get this PR merged?

@ProExpertProg
Copy link

I guess could you let me know if you plan to merge it, or if there's any cleanup you want me to do before we merge. I also have a version with a boolean EnableNullptr parameter (default false) that enables the new scalar behavior which is consistent with the c3x epilogues. Let me know if I should push that to this branch.

@apuaaChen
Copy link

@ProExpertProg Please push your changes to this branch. I will first merge your updates to our internal repo. After the CI is passed, I can get your PR merged, thanks!

@ProExpertProg ProExpertProg force-pushed the tms/2x_scalar_broadcast branch from 80a5654 to 0b6c76e Compare July 17, 2024 22:16
@ProExpertProg
Copy link

Perfect, thank you!!

@ProExpertProg
Copy link

And please don't hesitate to ask for any changes or improved comments, and feel free to make edits yourself if there are any style/formatting issues.

Copy link

This PR has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this PR if it is no longer required. Otherwise, please respond with a comment indicating any updates. This PR will be labeled inactive-90d if there is no activity in the next 60 days.

@ProExpertProg
Copy link

@apuaaChen were you able to get the PR run on the internal CI?

@apuaaChen
Copy link

@apuaaChen were you able to get the PR run on the internal CI?

Yes,It passed the internal CI. I’m combining it with a few other fixes right now

Copy link

This PR has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this PR if it is no longer required. Otherwise, please respond with a comment indicating any updates.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request inactive-90d
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants