Skip to content

Commit 64bc560

Browse files
zasdfgbnmrdspring1
andauthored
Add warp specialization as a circular buffering type (#3511)
This PR adds warp specialization as a new type of circular buffering. Today, we already support pipelined circular buffer, and optionally, we could choose whether we want to use block-sync or mbarrier for handling WAR hazards. If we choose to use mbarrier for handling WAR harzard, then we will generate kernel like below: ```python # Mark buffer[i] as empty and ready to be loaded for i in range(prefetch + 1): arrive(war_mbarrier[i]) # Prologue: thanks to the previous arrives, all the loads will just go through and no wait needed for i in range(prefetch): wait war_mbarrier[i] arrive-expect-tx raw_mbarrier[i] load data[i] into buffer[i] # Main loop: for i in range(data.size - prefetch): if elect-sync: wait war_mbarrier[(i + prefetch) % stage] arrive-expect-tx raw_mbarrier[(i + prefetch) % stage] load data[i + prefetch] to buffer[(i + prefetch) % stage] wait raw_mbarrier[i % stage] mma on buffer[i % stage] for data[i] wait until there are at most stage - prefetch - 1 pending mma arrive war_mbarrier[(i + prefetch + 1) % stage] # Epilogue for i in range(data.size - prefetch, data.size): wait raw_mbarrier[i % stage] mma on buffer[i % stage] for data[i] wait until there are at most 0 pending mma write result back to gmem ``` The kernel above has the following problems: 1. The MMA loop is not clean. There is one thread doing an extra work of loading, while other threads in the warp groups just waiting this one thread to finish. (Note that mma is a warp-group collective, so all threads in the warp group must arrive that instruction for it to start). Ideally, we should have a for loop with only mma, and nothing else. Having extra instructions could increase the latency. 2. There is a false dependency between the loading of `data[i + prefetch]` and the computing of `data[i]`. These two things are not dealing with the same data, so in theory, they should not depend on each other, and whoever gets its mbarrier cleared first should go first. However, just because codes are executed from top to bottom, the mma has to wait until the load is issued. This further increases latency. With the above problem observed, it is naturally to ask: why not use different warps for load and compute? The load code and the compute code in the main loop are completely independent, and both the RAW and WAR are handled by mbarrier, which is on smem and accessible across the entire CTA, so all the preconditions for warp specialization are mature, and we just need to put different IR nodes into different places. This PR adds warp specialization. The generated code is similar to the pipelined code that uses mbarrier for WAR, but actually simpler. The code looks like below (assuming doing warp specialization on TIDy): ```python if threadIdx.y == blockDim.y - 1: # If we use warp specialization on TIDy, then the blockDim.y of the # kernel will be (whatever_value_inferred_from_schedule + 1), and the # last threadIdx.y will be used as load warp for i in range(data.size): wait war_mbarrier[i % stage] load data[i] to buffer[i % stage] else: # Every threadIdx.y other than the last will be used for compute for i in range(prefetch + 1): arrive war_mbarrier[i % stage] for i in range(data.size): wait raw_mbarrier[i % stage] compute buffer[i % stage] wait until there are at most stage - prefetch - 1 pending mma arrive war_mbarrier[(i + prefetch + 1) % stage] ``` This new way of doing circular buffering is intended to be computation-agnostic, it should work on whatever kernel we are scheduling, instead of just matmuls. But note that today, there are some strong limitations that makes it less applicable: 1. The computation can not have hardcoded `blockDim` in it. So block reduction will not work. I believe this will be easy to fix, but it is beyond the scope of this PR. 2. Because the warp-specialized parallel type will no longer be exact, there will be thread predicates generated for it. Predication elimination is not yet smart enough to know that this is in the compute warp, so already predicated and not need to predicate it again. This limitation also means, the computation can not be tensor core operations (`MmaOp`), so this PR actually does not work with matmul. Besides the above limitation, I believe this new circular buffer type is pretty generic, and in the future, I believe we should be able to try it with TMA in perf tuning. --------- Co-authored-by: Ryan Spring <[email protected]>
1 parent 1dda106 commit 64bc560

10 files changed

+268
-31
lines changed

csrc/device_lower/pass/allocation.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ Expr* initializeMbarrier(
8989
// threads in the CTA.
9090
num_of_arrives = SimplifyingIrBuilder::maybeCastExpr(
9191
DataType::UInt32,
92-
GpuLower::current()->parallelDimensionMap().getNumThreadsEachBlock());
92+
GpuLower::current()
93+
->parallelDimensionMap()
94+
.getNumComputeThreadsEachBlock());
9395
}
9496

