Skip to content

Commit

Permalink
Merge branch 'main' into jjsjann123/rope_benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 authored Dec 19, 2024
2 parents 0852707 + 962f002 commit 274968b
Show file tree
Hide file tree
Showing 17 changed files with 417 additions and 168 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ list(APPEND JIT_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_tensor_factories.cpp
${NVFUSER_ROOT}/tests/cpp/test_unary.cpp
${NVFUSER_ROOT}/tests/cpp/test_utils.cpp
${NVFUSER_ROOT}/tests/cpp/test_vectorization_analysis.cpp
)

if(BUILD_TEST)
Expand Down
1 change: 0 additions & 1 deletion benchmarks/python/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def torchprofile_timer(self) -> float:
# Clear the internal profiler object to avoid accumulating function events and then restart the profiler
# See PR: https://github.com/pytorch/pytorch/pull/125510
self.prof.profiler = None
self.prof.start()

return self.current_time

Expand Down
6 changes: 2 additions & 4 deletions csrc/polymorphic_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// clang-format on
#include <polymorphic_value.h>
#include <type.h>
#include <utils.h>

#include <string>

Expand Down Expand Up @@ -44,10 +45,7 @@ namespace PolymorphicValue_functions {
std::string toString(const PolymorphicValue& v) {
std::stringstream ss;
if (v.is<at::Tensor>()) {
const auto& t = v.as<at::Tensor>();
ss << "Tensor(sizes=" << t.sizes() << ", "
<< "stride=" << t.strides() << ", dtype=" << t.dtype()
<< ", device=" << t.device() << ", data_ptr=" << t.data_ptr() << ")";
ss << debug_str(v.as<at::Tensor>());
} else if (v.is<std::monostate>()) {
ss << "std::monostate";
} else if (v.is<StructHandle>()) {
Expand Down
6 changes: 5 additions & 1 deletion csrc/runtime/executor_kernel_arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ void KernelArgumentHolder::erase(const PolymorphicValue* arg_to_delete) {
std::string KernelArgumentHolder::toString() const {
std::stringstream ss;
for (const auto& arg : arguments_) {
ss << *arg << "\n";
if (arg->is<at::Tensor>()) {
ss << debug_str(arg->as<at::Tensor>()) << "\n";
} else {
ss << *arg << "\n";
}
}
return ss.str();
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/runtime/fusion_cache_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ ArgumentManager::ArgumentManager(
}

const std::unordered_map<Val*, const PolymorphicValue*>& ArgumentManager::
getTensorMap() {
getTensorMap() const {
return tensor_map_;
}
const PolymorphicValue* ArgumentManager::checkTensorMap(Val* v) {
Expand Down
13 changes: 12 additions & 1 deletion csrc/runtime/fusion_cache_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class ArgumentManager {
const RuntimeWorkSpace& runtime_workspace,
const std::vector<Val*>& fusion_inputs);

const std::unordered_map<Val*, const PolymorphicValue*>& getTensorMap();
const std::unordered_map<Val*, const PolymorphicValue*>& getTensorMap() const;

const PolymorphicValue* checkTensorMap(Val* v);

Expand All @@ -104,6 +104,17 @@ class ArgumentManager {
const T& group_runtime_outputs,
const int64_t group_id);

std::string toString() const {
std::stringstream ss;
ss << "ArgumentManager {";
for (auto entry : tensor_map_) {
ss << " " << entry.first->toString() << " : "
<< PolymorphicValue_functions::toString(*entry.second) << std::endl;
}
ss << "}" << std::endl;
return ss.str();
}

private:
KernelArgumentHolder& fusion_args_;
// map from val to args
Expand Down
3 changes: 0 additions & 3 deletions csrc/scheduler/ampere_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -992,9 +992,6 @@ void AmpereMultipleMatmulScheduler::schedulePrologues() {
std::vector<TensorView*>& mma_inputs,
MmaOperand operand_type) {
NVF_ERROR(smem_stores.size() == smem_loads.size());
// TODO: we should not assume that each operand is used in only a single
// mma op
NVF_ERROR(mma_results_.size() >= smem_loads.size());
// We will save abs_ and bbs_ here for later use
// TODO: save all register prologue tensors instead to a new vector called
// prologue_register_tensors_
Expand Down
40 changes: 29 additions & 11 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ void limitCircularBufferingSmemOperands(
inline bool initCoreHeuristics(
MatmulParams* mparams,
const ProblemShape& problem_shape,
const mma_utils::TensorRolesMap& tensor_roles) {
const mma_utils::TensorRolesMap& tensor_roles,
const size_t num_problems) {
const GemmTile instruction_tile = getMmaOpShape(mparams->mma_macro);
GemmTile warp_tile = {-1, -1, -1};
GemmTile cta_tile = {-1, -1, -1};
Expand All @@ -113,7 +114,7 @@ inline bool initCoreHeuristics(
// - start with [4, 4, 2] shape, later it should depend on problem
// shape and have bigger impact on CTA tile shape

const DimType m_ratio = 4;
const DimType m_ratio = 4 / (DimType)num_problems;
const DimType n_ratio = 4;
const DimType k_ratio = 2;

Expand Down Expand Up @@ -264,10 +265,11 @@ std::string isMatmulFusionDefinitionSupported(
{MatmulTensorRole::OPERAND_A, MatmulTensorRole::OPERAND_B}) {
auto entry = tensor_roles.find(role);
if (entry != tensor_roles.end()) {
if (1 == entry->second.size()) {
if (isOptionEnabled(EnableOption::FuseMultipleMatmuls) ||
1 == entry->second.size()) {
tvs_with_roles.insert(entry->second.begin(), entry->second.end());
} else {
return "There is other than one fusion input that can be MMA operand";
return "There is more than one fusion input that can be MMA operand (enable fuse_multiple_matmuls)";
}
} else {
return "No candidate in fusion inputs for MMA operand";
Expand Down Expand Up @@ -370,10 +372,16 @@ class VectorizationCalculator {
MatmulParams::SupportedVectorization compute() {
const std::vector<int64_t> a_vecs =
operandVectorizations(MatmulTensorRole::OPERAND_A);
NVF_ERROR(a_vecs.size() == 1, "Expected exactly one A operand");
NVF_ERROR(
isOptionEnabled(EnableOption::FuseMultipleMatmuls) ||
a_vecs.size() == 1,
"Expected exactly one A operand");
const std::vector<int64_t> b_vecs =
operandVectorizations(MatmulTensorRole::OPERAND_B);
NVF_ERROR(b_vecs.size() == 1, "Expected exactly one B operand");
NVF_ERROR(
isOptionEnabled(EnableOption::FuseMultipleMatmuls) ||
b_vecs.size() == 1,
"Expected exactly one B operand");
return {a_vecs[0], b_vecs[0], epilogueVectorization()};
}

Expand Down Expand Up @@ -703,8 +711,10 @@ std::unique_ptr<MatmulParams> getMatmulHeuristics(
mma_utils::findMatmulPatterns(fusion);
NVF_ERROR(!patterns.empty(), "No matmul patterns were found");
NVF_ERROR(
patterns.size() == 1,
"Only a single matmul pattern can currently be fused");
isOptionEnabled(EnableOption::FuseMultipleMatmuls) ||
patterns.size() == 1,
"Only a single matmul pattern can currently be fused ",
"unless the fuse_multiple_matmuls option is enabled");
mma_utils::MatmulPattern& pattern = patterns.front();

// IdModel is used to analyze problem shape & layout
Expand Down Expand Up @@ -750,14 +760,21 @@ std::unique_ptr<MatmulParams> getMatmulHeuristics(
problem_shape[(size_t)MatmulDimRole::Batch],
inner_dims,
tensor_roles);
// TODO: more sophisticated handling of multiple matmuls when using plugin
mparams->tile_sizes.cta_tile.m /= (int64_t)patterns.size();
} else {
TORCH_WARN_ONCE(
"Scheduling a matmul without heuristic plugin. "
"Specify plugin location like this: "
"NVFUSER_MATMUL_HEURISTIC_PLUGIN=/path/to/libmatmulheuristic.so");
// Populate heuristic details
auto status =
initCoreHeuristics(mparams.get(), problem_shape, tensor_roles);
auto status = initCoreHeuristics(
mparams.get(),
problem_shape,
tensor_roles,
// TODO: this assumes all patterns will lie in the same main loop, which
// might be false
/*num_problems=*/patterns.size());
NVF_ERROR(status, "Initialization of core part of heuristics failed.");
}

Expand Down Expand Up @@ -857,7 +874,8 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) {
}
}

if (patterns.size() > 1) {
if (!isOptionEnabled(EnableOption::FuseMultipleMatmuls) &&
patterns.size() > 1) {
return "Only a single matmul pattern can currently be fused";
}

Expand Down
15 changes: 10 additions & 5 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <scheduler/utils.h>
#include <val_graph.h>
#include <variant>
#include "options.h"

namespace nvfuser {

Expand Down Expand Up @@ -187,7 +188,8 @@ TensorView* getOperandTv(
NVF_ERROR(it != tensor_roles.end(), "Could not find any tensors with role");
const std::vector<TensorView*>& operands = it->second;
NVF_ERROR(
operands.size() == 1,
isOptionEnabled(EnableOption::FuseMultipleMatmuls) ||
operands.size() == 1,
"Exactly one operand is expected in each A and B role");
return operands.front();
}
Expand Down Expand Up @@ -1347,10 +1349,13 @@ void scheduleStMatrixForMmaOutput(
MatmulOperandInnerDimsOpt getOperandInnerDims(Fusion* fusion) {
const std::vector<MatmulPattern> patterns = findMatmulPatterns(fusion);
if (patterns.size() != 1) {
std::stringstream ss;
ss << "Invalid number of MmaOp instances in fusion, expected 1, got "
<< patterns.size();
return ss.str();
if (!isOptionEnabled(EnableOption::FuseMultipleMatmuls)) {
std::stringstream ss;
ss << "Invalid number of MmaOp instances in fusion, expected 1, got "
<< patterns.size();
return ss.str();
}
TORCH_WARN("TODO: Update getOperandInnerDims for multiple patterns");
}
const MatmulPattern& pattern = patterns[0];
IdModel id_model(fusion);
Expand Down
12 changes: 6 additions & 6 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@ bool TransposeScheduler::canScheduleCompileTime(Fusion* fusion) {
}
}

if (!hasAtLeastTwoValidGroups(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "cannot find two mismatching inner most dimensions");
return false;
}

if (ir_utils::hasAnyReductionOps(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "no support for reduction ops");
Expand All @@ -87,6 +81,12 @@ bool TransposeScheduler::canScheduleCompileTime(Fusion* fusion) {
return false;
}

if (!hasAtLeastTwoValidGroups(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "cannot find two mismatching inner most dimensions");
return false;
}

return true;
}

Expand Down
Loading

0 comments on commit 274968b

Please sign in to comment.