Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 90 additions & 10 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ struct TransposeViewPropagator : public MaxInfoSpanningTree::Propagator {
bool should_reject = false;
};

bool hasSmallTransposeDimensions(
const std::unique_ptr<TransposeParams>& params) {
bool hasSmallTransposeDimensions(const TransposeParams* params) {
return !params->split_before_tiling.empty() ||
!params->dims_merged_with_1.empty() ||
!params->dims_merged_with_2.empty();
Expand Down Expand Up @@ -579,7 +578,7 @@ std::string getTransposeRuntimeRejectReason(
// 1. view op; and
// 2. small transpose transformation
// See note [Supporting small transpose dimensions]
if (hasSmallTransposeDimensions(params)) {
if (hasSmallTransposeDimensions(params.get())) {
return "Small transpose dimensions and view op cannot be currently be "
"handled by transpose scheduler. See: "
"https://github.com/NVIDIA/Fuser/pull/592";
Expand Down Expand Up @@ -677,7 +676,7 @@ std::unique_ptr<TransposeParams> getTransposeHeuristics(
inner_most_pos2_in_ref1);

NVF_ERROR(
!hasSmallTransposeDimensions(tparams) ||
!hasSmallTransposeDimensions(tparams.get()) ||
scheduler_utils::getViewTVs(fusion).empty(),
"combination of view op with small transpose dimensions are not "
"supported by transpose scheduler");
Expand Down Expand Up @@ -722,6 +721,39 @@ std::unique_ptr<TransposeParams> getTransposeHeuristics(
scan_max_dtype_size(fusion->inputs());
scan_max_dtype_size(fusion->outputs());

// Double tile_size2 if the default configuration doesn't provide enough
// bytes in flight to saturate memory bandwidth. This is based on Little's
// law: bytes_in_flight = bandwidth * latency. We estimate the bits in flight
// per SM as: (sum of input tensor element sizes) * elements_per_tile *
// blocks_per_sm. If this is less than the required bits in flight (derived
// from hardware bandwidth and memory latency), we double tile_size2 to
// increase the data in flight. If tile1 is doubled, it will also double
// shared memory bank conflict, e.g. from 8-ways to 16 ways when increased
// from 32 to 64 assuming vectorization factor is 4, we need 8 or 16 threads
// loading per column.
const auto dev_prop = at::cuda::getCurrentDeviceProperties();
const int64_t max_blocks_per_sm = dev_prop->maxThreadsPerMultiProcessor /
TransposeParams::getMaxThreadsPerBlock();
const int64_t num_elems_per_tile = tparams->tile_size1 * tparams->tile_size2;
const int64_t required_bits_per_sm =
scheduler_utils::getRequiredBitsInFlight();
int64_t total_input_bits_per_elem = 0;
for (auto tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
total_input_bits_per_elem +=
dataTypeSizeBit(tv->getDataType().value(), index_type);
}
const int64_t bits_in_flight_per_sm =
total_input_bits_per_elem * num_elems_per_tile * max_blocks_per_sm;
std::cout << "total_input_bits_per_elem: " << total_input_bits_per_elem
<< std::endl;
std::cout << "num_elems_per_tile: " << num_elems_per_tile << std::endl;
std::cout << "max_blocks_per_sm: " << max_blocks_per_sm << std::endl;
std::cout << "bits_in_flight_per_sm: " << bits_in_flight_per_sm << std::endl;
std::cout << "required_bits_per_sm: " << required_bits_per_sm << std::endl;
if (bits_in_flight_per_sm < required_bits_per_sm) {
tparams->tile_size2 *= 2;
}

auto max_unroll_factor = ceilDiv(
// Available unrolling based on size of data type
kSixteen / max_io_dtype_size,
Expand Down Expand Up @@ -834,7 +866,7 @@ std::unique_ptr<TransposeParams> getTransposeHeuristics(
<< "reference2: " << reference2->toString() << "\n"
<< "inner_most_id2 position: " << inner_most_pos2_in_ref1
<< " (in reference 1)" << std::endl;
if (hasSmallTransposeDimensions(tparams)) {
if (hasSmallTransposeDimensions(tparams.get())) {
debug() << "small transposed dim, needs virtual inner-most dim"
<< std::endl;
}
Expand Down Expand Up @@ -912,26 +944,39 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) {
*/
std::unordered_set<TensorView*> group2_and_cached_inputs(
grouped_inputs_outputs[1].begin(), grouped_inputs_outputs[1].end());
std::vector<TensorView*> smem_cached_input_tvs;
for (auto tv : grouped_inputs_outputs[1]) {
if (tv->isFusionInput()) {
auto existing_cache = ir_utils::consumerTvsOf(tv)[0];
if (ir_utils::consumerTvsOf(existing_cache).size() > 1) {
auto new_cache = tv->cacheAfter();
new_cache->setMemoryType(MemoryType::Shared);
group2_and_cached_inputs.emplace(new_cache);
smem_cached_input_tvs.push_back(new_cache);
} else {
existing_cache->setMemoryType(MemoryType::Shared);
group2_and_cached_inputs.emplace(existing_cache);
smem_cached_input_tvs.push_back(existing_cache);
}
}
}

bool use_smem_swizzle = !hasSmallTransposeDimensions(tparams);
// set cached outputs of group 2 to shared memory
for (const auto& [cached_output, output_idx] : cached_outputs) {
auto output = fusion->outputs()[output_idx]->as<TensorView>();
if (group2_and_cached_inputs.count(output) > 0) {
cached_output->setMemoryType(MemoryType::Shared);
// current smem swizzle only works for cached input
use_smem_swizzle = false;
}
}
// For non-square tile, can't create smem swizzle chunks if tile2 is larger
// and not vectorized
if (tparams->tile_size2 > tparams->tile_size1 &&
tparams->vectorize_factor2 == 1) {
use_smem_swizzle = false;
}

TensorView* reference1 =
domain_map.findReferenceFor(grouped_inputs_outputs[0]);
Expand Down Expand Up @@ -1133,9 +1178,14 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) {
// inputs and outputs themselves are disconnected, so we have to borrow the
// entire DAG and use its spanning tree.
{
auto all_tvs_except1 = ir_utils::allTvsExcept(
fusion,
{grouped_inputs_outputs[0].begin(), grouped_inputs_outputs[0].end()});
std::unordered_set<TensorView*> except_tvs;
except_tvs.insert(
grouped_inputs_outputs[0].begin(), grouped_inputs_outputs[0].end());
if (use_smem_swizzle) {
except_tvs.insert(
smem_cached_input_tvs.begin(), smem_cached_input_tvs.end());
}
auto all_tvs_except1 = ir_utils::allTvsExcept(fusion, except_tvs);
SetSelector selector({all_tvs_except1.begin(), all_tvs_except1.end()});
MaxLogicalDomainInfoSpanningTree entire_dag_except1(reference2, &selector);
TransformPropagator propagator(reference2);
Expand Down Expand Up @@ -1226,8 +1276,14 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) {
// Propagate transformations, parallelization of the reference1 to the entire
// DAG except group 2 and its corresponding cached outputs.
{
auto all_tvs_except2 =
ir_utils::allTvsExcept(fusion, group2_and_cached_inputs);
std::unordered_set<TensorView*> except_tvs;
except_tvs.insert(
group2_and_cached_inputs.begin(), group2_and_cached_inputs.end());
if (use_smem_swizzle) {
except_tvs.insert(
smem_cached_input_tvs.begin(), smem_cached_input_tvs.end());
}
auto all_tvs_except2 = ir_utils::allTvsExcept(fusion, except_tvs);
SetSelector selector({all_tvs_except2.begin(), all_tvs_except2.end()});
MaxLogicalDomainInfoSpanningTree entire_dag_except_outputs(
reference1, &selector);
Expand Down Expand Up @@ -1292,6 +1348,30 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) {
}
}

// schedule smem_cached_input_tvs
if (use_smem_swizzle) {
for (auto tv : smem_cached_input_tvs) {
std::cout << "scheduling smem_cached_tv: " << tv->toString() << std::endl;
int64_t pos = tv->nDims() - 2;
bool is_group2 = group2_and_cached_inputs.count(tv) > 0;
int64_t tile2_factor =
is_group2 ? tparams->vectorize_factor2 : tparams->vectorize_factor1;
int64_t tile1_factor =
tparams->tile_size1 * tile2_factor / tparams->tile_size2;
// [BIDx, UnSwitch, tile1, tile2]
tv->split(pos + 1, tile2_factor);
tv->split(pos, tile1_factor);
tv->swizzle(SwizzleType::XOR, pos, pos + 2);
tv->merge(pos);
tv->merge(pos);
tv->split(pos, tparams->getThreadsPerBlock());
tv->axis(pos)->parallelize(ParallelType::Unroll);
tv->axis(pos + 1)->parallelize(ParallelType::TIDx);
tv->axis(pos + 2)->parallelize(ParallelType::Vectorize);
std::cout << "scheduled smem_cached_tv: " << tv->toString() << std::endl;
}
}

////////////////////////////////
// Step 5: Cleanup and inline //
////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/test_rng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ TEST_F(RNGTest, BroadcastingRNGSmem) {

auto outputs =
scheduleAndRun(
fusion, SchedulerType::Transpose, {input0, input1}, false)
fusion, SchedulerType::PointWise, {input0, input1}, false)
.outputs;
auto out = outputs[0].as<at::Tensor>();

Expand Down