Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into translate_repeat_pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Dec 25, 2024
2 parents 4c1abda + ee63c98 commit e5fcf14
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 73 deletions.
66 changes: 43 additions & 23 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,12 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
TensorView* d_smem = cacheAfter(dc, LoadStoreOpType::Set);

std::vector<TensorView*> tvs_to_schedule{d, d_smem};
if (std::find(mma_results_.begin(), mma_results_.end(), dc) ==
mma_results_.end()) {

bool dc_in_mma_results =
std::find(mma_results_.begin(), mma_results_.end(), dc) !=
mma_results_.end();

if (!dc_in_mma_results) {
// Skip scheduling dc if it is an mma_result. This can happen if we are
// not casting back to half-precision in the output
tvs_to_schedule.push_back(dc);
Expand All @@ -519,14 +523,13 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
dc->setMemoryType(MemoryType::Local);
d_smem->setMemoryType(MemoryType::Shared);

// Set LoadStoreOp
// TODO: extend support when mma is not cast to half
NVF_CHECK(
dataTypeSize(dc->dtype()) == 2,
"We support use_smem_epilogue on Hopper only when the output is 16-bit");
auto store_with_stmatrix = dataTypeSize(dc->dtype()) == 2;

d_smem->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::StMatrix);
if (store_with_stmatrix) {
// Set LoadStoreOp
d_smem->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::StMatrix);
}
d->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);

Expand All @@ -539,23 +542,40 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
}

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
dc,
-1,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
// Should not propagate if the dc is a mma output as the mma output has
// already been scheduled.
if (!dc_in_mma_results) {
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
dc,
-1,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
}

MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem);

// Schedule shared memory cache; Output from StMatrix
mma_utils::scheduleStMatrixForMmaOutput(
d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n);
// [M, N] -> [128(TIDx), N/8 , m(2) , n(2)]
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
d_smem->getLoopDomain());
if (swizzle != MmaInputSmemSwizzle::None) {
// Create tma store allocation domain with swizzle
mma_utils::scheduleTMAStoreForMmaOutput(d_smem, swizzle);
}
d_smem->setLoopDomain(s.as<IterDomain*>());

if (store_with_stmatrix) {
// Schedule shared memory cache; Output from StMatrix
mma_utils::scheduleStMatrixForMmaOutput(
d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n);
}

d_smem->axis(-1)->parallelize(ParallelType::Vectorize);

// Schedule global memory output; Output from TMA Store
mma_utils::scheduleTMAStoreForMmaOutput(d, swizzle);
Expand Down
12 changes: 0 additions & 12 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1315,17 +1315,6 @@ void scheduleStMatrixForMmaOutput(
dataTypeSize(tv->dtype()) == 2,
"we only support 16-bit types in stmatrix");

// [M, N] -> [128(TIDx), N/8 , 2 , 2]
auto s =
mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain());

if (swizzle != MmaInputSmemSwizzle::None) {
// Create tma store allocation domain with swizzle
mma_utils::scheduleTMAStoreForMmaOutput(tv, swizzle);
}

tv->setLoopDomain(s.as<IterDomain*>());

if (tile_m == 16 && tile_n == 16) {
// Let [M, N] be [64, 32]
// After scheduleMmaOutputAllocation: [128(TIDx), 4, 2, 2]
Expand All @@ -1344,7 +1333,6 @@ void scheduleStMatrixForMmaOutput(
// [2, 128(TIDx), 2, 2] -> [2, 128(TIDx), 4(vectorize)]
tv->merge(-2);
}
tv->axis(-1)->parallelize(ParallelType::Vectorize);
}

