Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warp specialization as a circular buffering type #3511

Merged
merged 42 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
b605a22
Add warp specialization as a circular buffering type
zasdfgbnm Dec 2, 2024
6a11cdd
revert
zasdfgbnm Dec 2, 2024
1436128
revert
zasdfgbnm Dec 2, 2024
b1a873e
fix
zasdfgbnm Dec 2, 2024
5b3399c
save
zasdfgbnm Dec 2, 2024
28a0931
str
zasdfgbnm Dec 2, 2024
08c2357
save
zasdfgbnm Dec 2, 2024
4eadab7
save
zasdfgbnm Dec 2, 2024
f12c433
save
zasdfgbnm Dec 2, 2024
b705372
fix
zasdfgbnm Dec 2, 2024
5dd18b2
save
zasdfgbnm Dec 2, 2024
0eec871
remove assert
zasdfgbnm Dec 2, 2024
ddae9d7
assert back
zasdfgbnm Dec 2, 2024
ae0122a
save
zasdfgbnm Dec 2, 2024
d30ec31
save
zasdfgbnm Dec 2, 2024
e2cb8d1
save
zasdfgbnm Dec 2, 2024
d8233b3
comment
zasdfgbnm Dec 2, 2024
26675e3
save
zasdfgbnm Dec 2, 2024
732072b
save
zasdfgbnm Dec 2, 2024
ae756cc
save
zasdfgbnm Dec 2, 2024
0e09645
save
zasdfgbnm Dec 2, 2024
93b00ee
outer reduction
zasdfgbnm Dec 3, 2024
7a20b02
save
zasdfgbnm Dec 3, 2024
05a3128
save
zasdfgbnm Dec 3, 2024
f839382
save
zasdfgbnm Dec 3, 2024
eaf1ff5
save
zasdfgbnm Dec 3, 2024
0cc1706
save
zasdfgbnm Dec 3, 2024
70075e9
save
zasdfgbnm Dec 3, 2024
32ea65e
save
zasdfgbnm Dec 3, 2024
b34cf05
save
zasdfgbnm Dec 3, 2024
807122b
save
zasdfgbnm Dec 3, 2024
4990ef1
Merge branch 'main' into warp-specialization-submit
zasdfgbnm Dec 3, 2024
6ad545e
Merge branch 'main' into warp-specialization-submit
zasdfgbnm Dec 3, 2024
5a6ecad
Merge branch 'main' into warp-specialization-submit
zasdfgbnm Dec 3, 2024
d40df72
Merge branch 'main' into warp-specialization-submit
zasdfgbnm Dec 3, 2024
9f13ce4
save
zasdfgbnm Dec 3, 2024
258c31d
save
zasdfgbnm Dec 3, 2024
e2dfda7
comment
zasdfgbnm Dec 3, 2024
88f2e3f
doc
zasdfgbnm Dec 3, 2024
07e59de
error message
zasdfgbnm Dec 4, 2024
8eb0dc0
Warp specialization on x (#3525)
zasdfgbnm Dec 4, 2024
e33f8e2
Update tests/cpp/test_circular_buffering.cpp
zasdfgbnm Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ Expr* initializeMbarrier(
// threads in the CTA.
num_of_arrives = SimplifyingIrBuilder::maybeCastExpr(
DataType::UInt32,
GpuLower::current()->parallelDimensionMap().getNumThreadsEachBlock());
GpuLower::current()
->parallelDimensionMap()
.getNumComputeThreadsEachBlock());
}

// Initialize mbarrier for each circular buffer stage. Use the thread
Expand Down
56 changes: 52 additions & 4 deletions csrc/device_lower/pass/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ class CircularBufferLoopCloner : public kir::IrVisitor {
SimplifyingIrBuilder::create<Val>(opt.prefetch, DataType::Index));
break;
}
case CircularBufferLoopStage::LoadWarp:
case CircularBufferLoopStage::ComputeWarp: {
break;
}
default: {
NVF_THROW("Unsupported loop mode, got: ", loop_type_);
}
Expand Down Expand Up @@ -1246,11 +1250,22 @@ class CircularBufferInserter : private kir::ExprMutator {
return;
}

