Skip to content

Commit

Permalink
Test horizontal matmul fusion in Llama2FFN test (#3610)
Browse files Browse the repository at this point in the history
This removes some barriers to horizontal fusion and updates the test
which is currently Ampere-only.

Note that most of the horizontal fusion code hasn't been exercised much
so we might continue hitting small snags as we start using it more. My
intention with this PR is to test it automatically by modifying the
test. Likewise, we will need changes to the canSchedule checks and
default heuristics to ensure sane behavior when doing horizontal
fusions, so there will likely be more PRs of this flavor soon.

---------

Co-authored-by: Ryan Spring <[email protected]>
  • Loading branch information
jacobhinkle and rdspring1 authored Dec 19, 2024
1 parent 6fc977a commit 962f002
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 31 deletions.
3 changes: 0 additions & 3 deletions csrc/scheduler/ampere_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -992,9 +992,6 @@ void AmpereMultipleMatmulScheduler::schedulePrologues() {
std::vector<TensorView*>& mma_inputs,
MmaOperand operand_type) {
NVF_ERROR(smem_stores.size() == smem_loads.size());
// TODO: we should not assume that each operand is used in only a single
// mma op
NVF_ERROR(mma_results_.size() >= smem_loads.size());
// We will save abs_ and bbs_ here for later use
// TODO: save all register prologue tensors instead to a new vector called
// prologue_register_tensors_
Expand Down
40 changes: 29 additions & 11 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ void limitCircularBufferingSmemOperands(
inline bool initCoreHeuristics(
MatmulParams* mparams,
const ProblemShape& problem_shape,
const mma_utils::TensorRolesMap& tensor_roles) {
const mma_utils::TensorRolesMap& tensor_roles,
const size_t num_problems) {
const GemmTile instruction_tile = getMmaOpShape(mparams->mma_macro);
GemmTile warp_tile = {-1, -1, -1};
GemmTile cta_tile = {-1, -1, -1};
Expand All @@ -113,7 +114,7 @@ inline bool initCoreHeuristics(
// - start with [4, 4, 2] shape, later it should depend on problem
// shape and have bigger impact on CTA tile shape

const DimType m_ratio = 4;
const DimType m_ratio = 4 / (DimType)num_problems;
const DimType n_ratio = 4;
const DimType k_ratio = 2;

Expand Down Expand Up @@ -264,10 +265,11 @@ std::string isMatmulFusionDefinitionSupported(
{MatmulTensorRole::OPERAND_A, MatmulTensorRole::OPERAND_B}) {
auto entry = tensor_roles.find(role);
if (entry != tensor_roles.end()) {
if (1 == entry->second.size()) {
if (isOptionEnabled(EnableOption::FuseMultipleMatmuls) ||
1 == entry->second.size()) {
tvs_with_roles.insert(entry->second.begin(), entry->second.end());
} else {
return "There is other than one fusion input that can be MMA operand";
return "There is more than one fusion input that can be MMA operand (enable fuse_multiple_matmuls)";
}
} else {
return "No candidate in fusion inputs for MMA operand";
Expand Down Expand Up @@ -370,10 +372,16 @@ class VectorizationCalculator {
MatmulParams::SupportedVectorization compute() {
const std::vector<int64_t> a_vecs =
operandVectorizations(MatmulTensorRole::OPERAND_A);
NVF_ERROR(a_vecs.size() == 1, "Expected exactly one A operand");
NVF_ERROR(
isOptionEnabled(EnableOption::FuseMultipleMatmuls) ||
a_vecs.size() == 1,
"Expected exactly one A operand");
const std::vector<int64_t> b_vecs =
operandVectorizations(MatmulTensorRole::OPERAND_B);
NVF_ERROR(b_vecs.size() == 1, "Expected exactly one B operand");
NVF_ERROR(
isOptionEnabled(EnableOption::FuseMultipleMatmuls) ||
b_vecs.size() == 1,
"Expected exactly one B operand");
return {a_vecs[0], b_vecs[0], epilogueVectorization()};
}

Expand Down Expand Up @@ -703,8 +711,10 @@ std::unique_ptr<MatmulParams> getMatmulHeuristics(
mma_utils::findMatmulPatterns(fusion);
NVF_ERROR(!patterns.empty(), "No matmul patterns were found");
NVF_ERROR(
patterns.size() == 1,
"Only a single matmul pattern can currently be fused");
isOptionEnabled(EnableOption::FuseMultipleMatmuls) ||
patterns.size() == 1,
"Only a single matmul pattern can currently be fused ",
"unless the fuse_multiple_matmuls option is enabled");
mma_utils::MatmulPattern& pattern = patterns.front();

// IdModel is used to analyze problem shape & layout
Expand Down Expand Up @@ -750,14 +760,21 @@ std::unique_ptr<MatmulParams> getMatmulHeuristics(
problem_shape[(size_t)MatmulDimRole::Batch],
inner_dims,
tensor_roles);
// TODO: more sophisticated handling of multiple matmuls when using plugin
mparams->tile_sizes.cta_tile.m /= (int64_t)patterns.size();
} else {
TORCH_WARN_ONCE(
"Scheduling a matmul without heuristic plugin. "
"Specify plugin location like this: "
"NVFUSER_MATMUL_HEURISTIC_PLUGIN=/path/to/libmatmulheuristic.so");
// Populate heuristic details
auto status =
initCoreHeuristics(mparams.get(), problem_shape, tensor_roles);
auto status = initCoreHeuristics(
mparams.get(),
problem_shape,
tensor_roles,
// TODO: this assumes all patterns will lie in the same main loop, which
// might be false
/*num_problems=*/patterns.size());
NVF_ERROR(status, "Initialization of core part of heuristics failed.");
}

Expand Down Expand Up @@ -857,7 +874,8 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) {
}
}

if (patterns.size() > 1) {
if (!isOptionEnabled(EnableOption::FuseMultipleMatmuls) &&
patterns.size() > 1) {
return "Only a single matmul pattern can currently be fused";
}

Expand Down
15 changes: 10 additions & 5 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <scheduler/utils.h>
#include <val_graph.h>
#include <variant>
#include "options.h"

namespace nvfuser {

Expand Down Expand Up @@ -187,7 +188,8 @@ TensorView* getOperandTv(
NVF_ERROR(it != tensor_roles.end(), "Could not find any tensors with role");
const std::vector<TensorView*>& operands = it->second;
NVF_ERROR(
operands.size() == 1,
isOptionEnabled(EnableOption::FuseMultipleMatmuls) ||
operands.size() == 1,
"Exactly one operand is expected in each A and B role");
return operands.front();
}
Expand Down Expand Up @@ -1347,10 +1349,13 @@ void scheduleStMatrixForMmaOutput(
MatmulOperandInnerDimsOpt getOperandInnerDims(Fusion* fusion) {
const std::vector<MatmulPattern> patterns = findMatmulPatterns(fusion);
if (patterns.size() != 1) {
std::stringstream ss;
ss << "Invalid number of MmaOp instances in fusion, expected 1, got "
<< patterns.size();
return ss.str();
if (!isOptionEnabled(EnableOption::FuseMultipleMatmuls)) {
std::stringstream ss;
ss << "Invalid number of MmaOp instances in fusion, expected 1, got "
<< patterns.size();
return ss.str();
}
TORCH_WARN("TODO: Update getOperandInnerDims for multiple patterns");
}
const MatmulPattern& pattern = patterns[0];
IdModel id_model(fusion);
Expand Down
37 changes: 25 additions & 12 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2720,17 +2720,23 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) {
NVF_CHECK(outputs[0].allclose(tref, 0.001, 0.001));
}

class MatmulFusionTest : public MatmulSchedulerTest,
public ::testing::WithParamInterface<bool> {
class MatmulFusionTest
: public MatmulSchedulerTest,
public ::testing::WithParamInterface<std::pair<bool, bool>> {
protected:
void SetUp() override {
if (fusion_enabled) {
EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul);
}
if (horizontal_fusion_enabled) {
EnableOptionsGuard::getCurOptions().set(
EnableOption::FuseMultipleMatmuls);
}
}

EnableOptionsGuard eog_;
bool fusion_enabled = GetParam();
bool fusion_enabled = GetParam().first;
bool horizontal_fusion_enabled = GetParam().second;
};

// Test that we can segment a Fusion containing two matmuls
Expand Down Expand Up @@ -2788,21 +2794,28 @@ TEST_P(MatmulFusionTest, Llama2FFN) {
const FusionKernelRuntime* runtime =
executor_cache.getMostRecentKernelRuntime();

EXPECT_TRUE(runtime->isSegmented());
size_t expected_kernels =
fusion_enabled ? (horizontal_fusion_enabled ? 1 : 2) : 3;

if (fusion_enabled) {
EXPECT_EQ(runtime->fusionSegments()->groups().size(), 2);
} else {
EXPECT_EQ(runtime->fusionSegments()->groups().size(), 3);
}
EXPECT_EQ(runtime->fusionSegments()->groups().size(), expected_kernels);
}

INSTANTIATE_TEST_SUITE_P(
,
MatmulFusionTest,
::testing::Bool(),
[](const testing::TestParamInfo<bool>& info) {
return info.param ? "fuse" : "dontfuse";
::testing::ValuesIn(std::vector<std::pair<bool, bool>>{
{false, false},
{true, false},
{true, true}}),
[](const testing::TestParamInfo<std::pair<bool, bool>>& info) {
bool fuse = info.param.first;
bool horiz_fuse = info.param.second;
if (horiz_fuse) {
NVF_ERROR(
fuse, "Horizontal fusion enabled but overall fusion disabled");
}
return fuse ? (horiz_fuse ? "fuse_horizontal" : "fuse_single")
: "dontfuse";
});

// This test can be used to check that an external plugin has been loaded. It
Expand Down

0 comments on commit 962f002

Please sign in to comment.