MatmulOperandInnerDimsOpt getOperandInnerDims(Fusion* fusion) {
Expand Down
39 changes: 19 additions & 20 deletions csrc/scheduler/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ inline int64_t safeDiv(const int64_t x, const int64_t y) {
// `to_update` to the positions in the splitted tensor. Splitting one dimension
// multiple times is supported, and if this is the case, then the order of
// `to_split` matters. All given dimensions are numbers before any split.
NVF_API void splitDims(
void splitDims(
TensorView* tv,
std::vector<std::pair<int64_t, int64_t>> to_split, // (dim, size)
std::vector<int64_t>& to_update);

NVF_API inline void splitDims(
inline void splitDims(
TensorView* tv,
std::vector<std::pair<int64_t, int64_t>> to_split) { // (dim, size)
std::vector<int64_t> unused;
Expand All @@ -126,7 +126,7 @@ NVF_API inline void splitDims(
// merge.
// NOTE: merged is done as the entries in the order of `to_merge`, assuming an
// order from inner to outer
NVF_API std::optional<int64_t> mergeDims(
std::optional<int64_t> mergeDims(
TensorView* tv,
std::vector<int64_t> to_merge,
std::vector<int64_t>& to_update);
Expand All @@ -153,7 +153,7 @@ int64_t mergeNonReduction(TensorView* tv);
// DAG. Empty `selected_tvs` means selecting all tensors in the fusion of
// `reference_tv`. `selected_parallel_types` are the selected parallel types.
// Empty `selected_parallel_types` means selecting all parallel types.
NVF_API void parallelizeAllLike(
void parallelizeAllLike(
TensorView* reference_tv,
int64_t pos = -1,
std::vector<TensorView*> selected_tvs = {},
Expand Down Expand Up @@ -237,7 +237,7 @@ struct PersistentBufferInfo {
// return inputs as being marked persistent if they follow this pattern. It is
// important to note however inputs don't strictly have to be persistent as they
// can simply be read multiple times from GMEM in the same kernel.
NVF_API PersistentBufferInfo persistentBuffers(Fusion* fusion);
PersistentBufferInfo persistentBuffers(Fusion* fusion);

// A persistent tv can be projected to its producers when all the producers are
// persistent tvs and there is no reduction op.
Expand Down Expand Up @@ -304,7 +304,7 @@ struct PersistentBufferSizeReturn {
// persistently, only based on buffers that must be persistent, and based on the
// maximum of all minimum size requirement. i.e. if must be persistent, only
// hold persistent dimension.
NVF_API PersistentBufferSizeReturn persistentBufferSize(
PersistentBufferSizeReturn persistentBufferSize(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
const PersistentBufferInfo& persistent_buffers,
Expand All @@ -321,7 +321,7 @@ std::pair<bool, bool> canonicalDimReduction(
// Return a list of tensor views that are outputs of reduction operations,
// excluding resharding reduce expressions. If multiple outputs of an expression
// are found, only include one in the list
NVF_API std::vector<TensorView*> getReductionTvs(Fusion* fusion);
std::vector<TensorView*> getReductionTvs(Fusion* fusion);

// Returns a list of TensorViews that are the consumer tv for a view operation.
std::vector<TensorView*> getViewTVs(Fusion* fusion);
Expand All @@ -330,15 +330,15 @@ std::vector<TensorView*> getViewTVs(Fusion* fusion);
std::vector<TensorView*> getTVsWithNonReductionRFactor(Fusion* fusion);

// Reset inputs and outputs to global memory, everything else to local.
NVF_API void clearMemorySpace(Fusion* fusion);
void clearMemorySpace(Fusion* fusion);

// Returns cached after tensors of the fusion inputs if unrolled. Otherwise
// return empty vector.
NVF_API std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll);
std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll);

// Returns the pairs of <cache of each fusion output, corresponding output> for
// all outputs.
NVF_API std::vector<std::pair<TensorView*, TensorView*>> cacheAndForkOutputs(
std::vector<std::pair<TensorView*, TensorView*>> cacheAndForkOutputs(
Fusion* fusion,
bool unroll);

Expand Down Expand Up @@ -473,7 +473,7 @@ struct BroadcastMultipleInformation {
//
// logical_reorder_map is provided to assume reference_tv will be reordered per
// the map
NVF_API BroadcastMultipleInformation getBroadcastMultiples(
BroadcastMultipleInformation getBroadcastMultiples(
TensorView* reference_tv,
DataType index_type,
const std::unordered_map<int64_t, int64_t>& logical_reorder_map = {});
Expand Down Expand Up @@ -542,7 +542,7 @@ struct BoundedDirectionalTransformPropagator {
//! Replay transforms from tensorview `from`
//! to the tensorviews that are consumers
//! of boundary tensorviews in `to` and producers of `from`.
NVF_API static void backward(
static void backward(
TensorView* from,
int64_t pos,
std::vector<TensorView*> to,
Expand Down Expand Up @@ -601,22 +601,21 @@ struct BoundedDirectionalTransformPropagator {
// If IterDomains are disjoint in the returned set, then they are considered
// "separable".
// Warning: This pass generates the IdGraphs, not intended for use at runtime.
NVF_API DisjointSets<IterDomain*> disjointLogicalSets(Fusion* fusion);
DisjointSets<IterDomain*> disjointLogicalSets(Fusion* fusion);

// Makes sure that there are no group id's left of pos that match right of pos.
// e.g.
// [1, 0, 0] pos 2 would return false
// [1, 0, 0] pos 1 would return true
NVF_API bool breakIsDisjoint(std::vector<int64_t> group_ids, int64_t pos);
bool breakIsDisjoint(std::vector<int64_t> group_ids, int64_t pos);

// Generates an old to new map to reorder tv's domain as the logical order.
// Priority is given to inner most dimensions for example:
// logical [i0, i1, i2]
// domain [i0*i2, i1]
// will produce the map {{0, 1}, {1, 0}}
// This is somewhat similar to orderTiledConcreteIdAsRoot
NVF_API std::unordered_map<int64_t, int64_t> domainReorderAsLogicalMap(
TensorView* tv);
std::unordered_map<int64_t, int64_t> domainReorderAsLogicalMap(TensorView* tv);

// Generates an old to new map to reorder tv's domain as the logical order.
// This only handles the simple case where allocation is a permutation of
Expand All @@ -629,7 +628,7 @@ std::unordered_map<int64_t, int64_t> maybeLogicalReorderAsAllocationMap(
void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map);

//! Check if tv is an output of a fastest-dim reduction
NVF_API bool isFastestDimReduction(TensorView* tv);
bool isFastestDimReduction(TensorView* tv);

// A wrapper for Fusion::rotateLoop that provide more consistent interace
inline void rotateLoop(
Expand Down Expand Up @@ -670,21 +669,21 @@ inline void rotateLoop(
//! tv1, but the data dependency for the resize op is still satisfied
//! by having a copy of tv1, i.e., tv4. Note that the other op using
//! tv1 still uses tv1.
NVF_API void prepareForMemoryTypePromotion(Fusion* fusion);
void prepareForMemoryTypePromotion(Fusion* fusion);

//! If a consumer tensor induces a data dependency between threads,
//! move its producer to a shared memory that is sufficient to satisfy
//! the dependency. For example, if the domain is parallelized
//! with blockIdx, the producer memory type will be changed to
//! Global. A proper RAW sync will be automatically inserted when the
//! fusion is lowered.
NVF_API void promoteProducerMemoryTypes(
void promoteProducerMemoryTypes(
Fusion* fusion,
const std::vector<TensorView*>& input_caches);

//! Get all tensors that are connected to from_tvs without going through
//! any tvs in the cutoff_tv_set.
NVF_API std::unordered_set<TensorView*> getAllTvsFrom(
std::unordered_set<TensorView*> getAllTvsFrom(
const std::vector<TensorView*>& from_tvs,
const std::unordered_set<TensorView*>& cutoff_tv_set);

Expand Down
26 changes: 18 additions & 8 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2660,7 +2660,7 @@ TEST_F(MatmulSchedulerTest, SegmentMatmulOpUnsupportedDtype) {
testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__);
}

TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) {
TEST_F(MatmulSchedulerTest, PreBroadcastMmaBiasNeg) {
// TODO: fix up params or switch to FusionExecutorCache when ready, then
// enable Ampere
NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
Expand All @@ -2671,12 +2671,20 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) {
// A - tv0, B - tv1
auto tv0 = makeContigConcreteTensor({-1, 1, -1}, DataType::Half);
auto tv1 = makeContigConcreteTensor({1, -1, -1}, DataType::Half);
TensorView* tv2 = makeContigConcreteTensor({-1}, DataType::Half);
fusion->addInput(tv0);
fusion->addInput(tv1);
fusion->addInput(tv2);

auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
auto tv3 = fusedMultiplySum(tv0, tv1, {-1});
// We add these computations to test
// scheduling (with epilogue) when the ouptut of mma is not
// cast to half.
auto tv4 = maybeCastOp(DataType::Float, tv2);
auto tv5 = biasEpilogue(tv3, tv4);
auto tv6 = neg(tv5);

fusion->addOutput(tv2);
fusion->addOutput(tv6);

NVF_CHECK(
1 == ir_utils::getOpsOfType<MmaOp>(fusion.get()).size(),
Expand All @@ -2689,10 +2697,14 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) {
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto a = at::randn({M, K}, options);
auto b = at::randn({N, K}, options);
auto c = at::randn({M}, options);
auto t0 = a.unsqueeze(1);
auto t1 = b.unsqueeze(0);
auto tref = at::matmul(a.to(at::kFloat), b.to(at::kFloat).t());
std::vector<c10::IValue> inputs{t0, t1};
auto tref =
atBiasEpilogue(
at::matmul(a.to(at::kFloat), b.to(at::kFloat).t()), c.to(at::kFloat))
.neg_();
std::vector<c10::IValue> inputs{t0, t1, c};

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 4};
Expand All @@ -2705,9 +2717,7 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) {
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = true;
mparams.circular_buffer_options.smem_circular_buffer_stage = 2;
// TODO: Currently we use stmatrix whenever this is true. We cannot do that
// when the dtype is not 16 bits.
mparams.use_smem_epilogue = false;
mparams.use_smem_epilogue = true;
mparams.promote_prologue_smem_reuse = false;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
Expand Down
12 changes: 8 additions & 4 deletions tests/cpp/test_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2860,14 +2860,18 @@ TEST_P(StMatrixTest, Regular) {
tv0->split(0, 32);
tv0->axis(1)->parallelize(ParallelType::TIDx);

auto s =
mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv1->getLoopDomain());
tv1->setLoopDomain(s.as<IterDomain*>());
tv1->setAllocationDomain(s.as<IterDomain*>(), true);
for (auto tv : {tv1, tv2}) {
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
}
tv1->setAllocationDomain(tv1->getLoopDomain(), true);

mma_utils::scheduleStMatrixForMmaOutput(
tv2, /*swizzle=*/MmaInputSmemSwizzle::None, tile_m, tile_n);

tv2->axis(-1)->parallelize(ParallelType::Vectorize);

tv3->merge(0);
tv3->split(0, 32);
tv3->axis(1)->parallelize(ParallelType::TIDx);
Expand Down
24 changes: 18 additions & 6 deletions tests/cpp/test_mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,12 +515,6 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) {
EXPECT_TRUE(tv3->getMemoryType() == MemoryType::Shared);
EXPECT_TRUE(tv4->getMemoryType() == MemoryType::Global);

{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv3c->getLoopDomain());
tv3c->setLoopDomain(s.as<IterDomain*>());
tv3c->setAllocationDomain(s.as<IterDomain*>(), true);
}
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv2->getLoopDomain());
Expand All @@ -531,8 +525,26 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) {
tv2->axis(-3)->parallelize(ParallelType::Mma);
}

{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv3c->getLoopDomain());
tv3c->setLoopDomain(s.as<IterDomain*>());
tv3c->setAllocationDomain(s.as<IterDomain*>(), true);
}

MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(tv3);
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv3->getLoopDomain());

if (swizzle != MmaInputSmemSwizzle::None) {
mma_utils::scheduleTMAStoreForMmaOutput(tv3, swizzle);
}

tv3->setLoopDomain(s.as<IterDomain*>());
}
mma_utils::scheduleStMatrixForMmaOutput(tv3, swizzle, tile_m, tile_n);
tv3->axis(-1)->parallelize(ParallelType::Vectorize);

mma_utils::scheduleTMAStoreForMmaOutput(tv4, swizzle);

Expand Down

0 comments on commit e5fcf14

Please sign in to comment.