auto hasCpAsyncBulk = std::any_of(
auto has_cp_async_bulk = std::any_of(
it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk);

if (hasCpAsyncBulk) {
insertTma(loop, it->second);
bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
GpuLower::current()
->circularBufferInfo()
.getCircularBufferOptionsFor(loop->iter_domain())
.type);
if (use_warp_specialization) {
NVF_ERROR(
std::all_of(
it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk),
"In order to use warp specialization, all buffers must be loaded by TMA");
insertTmaWarpSpecialized(loop, it->second);
} else if (has_cp_async_bulk) {
insertTmaPipelined(loop, it->second);
} else {
insert(loop, it->second);
}
Expand Down Expand Up @@ -1315,7 +1330,7 @@ class CircularBufferInserter : private kir::ExprMutator {
.usesMBarrierForWAR();
}

void insertTma(
void insertTmaPipelined(
ForLoop* circular_buffer_loop,
const std::vector<Expr*>& loads) {
// Arrive on the WAR mbarriers to let the prefetching start.
Expand Down Expand Up @@ -1363,6 +1378,39 @@ class CircularBufferInserter : private kir::ExprMutator {
registerInsertAfter(circular_buffer_loop, epilogue_loop);
}

void insertTmaWarpSpecialized(
ForLoop* circular_buffer_loop,
const std::vector<Expr*>& loads) {
const auto& opt =
GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
circular_buffer_loop->iter_domain());
ParallelType warp_specialize_on = std::get<WarpSpecialized>(opt.type).on;

kir::IfThenElse* warp_dispatch_ite = IrBuilder::create<kir::IfThenElse>(
IrBuilder::create<kir::Predicate>(IrBuilder::eqExpr(
NamedScalar::getParallelIndex(warp_specialize_on),
IrBuilder::subExpr(
GpuLower::current()->parallelDimensionMap().get(
warp_specialize_on),
circular_buffer_loop->fusion()->oneVal()))));

// Load loop:
ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
circular_buffer_loop, loads, CircularBufferLoopStage::LoadWarp);
warp_dispatch_ite->thenBody().push_back(load_loop);

// Prefetch:
auto prefetch_loop = createArrivesForWar(circular_buffer_loop);
warp_dispatch_ite->elseBody().push_back(prefetch_loop);

// Compute loop:
ForLoop* compute_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
circular_buffer_loop, loads, CircularBufferLoopStage::ComputeWarp);
warp_dispatch_ite->elseBody().push_back(compute_loop);

registerReplace(circular_buffer_loop, warp_dispatch_ite);
}