9597
// Initialize mbarrier for each circular buffer stage. Use the thread

csrc/device_lower/pass/circular_buffer.cpp

+52-4
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ class CircularBufferLoopCloner : public kir::IrVisitor {
106106
SimplifyingIrBuilder::create<Val>(opt.prefetch, DataType::Index));
107107
break;
108108
}
109+
case CircularBufferLoopStage::LoadWarp:
110+
case CircularBufferLoopStage::ComputeWarp: {
111+
break;
112+
}
109113
default: {
110114
NVF_THROW("Unsupported loop mode, got: ", loop_type_);
111115
}
@@ -1246,11 +1250,22 @@ class CircularBufferInserter : private kir::ExprMutator {
12461250
return;
12471251
}
12481252

1249-
auto hasCpAsyncBulk = std::any_of(
1253+
auto has_cp_async_bulk = std::any_of(
12501254
it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk);
12511255

1252-
if (hasCpAsyncBulk) {
1253-
insertTma(loop, it->second);
1256+
bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
1257+
GpuLower::current()
1258+
->circularBufferInfo()
1259+
.getCircularBufferOptionsFor(loop->iter_domain())
1260+
.type);
1261+
if (use_warp_specialization) {
1262+
NVF_ERROR(
1263+
std::all_of(
1264+
it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk),
1265+
"In order to use warp specialization, all buffers must be loaded by TMA");
1266+
insertTmaWarpSpecialized(loop, it->second);
1267+
} else if (has_cp_async_bulk) {
1268+
insertTmaPipelined(loop, it->second);
12541269
} else {
12551270
insert(loop, it->second);
12561271
}
@@ -1315,7 +1330,7 @@ class CircularBufferInserter : private kir::ExprMutator {
13151330
.usesMBarrierForWAR();
13161331
}
13171332

1318-
void insertTma(
1333+
void insertTmaPipelined(
13191334
ForLoop* circular_buffer_loop,
13201335
const std::vector<Expr*>& loads) {
13211336
// Arrive on the WAR mbarriers to let the prefetching start.
@@ -1363,6 +1378,39 @@ class CircularBufferInserter : private kir::ExprMutator {
13631378
registerInsertAfter(circular_buffer_loop, epilogue_loop);
13641379
}
13651380

1381+
void insertTmaWarpSpecialized(
1382+
ForLoop* circular_buffer_loop,
1383+
const std::vector<Expr*>& loads) {
1384+
const auto& opt =
1385+
GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
1386+
circular_buffer_loop->iter_domain());
1387+
ParallelType warp_specialize_on = std::get<WarpSpecialized>(opt.type).on;
1388+
1389+
kir::IfThenElse* warp_dispatch_ite = IrBuilder::create<kir::IfThenElse>(
1390+
IrBuilder::create<kir::Predicate>(IrBuilder::eqExpr(
1391+
NamedScalar::getParallelIndex(warp_specialize_on),
1392+
IrBuilder::subExpr(
1393+
GpuLower::current()->parallelDimensionMap().get(
1394+
warp_specialize_on),
1395+
circular_buffer_loop->fusion()->oneVal()))));
1396+
1397+
// Load loop:
1398+
ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
1399+
circular_buffer_loop, loads, CircularBufferLoopStage::LoadWarp);
1400+
warp_dispatch_ite->thenBody().push_back(load_loop);
1401+
1402+
// Prefetch:
1403+
auto prefetch_loop = createArrivesForWar(circular_buffer_loop);
1404+
warp_dispatch_ite->elseBody().push_back(prefetch_loop);
1405+
1406+
// Compute loop:
1407+
ForLoop* compute_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
1408+
circular_buffer_loop, loads, CircularBufferLoopStage::ComputeWarp);
1409+
warp_dispatch_ite->elseBody().push_back(compute_loop);
1410+
1411+
registerReplace(circular_buffer_loop, warp_dispatch_ite);
1412+
}
1413+
13661414
void insert(ForLoop* circular_buffer_loop, const std::vector<Expr*>& loads) {
13671415
NVF_ERROR(
13681416
!usesMBarrierForWAR(circular_buffer_loop),

csrc/device_lower/pass/insert_syncs.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,18 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
468468
last_writes_.pop_front();
469469
// Found that a sync is needed
470470

471+
if (!sync_bitmap.hasBID() &&
472+
std::all_of(
473+
expr->inputs().begin(), expr->inputs().end(), [](Val* val) {
474+
return !val->isA<TensorView>() ||
475+
val->as<TensorView>()->getMemoryType() !=
476+
MemoryType::Shared ||
477+
ir_utils::isCpAsyncBulkLoad(val->definition());
478+
})) {
479+
// RAW of TMA is handled separately, so skip it here.
480+
return;
481+
}
482+
471483
// TODO: Explicitly test the 3 cases below
472484
Expr* sync_expr = nullptr;
473485
kir::Allocate* maybe_alloc = nullptr;

csrc/device_lower/pass/predicate.cpp

+30-6
Original file line numberDiff line numberDiff line change
@@ -246,17 +246,41 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator {
246246
IrBuilder::create<UnaryOp>(
247247
UnaryOpType::ElectSync, elect_sync_val, full_mask_val);
248248

249+
auto load_warp_loop_it =
250+
std::find_if(for_loops_.begin(), for_loops_.end(), [](ForLoop* fl) {
251+
return fl->circularBufferLoopStage() ==
252+
CircularBufferLoopStage::LoadWarp;
253+
});
254+
ParallelType load_warp_on = ParallelType::Serial;
255+
if (load_warp_loop_it != for_loops_.end()) {
256+
load_warp_on = std::get<WarpSpecialized>(
257+
GpuLower::current()
258+
->circularBufferInfo()
259+
.getCircularBufferOptionsFor(
260+
(*load_warp_loop_it)->iter_domain())
261+
.type)
262+
.on;
263+
}
264+
265+
// If we are in a load warp, then the warp-dispatching IfThenElse
266+
// already selects on `load_warp_on`, so we should not generate
267+
// predicates for it here.
249268
const auto& pdim_map = GpuLower::current()->parallelDimensionMap();
250-
Val* first_warp = IrBuilder::ltExpr(
251-
NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size);
269+
Val* conditional = load_warp_on == ParallelType::TIDx
270+
? pred->fusion()->trueVal()
271+
: SimplifyingIrBuilder::logicalAndExpr(
272+
elect_sync_val,
273+
IrBuilder::ltExpr(
274+
NamedScalar::getParallelIndex(ParallelType::TIDx),
275+
warp_size));
252276
for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) {
253-
if (pdim_map.has(pt)) {
254-
first_warp = SimplifyingIrBuilder::logicalAndExpr(
255-
first_warp,
277+
if (pdim_map.has(pt) && load_warp_on != pt) {
278+
conditional = SimplifyingIrBuilder::logicalAndExpr(
279+
conditional,
256280
IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero));
257281
}
258282
}
259-
return SimplifyingIrBuilder::logicalAndExpr(first_warp, elect_sync_val);
283+
return conditional;
260284
}
261285
default:
262286
break;

csrc/ir/interface_nodes.h

+81-9
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class TVDomainGuard;
113113

114114
//
115115
// /load 0;\ \.
116-
// / load 1; [prefetch = 3] | [prologue]
116+
// / load 1; [prefetch = 3] | [prefetching]
117117
// [stage] load 2;/ /'
118118
// [ = 6 ] load 3; wait load 0; compute 0; \.
119119
// \ load 4; wait load 1; compute 1; |
@@ -123,7 +123,7 @@ class TVDomainGuard;
123123
// load 2; wait load 5; compute 5; wait compute 3; |
124124
// load 3; wait load 0; compute 0; wait compute 4; |
125125
// load 4; wait load 1; compute 1; wait compute 5; | [main]
126-
// load 5; wait load 2; compute 2; wait compute 0; | [loop]
126+
// load 5; wait load 2; compute 2; wait compute 0; |
127127
// .................................................. |
128128
// .................................................. |
129129
// .................................................. |
@@ -132,7 +132,7 @@ class TVDomainGuard;
132132
// load ; wait load ; compute ; wait compute ; |
133133
// load ; wait load ; compute ; /'
134134
// /wait load ; compute ; \.
135-
// [same number as prefetch] wait load ; compute ; | [epilogue]
135+
// [same number as prefetch] wait load ; compute ; | [draining]
136136
// \wait load ; compute ; wait all computes; /'
137137

138138
// clang-format on
@@ -142,19 +142,37 @@ class TVDomainGuard;
142142
// load pipeline depth = prefetch + 1
143143
// compute pipeline depth = stage - prefetch
144144
//
145-
// The above timeline can be implemented as the following loop structure:
145+
// There are two ways to implement the above timeline: pipelined, and
146+
// warp-specialization.
147+
//
148+
// In the pipelined way, the prefetching stage is implemented as a prologue
149+
// loop, and main stage is implemented as a main loop, and the draining stage is
150+
// implemented as an epilogue loop. That is, we will have the following loop
151+
// structure:
146152
//
147153
// Prologue loop:
148154
// for i in range(prefetch):
149155
// load data[i] to buffer[i]
150156
//
151-
// Main loop:
157+
// Main loop (using syncthreads to avoid WAR harzard):
152158
// for i in range(data.size - prefetch):
153159
// load data[i + prefetch] to buffer[(i + prefetch) % stage]
154-
// wait buffer[i % stage] to be ready
160+
// wait buffer[i % stage] to be loaded
155161
// compute buffer[i % stage]
156162
// wait until the first compute in the queue is done
157163
// (i.e. stage - prefetch - 1 in flight computes remaining)
164+
// __syncthreads();
165+
//
166+
// Main loop (using mbarrier to avoid WAR harzard):
167+
// for i in range(data.size - prefetch):
168+
// wait buffer[(i + prefetch) % stage] to be empty
169+
// load data[i + prefetch] to buffer[(i + prefetch) % stage]
170+
// wait buffer[i % stage] to be loaded
171+
// compute buffer[i % stage]
172+
// wait until the first compute in the queue is done
173+
// (i.e. stage - prefetch - 1 in flight computes remaining)
174+
// signal that buffer (i + prefetch + 1) % stage is empty and ready to be
175+
// loaded again
158176
//
159177
// Epilogue loop:
160178
// for i in range(data.size - prefetch, data.size):
@@ -166,6 +184,30 @@ class TVDomainGuard;
166184
// stage - prefetch - 1 iterations and last iteration of the main loop is
167185
// redundant. We can remove them to further optimize the performance, but
168186
// we decide to keep them for simplicity.
187+
//
188+
// In the warp-specialized approach, we will use different warp/warp-group
189+
// for loading and computing. We will generate code like below (assuming warp
190+
// specialized on TIDy):
191+
//
192+
// if (threadIdx.y == blockDim.y - 1) {
193+
// // If we use warp specialization on TIDy, then the blockDim.y of the
194+
// // kernel will be (whatever_value_inferred_from_schedule + 1), and the
195+
// // last threadIdx.y will be used as load warp
196+
// for i in range(data.size):
197+
// wait buffer[i % stage] to be empty
198+
// load data[i] to buffer[i % stage]
199+
// } else {
200+
// // Every threadIdx.y other than the last will be used for compute
201+
// for i in range(prefetch + 1):
202+
// signal that buffer i % stage is empty and ready to load
203+
// for i in range(data.size):
204+
// wait buffer[i % stage] to be loaded
205+
// compute buffer[i % stage]
206+
// wait until the first compute in the queue is done
207+
// (i.e. stage - prefetch - 1 in flight computes remaining)
208+
// signal that buffer (i + prefetch + 1) % stage is empty and ready to be
209+
// loaded again
210+
// }
169211

170212
struct Pipelined {
171213
bool uses_mbarrier_for_war = false;
@@ -184,7 +226,36 @@ inline std::ostream& operator<<(std::ostream& os, const Pipelined& pipelined) {
184226
return os << "Pipelined";
185227
}
186228

187-
using CircularBufferType = std::variant<Pipelined>;
229+
struct WarpSpecialized {
230+
ParallelType on;
231+
explicit WarpSpecialized(ParallelType on) : on(on) {}
232+
WarpSpecialized() = default;
233+
bool operator==(const WarpSpecialized& other) const {
234+
return on == other.on;
235+
}
236+
};
237+
238+
inline std::ostream& operator<<(
239+
std::ostream& os,
240+
const WarpSpecialized& warp_specialized) {
241+
std::string parallel_type_str = "";
242+
switch (warp_specialized.on) {
243+
case ParallelType::TIDx:
244+
parallel_type_str = "TIDx";
245+
break;
246+
case ParallelType::TIDy:
247+
parallel_type_str = "TIDy";
248+
break;
249+
case ParallelType::TIDz:
250+
parallel_type_str = "TIDz";
251+
break;
252+
default:
253+
NVF_THROW("Invalid parallel type");
254+
}
255+
return os << "WarpSpecializedOn" << parallel_type_str;
256+
}
257+
258+
using CircularBufferType = std::variant<Pipelined, WarpSpecialized>;
188259

189260
inline std::ostream& operator<<(
190261
std::ostream& os,
@@ -207,8 +278,9 @@ struct CircularBufferOptions {
207278
}
208279

209280
bool usesMBarrierForWAR() const {
210-
return std::holds_alternative<Pipelined>(type) &&
211-
std::get<Pipelined>(type).uses_mbarrier_for_war;
281+
return (std::holds_alternative<Pipelined>(type) &&
282+
std::get<Pipelined>(type).uses_mbarrier_for_war) ||
283+
std::holds_alternative<WarpSpecialized>(type);
212284
return false;
213285
}
214286

csrc/parallel_dimension_map.cpp

+27-1
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,17 @@ struct hash<PAndID> {
3838
namespace nvfuser {
3939

4040
void ParallelDimensionMap::build(Fusion* fusion) {
41+
VectorOfUniqueEntries<ParallelType> warp_specialized_types;
4142
VectorOfUniqueEntries<PAndID> all_concrete_ids;
4243
auto all_vals = fusion->usedMathVals();
4344
for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
45+
if (tv->isCircularBuffered() &&
46+
std::holds_alternative<WarpSpecialized>(
47+
tv->circularBufferOptions().type)) {
48+
const auto& warp_specialized =
49+
std::get<WarpSpecialized>(tv->circularBufferOptions().type);
50+
warp_specialized_types.pushBack(warp_specialized.on);
51+
}
4452
for (auto id : tv->domain()->allIDs()) {
4553
auto ptype = id->getParallelType();
4654
if (!isParallelTypeThread(ptype)) {
@@ -83,6 +91,10 @@ void ParallelDimensionMap::build(Fusion* fusion) {
8391
}
8492

8593
adjustMappingsForWarpPadding();
94+
95+
for (auto pt : warp_specialized_types) {
96+
setWarpSpecializeOn(pt);
97+
}
8698
}
8799

88100
void ParallelDimensionMap::adjustMappingsForWarpPadding() {
@@ -137,6 +149,17 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() {
137149
exact_types_.erase(ParallelType::TIDx);
138150
}
139151

152+
void ParallelDimensionMap::setWarpSpecializeOn(ParallelType pt) {
153+
auto dim_it = dim_map_.find(pt);
154+
if (dim_it == dim_map_.end()) {
155+
dim_map_[pt] = IrBuilder::create<Val>(2, DataType::Index);
156+
} else {
157+
dim_map_[pt] = SimplifyingIrBuilder::addExpr(dim_it->second, 1);
158+
}
159+
exact_types_.erase(pt);
160+
warp_specialized_types_.insert(pt);
161+
}
162+
140163
Val* ParallelDimensionMap::getRaw(ParallelType pt) const {
141164
NVF_ERROR(isParallelTypeThread(pt), "Invalid ParallelType: ", pt);
142165
auto it = dim_map_.find(pt);
@@ -159,13 +182,16 @@ bool ParallelDimensionMap::isExact(ParallelType pt) const {
159182
return exact_types_.find(pt) != exact_types_.end();
160183
}
161184

162-
Val* ParallelDimensionMap::getNumThreadsEachBlock() const {
185+
Val* ParallelDimensionMap::getNumComputeThreadsEachBlock() const {
163186
Val* num_threads = FusionGuard::getCurFusion()->oneVal();
164187
for (auto pt : kParallelTypeTIDs) {
165188
auto dim = getRaw(pt);
166189
if (dim == nullptr) {
167190
continue;
168191
}
192+
if (warp_specialized_types_.find(pt) != warp_specialized_types_.end()) {
193+
dim = SimplifyingIrBuilder::addExpr(dim, -1);
194+
}
169195
num_threads = SimplifyingIrBuilder::mulExpr(num_threads, dim);
170196
}
171197
return num_threads;

0 commit comments

Comments
 (0)