Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring Fusion Executor, pulling out compiled kernel #3468

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
6bdd7f5
Redraft pulling compiled kernel out of kernel executor.
csarofeen Nov 24, 2024
a7fca2e
Merge branch 'main' of https://github.com/NVIDIA/Fuser into compiled_…
csarofeen Nov 30, 2024
dbb0554
Cleanup and preparation to cleanup executor_utils.h
csarofeen Nov 30, 2024
b7c9e7c
Move compilation logic out of executor_utils into compiled_kernel.
csarofeen Nov 30, 2024
b61ac3a
Remove compilation profiling from compiled kernel as it's still calle…
csarofeen Nov 30, 2024
6177db2
cleanup
csarofeen Nov 30, 2024
0e3cdd9
Fix input binding in executor.
csarofeen Dec 1, 2024
3961181
Kernel executor doesn't instantiate compiled kernel unless compiled, …
csarofeen Dec 2, 2024
f2b00bb
Fix type consistency in st matrix testing.
csarofeen Dec 3, 2024
56dda45
Fix build.
csarofeen Dec 3, 2024
5999480
Need to be consistent with types for fusion.manage.
csarofeen Dec 7, 2024
5920d34
Merge branch 'main' of https://github.com/NVIDIA/Fuser into compiled_…
csarofeen Dec 7, 2024
a7ad429
Repair serialization.
csarofeen Dec 8, 2024
d89d155
Fix check that disables parameter cache, the check was valid before l…
csarofeen Dec 15, 2024
2933509
Merge branch 'main' into compiled_kernel_2
csarofeen Dec 15, 2024
b222303
Merge branch 'main' of https://github.com/NVIDIA/Fuser into compiled_…
csarofeen Dec 18, 2024
239d652
Merge.
csarofeen Dec 18, 2024
8137228
Merge branch 'main' of https://github.com/NVIDIA/Fuser into compiled_…
csarofeen Dec 23, 2024
a9733b0
Merge conflicts.
csarofeen Dec 23, 2024
da1452c
Fix lowering hooks, rename compileFusion to compile.
csarofeen Dec 23, 2024
e1fcd7f
Fix param cache check with validation.
csarofeen Dec 25, 2024
b66e6c5
Remove refactor validation.
csarofeen Dec 25, 2024
7dc8d09
Merge branch 'main' into compiled_kernel_2
csarofeen Dec 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp
${NVFUSER_SRCS_DIR}/rng.cpp
${NVFUSER_SRCS_DIR}/runtime/allocations.cpp
${NVFUSER_SRCS_DIR}/runtime/compiled_kernel.cpp
${NVFUSER_SRCS_DIR}/runtime/executor.cpp
${NVFUSER_SRCS_DIR}/runtime/executor_dispatch.cpp
${NVFUSER_SRCS_DIR}/runtime/executor_kernel_arg.cpp
Expand Down
9 changes: 6 additions & 3 deletions benchmarks/cpp/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ static void SingleMatmulBase(
KernelExecutor ke;
ke.compile(fusion, args, launch_constraints, cparams);
NVF_CHECK(
getBankConflictInfo(ke.kernel(), launch_constraints).empty(),
getBankConflictInfo(ke.compiledKernel()->kernel(), launch_constraints)
.empty(),
"Shared memory bank conflict not removed.");

std::vector<c10::IValue> aten_inputs({inputs.first, inputs.second});
Expand Down Expand Up @@ -358,7 +359,7 @@ static void SingleMatmulPartitionedK(
auto lparams = LaunchParams();
ke.compile(fusion, args, lparams, cparams);
NVF_CHECK(
getBankConflictInfo(ke.kernel(), lparams).empty(),
getBankConflictInfo(ke.compiledKernel()->kernel(), lparams).empty(),
"Shared memory bank conflict not removed.");

// Warm up run
Expand Down Expand Up @@ -471,7 +472,9 @@ static void NvFuserScheduler_MatmulSplitKReduction(
fusion, args, heuristic_params->lparams, heuristic_params->cparams);

NVF_CHECK(
getBankConflictInfo(ke.kernel(), heuristic_params->lparams).empty(),
getBankConflictInfo(
ke.compiledKernel()->kernel(), heuristic_params->lparams)
.empty(),
"Shared memory bank conflict not removed.");

// Warm up run
Expand Down
1 change: 1 addition & 0 deletions csrc/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <cuda_runtime.h>
#include <driver_api.h>
#include <exceptions.h>
#include <nvrtc.h>

#define NVFUSER_NVRTC_SAFE_CALL(x) \
do { \
Expand Down
2 changes: 1 addition & 1 deletion csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ enum class DebugDumpOption {
FusionIrPresched, //!< Dump the segmented Fusion IR before it is scheduled
// TODO(wujingyue): name the following FusionIrSched
FusionIr, //!< Dump the Fusion IR before lowering. This is the Fusion IR fed
//!< to `KernelExecutor::compileFusion`.
//!< to `KernelExecutor::compile`.
FusionIrGraph, //!< Dump a GraphViz graph of the Fusion IR
FusionIrMath, //!< Dump just the compute (math) part of the above `FusionIr`
//!< for conciseness
Expand Down
16 changes: 9 additions & 7 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ std::vector<at::Tensor> FusionDefinition::execute(
outputs = user_sched.executor->run(inputs);
} else {
// Automatic scheduler was used for UserSchedule.
// Pass launch and compile params to compileFusion and runFusion.
// Pass launch and compile params to compile and run.
if (!user_sched.executor->isCompiled()) {
user_sched.executor->compile(
user_sched.scheduled_fusion.get(),
Expand Down Expand Up @@ -487,10 +487,11 @@ std::string FusionDefinition::lastCudaCode(

if (!override_user_schedule && (user_exec != nullptr)) {
if (intrinsic_code) {
result = user_exec->getStructuredCode(
user_exec->kernelString(), user_exec->kernel()->indexType());
result = user_exec->compiledKernel()->getStructuredCode(
user_exec->compiledKernel()->kernelString(),
user_exec->compiledKernel()->kernel()->indexType());
} else {
result = user_exec->kernelString();
result = user_exec->compiledKernel()->kernelString();
}
} else {
result = scheds->auto_gen_schedules->getMostRecentCode(intrinsic_code);
Expand All @@ -516,10 +517,11 @@ std::string FusionDefinition::cudaCodeFor(
scheds, user_sched_id.value(), device);
auto user_exec = user_sched.executor.get();
if (intrinsic_code) {
return user_exec->getStructuredCode(
user_exec->kernelString(), user_exec->kernel()->indexType());
return user_exec->compiledKernel()->getStructuredCode(
user_exec->compiledKernel()->kernelString(),
user_exec->compiledKernel()->kernel()->indexType());
} else {
return user_exec->kernelString();
return user_exec->compiledKernel()->kernelString();
}
}
}
Expand Down
Loading
Loading