void insert(ForLoop* circular_buffer_loop, const std::vector<Expr*>& loads) {
NVF_ERROR(
!usesMBarrierForWAR(circular_buffer_loop),
Expand Down
12 changes: 12 additions & 0 deletions csrc/device_lower/pass/insert_syncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,18 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
last_writes_.pop_front();
// Found that a sync is needed

if (!sync_bitmap.hasBID() &&
std::all_of(
expr->inputs().begin(), expr->inputs().end(), [](Val* val) {
return !val->isA<TensorView>() ||
val->as<TensorView>()->getMemoryType() !=
MemoryType::Shared ||
ir_utils::isCpAsyncBulkLoad(val->definition());
})) {
// RAW of TMA is handled separately, so skip it here.
return;
}

// TODO: Explicitly test the 3 cases below
Expr* sync_expr = nullptr;
kir::Allocate* maybe_alloc = nullptr;
Expand Down
36 changes: 30 additions & 6 deletions csrc/device_lower/pass/predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,17 +246,41 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator {
IrBuilder::create<UnaryOp>(
UnaryOpType::ElectSync, elect_sync_val, full_mask_val);

auto load_warp_loop_it =
std::find_if(for_loops_.begin(), for_loops_.end(), [](ForLoop* fl) {
return fl->circularBufferLoopStage() ==
CircularBufferLoopStage::LoadWarp;
});
ParallelType load_warp_on = ParallelType::Serial;
if (load_warp_loop_it != for_loops_.end()) {
load_warp_on = std::get<WarpSpecialized>(
GpuLower::current()
->circularBufferInfo()
.getCircularBufferOptionsFor(
(*load_warp_loop_it)->iter_domain())
.type)
.on;
}

// If we are in a load warp, then the warp-dispatching IfThenElse
// already selects on `load_warp_on`, so we should not generate
// predicates for it here.
const auto& pdim_map = GpuLower::current()->parallelDimensionMap();
Val* first_warp = IrBuilder::ltExpr(
NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size);
Val* conditional = load_warp_on == ParallelType::TIDx
? pred->fusion()->trueVal()
: SimplifyingIrBuilder::logicalAndExpr(
elect_sync_val,
IrBuilder::ltExpr(
NamedScalar::getParallelIndex(ParallelType::TIDx),
warp_size));
for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) {
if (pdim_map.has(pt)) {
first_warp = SimplifyingIrBuilder::logicalAndExpr(
first_warp,
if (pdim_map.has(pt) && load_warp_on != pt) {
conditional = SimplifyingIrBuilder::logicalAndExpr(
conditional,
IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero));
}
}
return SimplifyingIrBuilder::logicalAndExpr(first_warp, elect_sync_val);
return conditional;
}
default:
break;
Expand Down
90 changes: 81 additions & 9 deletions csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class TVDomainGuard;

//
// /load 0;\ \.
// / load 1; [prefetch = 3] | [prologue]
// / load 1; [prefetch = 3] | [prefetching]
// [stage] load 2;/ /'
// [ = 6 ] load 3; wait load 0; compute 0; \.
// \ load 4; wait load 1; compute 1; |
Expand All @@ -123,7 +123,7 @@ class TVDomainGuard;
// load 2; wait load 5; compute 5; wait compute 3; |
// load 3; wait load 0; compute 0; wait compute 4; |
// load 4; wait load 1; compute 1; wait compute 5; | [main]
// load 5; wait load 2; compute 2; wait compute 0; | [loop]
// load 5; wait load 2; compute 2; wait compute 0; |
// .................................................. |
// .................................................. |
// .................................................. |
Expand All @@ -132,7 +132,7 @@ class TVDomainGuard;
// load ; wait load ; compute ; wait compute ; |
// load ; wait load ; compute ; /'
// /wait load ; compute ; \.
// [same number as prefetch] wait load ; compute ; | [epilogue]
// [same number as prefetch] wait load ; compute ; | [draining]
// \wait load ; compute ; wait all computes; /'

// clang-format on
Expand All @@ -142,19 +142,37 @@ class TVDomainGuard;
// load pipeline depth = prefetch + 1
// compute pipeline depth = stage - prefetch
//
// The above timeline can be implemented as the following loop structure:
// There are two ways to implement the above timeline: pipelined, and
// warp-specialization.
//
// In the pipelined way, the prefetching stage is implemented as a prologue
// loop, and main stage is implemented as a main loop, and the draining stage is
// implemented as an epilogue loop. That is, we will have the following loop
// structure:
//
// Prologue loop:
// for i in range(prefetch):
// load data[i] to buffer[i]
//
// Main loop:
// Main loop (using syncthreads to avoid WAR harzard):
// for i in range(data.size - prefetch):
// load data[i + prefetch] to buffer[(i + prefetch) % stage]
// wait buffer[i % stage] to be ready
// wait buffer[i % stage] to be loaded
// compute buffer[i % stage]
// wait until the first compute in the queue is done
// (i.e. stage - prefetch - 1 in flight computes remaining)
// __syncthreads();
//
// Main loop (using mbarrier to avoid WAR harzard):
// for i in range(data.size - prefetch):
// wait buffer[(i + prefetch) % stage] to be empty
// load data[i + prefetch] to buffer[(i + prefetch) % stage]
// wait buffer[i % stage] to be loaded
// compute buffer[i % stage]
// wait until the first compute in the queue is done
// (i.e. stage - prefetch - 1 in flight computes remaining)
// signal that buffer (i + prefetch + 1) % stage is empty and ready to be
// loaded again
//
// Epilogue loop:
// for i in range(data.size - prefetch, data.size):
Expand All @@ -166,6 +184,30 @@ class TVDomainGuard;
// stage - prefetch - 1 iterations and last iteration of the main loop is
// redundant. We can remove them to further optimize the performance, but
// we decide to keep them for simplicity.
//
// In the warp-specialized approach, we will use different warp/warp-group
// for loading and computing. We will generate code like below (assuming warp
// specialized on TIDy):
//
// if (threadIdx.y == blockDim.y - 1) {
// // If we use warp specialization on TIDy, then the blockDim.y of the
// // kernel will be (whatever_value_inferred_from_schedule + 1), and the
// // last threadIdx.y will be used as load warp
// for i in range(data.size):
// wait buffer[i % stage] to be empty
// load data[i] to buffer[i % stage]
// } else {
// // Every threadIdx.y other than the last will be used for compute
// for i in range(prefetch + 1):
// signal that buffer i % stage is empty and ready to load
// for i in range(data.size):
// wait buffer[i % stage] to be loaded
// compute buffer[i % stage]
// wait until the first compute in the queue is done
// (i.e. stage - prefetch - 1 in flight computes remaining)
// signal that buffer (i + prefetch + 1) % stage is empty and ready to be
// loaded again
// }

struct Pipelined {
bool uses_mbarrier_for_war = false;
Expand All @@ -184,7 +226,36 @@ inline std::ostream& operator<<(std::ostream& os, const Pipelined& pipelined) {
return os << "Pipelined";
}

using CircularBufferType = std::variant<Pipelined>;
struct WarpSpecialized {
ParallelType on;
explicit WarpSpecialized(ParallelType on) : on(on) {}
WarpSpecialized() = default;
bool operator==(const WarpSpecialized& other) const {
return on == other.on;
}
};

inline std::ostream& operator<<(
std::ostream& os,
const WarpSpecialized& warp_specialized) {
std::string parallel_type_str = "";
switch (warp_specialized.on) {
case ParallelType::TIDx:
parallel_type_str = "TIDx";
break;
case ParallelType::TIDy:
parallel_type_str = "TIDy";
break;
case ParallelType::TIDz:
parallel_type_str = "TIDz";
break;
default:
NVF_THROW("Invalid parallel type");
}
return os << "WarpSpecializedOn" << parallel_type_str;
}

using CircularBufferType = std::variant<Pipelined, WarpSpecialized>;

inline std::ostream& operator<<(
std::ostream& os,
Expand All @@ -207,8 +278,9 @@ struct CircularBufferOptions {
}

bool usesMBarrierForWAR() const {
return std::holds_alternative<Pipelined>(type) &&
std::get<Pipelined>(type).uses_mbarrier_for_war;
return (std::holds_alternative<Pipelined>(type) &&
std::get<Pipelined>(type).uses_mbarrier_for_war) ||
std::holds_alternative<WarpSpecialized>(type);
return false;
}

Expand Down
28 changes: 27 additions & 1 deletion csrc/parallel_dimension_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,17 @@ struct hash<PAndID> {
namespace nvfuser {

void ParallelDimensionMap::build(Fusion* fusion) {
VectorOfUniqueEntries<ParallelType> warp_specialized_types;
VectorOfUniqueEntries<PAndID> all_concrete_ids;
auto all_vals = fusion->usedMathVals();
for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
if (tv->isCircularBuffered() &&
std::holds_alternative<WarpSpecialized>(
tv->circularBufferOptions().type)) {
const auto& warp_specialized =
std::get<WarpSpecialized>(tv->circularBufferOptions().type);
warp_specialized_types.pushBack(warp_specialized.on);
}
for (auto id : tv->domain()->allIDs()) {
auto ptype = id->getParallelType();
if (!isParallelTypeThread(ptype)) {
Expand Down Expand Up @@ -83,6 +91,10 @@ void ParallelDimensionMap::build(Fusion* fusion) {
}

adjustMappingsForWarpPadding();

for (auto pt : warp_specialized_types) {
setWarpSpecializeOn(pt);
}
}

void ParallelDimensionMap::adjustMappingsForWarpPadding() {
Expand Down Expand Up @@ -137,6 +149,17 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() {
exact_types_.erase(ParallelType::TIDx);
}

void ParallelDimensionMap::setWarpSpecializeOn(ParallelType pt) {
auto dim_it = dim_map_.find(pt);
if (dim_it == dim_map_.end()) {
dim_map_[pt] = IrBuilder::create<Val>(2, DataType::Index);
} else {
dim_map_[pt] = SimplifyingIrBuilder::addExpr(dim_it->second, 1);
}
exact_types_.erase(pt);
warp_specialized_types_.insert(pt);
}

Val* ParallelDimensionMap::getRaw(ParallelType pt) const {
NVF_ERROR(isParallelTypeThread(pt), "Invalid ParallelType: ", pt);
auto it = dim_map_.find(pt);
Expand All @@ -159,13 +182,16 @@ bool ParallelDimensionMap::isExact(ParallelType pt) const {
return exact_types_.find(pt) != exact_types_.end();
}

Val* ParallelDimensionMap::getNumThreadsEachBlock() const {
Val* ParallelDimensionMap::getNumComputeThreadsEachBlock() const {
Val* num_threads = FusionGuard::getCurFusion()->oneVal();
for (auto pt : kParallelTypeTIDs) {
auto dim = getRaw(pt);
if (dim == nullptr) {
continue;
}
if (warp_specialized_types_.find(pt) != warp_specialized_types_.end()) {
dim = SimplifyingIrBuilder::addExpr(dim, -1);
}
num_threads = SimplifyingIrBuilder::mulExpr(num_threads, dim);
}
return num_threads;
Expand Down
Loading
Loading