-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
redo register sharing PR-3972 #3993
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Ryan Spring <rspring@nvidia.com>
Review updated until commit df88c18 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
!test |
!test |
!test |
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revise number of padded threads for warp specialization with register sharing to ensure both loading branch and computation branch has 128*N threads.
The test coverage only supports a direct multiple of 128 * N threads. Probably should have had an assertion for this. This PR should expand the test coverage to handle that support.
Is there a direct use for supporting all combinations of CTA shapes that are a multiple of 128?
warp_dispatch_ite->thenBody().push_back(load_loop); | ||
|
||
// Nest load loop inside the warp dispatch if-then-else | ||
if (warp_specilization_pad > 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this nested IfThenElse should be merged with the ElectSync predicate logic at https://github.com/NVIDIA/Fuser/blob/main/csrc/predicate_compute.cpp#L652-L663.
// select 1 thread form the last warp to do TMA load
if (Hopper::electSync(4294967295U) && threadIdx.y == 19) {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do that by adding the following two changes:
(1) Don't need to nest load loop inside the warp dispatch if-then-else
, basically remove changes at
Fuser/csrc/device_lower/pass/circular_buffer.cpp
Line 1445 in e19e529
if (warp_specialization_pad > 1) { |
(2) Revise
createMultipleExpressionElectSync
to add extra predicate.
for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) {
if(!pdim_map.has(pt)){
continue;
}
if (load_warp_on != pt) {
conditional = SimplifyingIrBuilder::logicalAndExpr(
conditional,
IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero));
}else{
Val* raw =
GpuLower::current()->parallelDimensionMap().get(load_warp_on);
conditional = SimplifyingIrBuilder::logicalAndExpr(
conditional,
IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), IrBuilder::subExpr(raw, IrBuilder::create<Val>(1, DataType::Index))));
}
}
The, the generated code is changed from:
if (threadIdx.y == 19) {
Grid-Stride For-loop{
if (Hopper::electSync(4294967295U)) {
// TMA Load
}
}
}
to
Grid-Stride For-loop{
if (Hopper::electSync(4294967295U) && threadIdx.y == 19) {
// TMA Load
}
}
What's the benefit of moving threadIdx.y == 19
to the inside of the ForLoop
? warp diverge is not an issue, since bdimx = 32/42/128, due to better loop handling or to keep consistent with other electSync
? For example we have
bool b18;
b18 = ((nvfuser_index_t)threadIdx.x) < 32ULL;
bool b19;
b19 = ((nvfuser_index_t)threadIdx.y) == 0ULL;
#pragma unroll
for(nvfuser_index_t i22 = 0; i22 < 2; ++i22) {
if (((Hopper::electSync(4294967295U) && b18) && b19)) {
mbarrier::init(toSmem((&T12[i22])), 2U);
}
}
instead of
if(b19){
#pragma unroll
for(nvfuser_index_t i22 = 0; i22 < 2; ++i22) {
if (((Hopper::electSync(4294967295U) && b18))) {
mbarrier::init(toSmem((&T12[i22])), 2U);
}
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the benefit of moving threadIdx.y == 19 to the inside of the ForLoop?
It isn't a CUDA kernel benefit, but a NvFuser lowering refactor.
IfThenElse
nodes are inserted in the UnrollPass
pass without an actual predicate. IfThenElse
nodes are also added in CircularBufferPass
because of warp specialization and to handle mbarriers and TMA operations. i.e., These IfThenElse
do not guard OOB memory access, but how the CTA executes these instructions. Then, the predicate is generated during the generateConditionalFromPredicate
pass.
csrc/codegen.cpp
Outdated
kernel_->hasManaged("increased_register_count"), | ||
"Decreased and increased register count must be set for register sharing warp specialization"); | ||
|
||
int64_t decreased_reg_count = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this check occur when compiling the fusion like a vectorization check? We can have multiple circular buffered loops. You can look up the register count through the kernel summary.
int64_t prefetch = kernel_->summary()
.circular_buffer_info
.getCircularBufferOptionsFor(loop->iter_domain())
.prefetch;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved from fusion managed to kernel summary.
Co-authored-by: Ryan Spring <rspring@nvidia.com>
Co-authored-by: Ryan Spring <rspring@nvidia.com>
Co-authored-by: Ryan Spring <rspring@nvidia.com>
Co-authored-by: Ryan Spring <rspring@nvidia.com>
Co-authored-by: Ryan Spring <rspring@nvidia.com>
Currently, only matmul uses warp specialization with register sharing, test case |
My main objection is that I don't believe this PR has enough test coverage to enable all the CTA shape combinations. Test ideas:
For example:
|
If this is easier, you can break out features 1, hard-code |
!test |
Major changes after previous review:
For example, cta = [128, 2, 1]
(2) Refactored
|
!test |
redo #3972
MIN_BLOCKS_PER_SM = 1
, to ensure register sharing is not ignored by compiler.128*N
threads.ref for setmaxnreg
Generated code sample: