-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accept Hopper matmuls and update default heuristic #3579
base: main
Are you sure you want to change the base?
Conversation
This enables Hopper matmul in our automatic scheduler by translating them without introducing new broadcasts. Specifically: 1. Update `mma_utils::MatmulPattern::translateToMmaOp` to optionally avoid intermediates by using an `MmaOp::AxisMapping`. Enable this option when the target arch is not Ampere or Turing. 3. Unguard some tests in `test_translate_mma.cpp` This does not update the default heuristic or change the `canSchedule` checks. See #3579 for that follow-up PR --------- Co-authored-by: Ryan Spring <[email protected]> Co-authored-by: Naoya Maruyama <[email protected]> Co-authored-by: Jingyue Wu <[email protected]> Co-authored-by: nsarka <[email protected]> Co-authored-by: Protonu <[email protected]> Co-authored-by: samnordmann <[email protected]>
Must have been a broken merge
I'm still skipping the ones with batch dimensions on A since these hit an error currently. Will investigate later but we only need 2d A for now.
axis_mapping.a_axes.push_back(d); | ||
} | ||
axis_mapping.a_axes.reserve(out_dim); | ||
for (size_t d : c10::irange(out_dim - 2)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was just due to a busted merge.
macro_encode.n = 256; | ||
while (macro_encode.n >= 8) { | ||
if (n_extent % macro_encode.n != 0) { | ||
macro_encode.n /= 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this only chooses powers of two. For small problems I think we could choose one of the other sizes. For example if n_extent == 72
then we should probably use that size.
|
||
const auto tryIncreaseM = [&]() { | ||
if (ratiosValid(m_ratio + 1, n_ratio)) { | ||
m_ratio++; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these also be powers of two? Currently this will chooses sizes like 192
Should fix this for both matmul and linear, and for avoid_intermediates_ and not
The dtype for stmatrix should have never been constrained to only Half. The only constraint we have is that the dtype size is 16-bit. This PR is needed for us to actually use stmatrix in bfloat16 matmuls.
No description provided.