From 8654f9ac9e90d079eb02f4a232da04d4bc03c4ec Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 23 Dec 2024 15:59:40 -0800 Subject: [PATCH 1/3] Move scheduling op bindings to a separate function --- csrc/python_frontend/python_bindings.cpp | 8 ++++++-- csrc/python_frontend/python_bindings.h | 3 +++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 57e7624e772..44cf2a3aa88 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -3569,8 +3569,12 @@ void initNvFuserPythonBindings(PyObject* module) { py::arg("scale").none(true) = py::none(), py::return_value_policy::reference); - //! The ScedOperators class is a nested class of FusionDefinition to allow the - //! user to query the class for the list of schedule operators. + bindSchedulingOperators(fusion_def); +} + +void bindSchedulingOperators(py::class_& fusion_def) { + //! The SchedOperators class is a nested class of FusionDefinition to allow + //! the user to query the class for the list of schedule operators. //! //! Example: //! help(FusionDefinition.SchedOperators) diff --git a/csrc/python_frontend/python_bindings.h b/csrc/python_frontend/python_bindings.h index a698619eb4e..2889f681d2c 100644 --- a/csrc/python_frontend/python_bindings.h +++ b/csrc/python_frontend/python_bindings.h @@ -10,10 +10,13 @@ #include #include +#include #include namespace nvfuser::python_frontend { NVF_API void initNvFuserPythonBindings(PyObject* module); +void bindSchedulingOperators(py::class_& fusion_def); + NVF_API void cleanup(); } // namespace nvfuser::python_frontend From 39727b4da6f30d1379277963787a481bc5732f8a Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 23 Dec 2024 16:30:30 -0800 Subject: [PATCH 2/3] More bindSchedulingOperators to a separate file --- CMakeLists.txt | 1 + csrc/python_frontend/fusion_definition.h | 3 +- csrc/python_frontend/python_bindings.cpp | 516 +------------------- csrc/python_frontend/sched_op_bindings.cpp | 517 +++++++++++++++++++++ 4 files changed, 529 insertions(+), 508 deletions(-) create mode 100644 csrc/python_frontend/sched_op_bindings.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 9cfe58b540e..1e3fa7a91ac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -442,6 +442,7 @@ if(BUILD_PYTHON) list(APPEND NVFUSER_PYTHON_SRCS ${NVFUSER_SRCS_DIR}/python_frontend/python_bindings.cpp ${NVFUSER_SRCS_DIR}/python_frontend/python_bindings_extension.cpp + ${NVFUSER_SRCS_DIR}/python_frontend/sched_op_bindings.cpp ) add_library(nvf_py_internal OBJECT ${NVFUSER_PYTHON_SRCS}) diff --git a/csrc/python_frontend/fusion_definition.h b/csrc/python_frontend/fusion_definition.h index 6157704f86b..28fc4b8b484 100644 --- a/csrc/python_frontend/fusion_definition.h +++ b/csrc/python_frontend/fusion_definition.h @@ -6,14 +6,15 @@ */ // clang-format on #pragma once + #include #include #include +#include #include #include #include -#include namespace nvfuser::python_frontend { diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 44cf2a3aa88..e1f6d9b41d7 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -5,10 +5,18 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include +#include +#include +#include #include #include +#include + +#include +#include +#include + #include #include #include @@ -28,15 +36,7 @@ #include #include #include -#include #include -#include -#include -#include - -#include -#include -#include namespace nvfuser::python_frontend { @@ -3572,504 +3572,6 @@ void initNvFuserPythonBindings(PyObject* module) { bindSchedulingOperators(fusion_def); } -void bindSchedulingOperators(py::class_& fusion_def) { - //! The SchedOperators class is a nested class of FusionDefinition to allow - //! the user to query the class for the list of schedule operators. - //! - //! Example: - //! help(FusionDefinition.SchedOperators) - //! - //! Additional operators are expected to be defined below as needed. - py::class_ nvf_sched( - fusion_def, "SchedOperators"); - nvf_sched.def(py::init()); - nvf_sched.def( - "to_string", - [](FusionDefinition::SchedOperators& self, Tensor tensor) { - // NOTE: For debugging purposes, print the state of TensorView - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - // Determine if tensor is a result from a reduction operation. - FusionDefinition* fd = self.fusion_definition; - TensorView* tv = - fd->getFusionState(tensor.index)->template as(); - return tv->toString(); - }, - py::arg("tensor")); - nvf_sched.def( - "user_schedule_ir", - [](FusionDefinition::SchedOperators& self) { - return self.fusion_definition->userScheduleIr(); - }, - py::return_value_policy::reference); - //! experimental API for multidevice support - nvf_sched.def( - "_set_device_mesh", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const DeviceMesh& mesh) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto tv = fd->getFusionState(tensor.index)->template as(); - tv->setDeviceMesh(mesh); - }, - py::arg("tensor"), - py::arg("mesh")); - nvf_sched.def( - "parallelize", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - int axis, - const ParallelType& parallel_type) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto tv = fd->getFusionState(tensor.index)->template as(); - tv->axis(axis)->parallelize(parallel_type); - }, - py::arg("tensor"), - py::arg("axis"), - py::arg("parallel_type")); - nvf_sched.def( - "merge", - [](FusionDefinition::SchedOperators& self, Tensor arg, int dim) { - FUSER_PERF_SCOPE("SchedOperators.merge"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto input_tv = - fd->getFusionState(arg.index)->template as(); - input_tv->merge(dim); - }, - py::arg("arg"), - py::arg("dim")); - auto reduction_factor_func = [](FusionDefinition::SchedOperators& self, - Tensor arg, - const std::vector& dims) -> Tensor { - FUSER_PERF_SCOPE("SchedOperators.reduction_factor"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* input_tv = - fd->getFusionState(arg.index)->template as(); - TensorView* output_tv = input_tv->rFactor(dims); - return fd->addTensor(output_tv); - }; - nvf_sched.def( - "reduction_factor", - reduction_factor_func, - py::arg("arg"), - py::arg("dims")); - nvf_sched.def( - "rfactor", reduction_factor_func, py::arg("arg"), py::arg("dims")); - nvf_sched.def( - "reorder", - [](FusionDefinition::SchedOperators& self, - Tensor arg, - const std::unordered_map& old2new) { - FUSER_PERF_SCOPE("SchedOperators.reorder"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto input_tv = - fd->getFusionState(arg.index)->template as(); - input_tv->reorder(old2new); - }, - py::arg("arg"), - py::arg("old2new")); - nvf_sched.def( - "split", - [](FusionDefinition::SchedOperators& self, - Tensor arg, - int64_t dim, - int64_t factor, - bool inner_split) { - FUSER_PERF_SCOPE("SchedOperators.split"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto input_tv = - fd->getFusionState(arg.index)->template as(); - input_tv->split(dim, factor, inner_split); - }, - py::arg("arg"), - py::arg("dim"), - py::arg("factor"), - py::arg("inner_split") = true); - nvf_sched.def( - "set_allocation_as_loop", - [](FusionDefinition::SchedOperators& self, Tensor arg) { - FUSER_PERF_SCOPE("SchedOperators.set_allocation_as_loop"); - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - auto* tv = fd->getFusionState(arg.index)->template as(); - tv->setAllocationDomain(tv->getLoopDomain(), true); - }, - py::arg("arg")); - nvf_sched.def( - "cache_after", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const LoadStoreOpType& op_type, - const CacheOp& cache_op) -> Tensor { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* input_tv = - fd->getFusionState(tensor.index)->template as(); - TensorView* output_tv = input_tv->cacheAfter(op_type, cache_op); - return fd->addTensor(output_tv); - }, - py::arg("tensor"), - py::arg("op_type") = LoadStoreOpType::Set, - py::arg("cache_op") = CacheOp::Unspecified); - nvf_sched.def( - "cache_before", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const LoadStoreOpType& op_type) -> Tensor { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* input_tv = - fd->getFusionState(tensor.index)->template as(); - TensorView* output_tv = input_tv->cacheBefore(op_type); - return fd->addTensor(output_tv); - }, - py::arg("tensor"), - py::arg("op_type") = LoadStoreOpType::Set); - nvf_sched.def( - "cache_fork", - [](FusionDefinition::SchedOperators& self, Tensor tensor) -> Tensor { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* input_tv = - fd->getFusionState(tensor.index)->template as(); - TensorView* output_tv = input_tv->cacheFork(); - return fd->addTensor(output_tv); - }, - py::arg("tensor")); - nvf_sched.def( - "set_memory_type", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const MemoryType& memory_type) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - FusionDefinition* fd = self.fusion_definition; - TensorView* tv = - fd->getFusionState(tensor.index)->template as(); - tv->setMemoryType(memory_type); - }, - py::arg("tensor"), - py::arg("memory_type")); - nvf_sched.def( - "transform_like", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - const std::vector& selected_tensors) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - FusionDefinition* fd = self.fusion_definition; - TensorView* reference_tv = - fd->getFusionState(tensor.index)->template as(); - - TransformPropagator propagator(reference_tv); - if (selected_tensors.empty()) { - // Propagate scheduler transformations on reference TensorView to the - // rest of the fusion. - MaxLogicalDomainInfoSpanningTree(reference_tv).traverse(&propagator); - } else { - // Propagate scheduler transformations on reference TensorView to the - // subset of the fusion. - std::unordered_set selected_tv_set; - selected_tv_set.reserve(selected_tensors.size()); - std::transform( - selected_tensors.begin(), - selected_tensors.end(), - std::inserter(selected_tv_set, selected_tv_set.end()), - [&fd](const Tensor& t) { - return fd->getFusionState(t.index)->template as(); - }); - SetSelector selector( - {selected_tv_set.begin(), selected_tv_set.end()}); - MaxLogicalDomainInfoSpanningTree(reference_tv, &selector) - .traverse(&propagator); - } - }, - py::arg("tensor"), - py::arg("selected_tensors") = std::vector()); - nvf_sched.def( - "parallelize_like", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - int64_t pos, - const std::vector& selected_tensors, - const std::unordered_set& selected_parallel_types, - bool propagate_padding) { - // Propagate the parallelization from the selected dimensions of the - // reference tensor to their corresponding dimensions in all selected - // tensors in the DAG. - // - // 1. Position `pos` means selecting all the dimensions - // [0, 1, ..., pos - 1]. pos = -1 means selecting all dimensions. - // 2. `selected_tvs` are selected tensors in the DAG. Empty - // `selected_tvs` means selecting all tensors in the fusion of - // `reference_tv`. - // 3. `selected_parallel_types` are the selected parallel types. Empty - // `selected_parallel_types` means selecting all parallel types. - - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - FusionDefinition* fd = self.fusion_definition; - TensorView* reference_tv = - fd->getFusionState(tensor.index)->template as(); - - std::vector selected_tvs; - selected_tvs.reserve(selected_tensors.size()); - std::transform( - selected_tensors.begin(), - selected_tensors.end(), - std::back_inserter(selected_tvs), - [&fd](const Tensor& t) { - return fd->getFusionState(t.index)->template as(); - }); - - nvfuser::scheduler_utils::parallelizeAllLike( - reference_tv, - pos, - selected_tvs, - selected_parallel_types, - propagate_padding); - }, - py::arg("tensor"), - py::arg("pos") = -1, - py::arg("selected_tensors") = std::vector(), - py::arg("selected_parallel_types") = std::unordered_set(), - py::arg("propagate_padding") = true); - nvf_sched.def( - "inline_most", - [](FusionDefinition::SchedOperators& self, - const std::vector& selected_tensors) { - // Inline to the right most allowed position for the selected tensors in - // the current fusion. - - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - FusionDefinition* fd = self.fusion_definition; - - if (selected_tensors.empty()) { - nvfuser::inlineMost(); - } else { - std::vector selected_tvs; - selected_tvs.reserve(selected_tensors.size()); - std::transform( - selected_tensors.begin(), - selected_tensors.end(), - std::back_inserter(selected_tvs), - [&fd](const Tensor& t) { - return fd->getFusionState(t.index)->template as(); - }); - nvfuser::inlineMost(selected_tvs); - } - }, - py::arg("selected_tensors") = std::vector()); - nvf_sched.def( - "inline_at", - [](FusionDefinition::SchedOperators& self, - Tensor tensor, - int64_t pos, - bool best_effort, - const std::vector& selected_tensors) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - FusionDefinition* fd = self.fusion_definition; - TensorView* reference_tv = - fd->getFusionState(tensor.index)->template as(); - - if (selected_tensors.empty()) { - // Inline to the position corresponding to the reference position in - // the reference tensor for all tensors in the current fusion. - nvfuser::inlineAllAt(reference_tv, pos, best_effort); - } else { - // Inline to the position corresponding to the reference position in - // the reference tensor for selected tensors in the current fusion. - std::unordered_set selected_tvs; - selected_tvs.reserve(selected_tensors.size()); - std::transform( - selected_tensors.begin(), - selected_tensors.end(), - std::inserter(selected_tvs, selected_tvs.end()), - [&fd](const Tensor& t) { - return fd->getFusionState(t.index)->template as(); - }); - - nvfuser::inlineSelectedAt( - selected_tvs, reference_tv, pos, best_effort); - } - }, - py::arg("tensor"), - py::arg("pos") = -1, - py::arg("best_effort") = false, - py::arg("selected_tensors") = std::vector()); - nvf_sched.def("tensors", [](FusionDefinition::SchedOperators& self) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - // Return all Tensors in FusionDefinition - return self.fusion_definition->tensors(); - }); - nvf_sched.def( - "is_reduction", - [](FusionDefinition::SchedOperators& self, Tensor tensor) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - // Determine if tensor is a result from a reduction operation. - FusionDefinition* fd = self.fusion_definition; - TensorView* tv = - fd->getFusionState(tensor.index)->template as(); - return ( - !tv->isFusionInput() && - std::any_of( - tv->getMaybeRootDomain().begin(), - tv->getMaybeRootDomain().end(), - [](IterDomain* id) { return id->isReduction(); }) && - !isResharding(tv->definition())); - }, - py::arg("tensor")); - nvf_sched.def( - "can_schedule", - [](FusionDefinition::SchedOperators& self, - const SchedulerType& scheduler_type) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - return self.fusion_definition->userSchedule()->canScheduleDebug( - scheduler_type); - }, - py::arg("scheduler_type")); - nvf_sched.def( - "find_compatible_schedulers", [](FusionDefinition::SchedOperators& self) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - - std::vector valid_scheduler_types; - valid_scheduler_types.reserve(all_heuristics_in_priority_order.size()); - std::copy_if( - all_heuristics_in_priority_order.begin(), - all_heuristics_in_priority_order.end(), - std::back_inserter(valid_scheduler_types), - [sched = self.fusion_definition->userSchedule()]( - SchedulerType scheduler_type) { - return sched->canSchedule(scheduler_type); - }); - return valid_scheduler_types; - }); - nvf_sched.def( - "schedule", - [](FusionDefinition::SchedOperators& self, - const SchedulerType& scheduler_type) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - auto&& [can_schedule, error_msg] = - sched->canScheduleDebug(scheduler_type); - NVF_CHECK(can_schedule, error_msg); - sched->scheduleWithType(scheduler_type); - }, - py::arg("heuristic")); - nvf_sched.def("schedule", [](FusionDefinition::SchedOperators& self) { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - sched->schedule(); - }); - nvf_sched.def( - "compute_pointwise_heuristics", - [](FusionDefinition::SchedOperators& self) -> PointwiseParams& { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - HeuristicParams* parameters = - sched->computeHeuristics(SchedulerType::PointWise); - return *parameters->as(); - }, - py::return_value_policy::reference); - nvf_sched.def( - "compute_reduction_heuristics", - [](FusionDefinition::SchedOperators& self) -> ReductionParams& { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - HeuristicParams* parameters = - sched->computeHeuristics(SchedulerType::Reduction); - return *parameters->as(); - }, - py::return_value_policy::reference); - nvf_sched.def( - "compute_matmul_heuristics", - [](FusionDefinition::SchedOperators& self) -> MatmulParams& { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - HeuristicParams* parameters = - sched->computeHeuristics(SchedulerType::Matmul); - return *parameters->as(); - }, - py::return_value_policy::reference); - nvf_sched.def( - "schedule_hyperparameters", - [](FusionDefinition::SchedOperators& self) - -> scheduler_utils::SchedulerHyperParameters& { - NVF_CHECK( - self.validUse(), - "Attempting to use a SchedOperators Op prior to definition!"); - UserSchedule* sched = self.fusion_definition->userSchedule(); - auto scheduler_hyperparameters_entry = HeuristicDataCacheEntry< - HeuristicCompileTime::SchedulerHyperParameters>( - sched->data_cache.get(), []() { - return std::make_unique< - scheduler_utils::SchedulerHyperParameters>( - /*vectorize_factor=*/1, - /*unroll_factor=*/1, - /*threads_per_block_min=*/1, - /*threads_per_block_max=*/1); - }); - return scheduler_hyperparameters_entry.get(); - }, - py::return_value_policy::reference); -} - void cleanup() { Communicator::getInstance().cleanup(); } diff --git a/csrc/python_frontend/sched_op_bindings.cpp b/csrc/python_frontend/sched_op_bindings.cpp new file mode 100644 index 00000000000..078a330a604 --- /dev/null +++ b/csrc/python_frontend/sched_op_bindings.cpp @@ -0,0 +1,517 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nvfuser::python_frontend { + +void bindSchedulingOperators(py::class_& fusion_def) { + //! The SchedOperators class is a nested class of FusionDefinition to allow + //! the user to query the class for the list of schedule operators. + //! + //! Example: + //! help(FusionDefinition.SchedOperators) + //! + //! Additional operators are expected to be defined below as needed. + py::class_ nvf_sched( + fusion_def, "SchedOperators"); + nvf_sched.def(py::init()); + nvf_sched.def( + "to_string", + [](FusionDefinition::SchedOperators& self, Tensor tensor) { + // NOTE: For debugging purposes, print the state of TensorView + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + // Determine if tensor is a result from a reduction operation. + FusionDefinition* fd = self.fusion_definition; + TensorView* tv = + fd->getFusionState(tensor.index)->template as(); + return tv->toString(); + }, + py::arg("tensor")); + nvf_sched.def( + "user_schedule_ir", + [](FusionDefinition::SchedOperators& self) { + return self.fusion_definition->userScheduleIr(); + }, + py::return_value_policy::reference); + //! experimental API for multidevice support + nvf_sched.def( + "_set_device_mesh", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + const DeviceMesh& mesh) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + auto tv = fd->getFusionState(tensor.index)->template as(); + tv->setDeviceMesh(mesh); + }, + py::arg("tensor"), + py::arg("mesh")); + nvf_sched.def( + "parallelize", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + int axis, + const ParallelType& parallel_type) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + auto tv = fd->getFusionState(tensor.index)->template as(); + tv->axis(axis)->parallelize(parallel_type); + }, + py::arg("tensor"), + py::arg("axis"), + py::arg("parallel_type")); + nvf_sched.def( + "merge", + [](FusionDefinition::SchedOperators& self, Tensor arg, int dim) { + FUSER_PERF_SCOPE("SchedOperators.merge"); + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + auto input_tv = + fd->getFusionState(arg.index)->template as(); + input_tv->merge(dim); + }, + py::arg("arg"), + py::arg("dim")); + auto reduction_factor_func = [](FusionDefinition::SchedOperators& self, + Tensor arg, + const std::vector& dims) -> Tensor { + FUSER_PERF_SCOPE("SchedOperators.reduction_factor"); + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + TensorView* input_tv = + fd->getFusionState(arg.index)->template as(); + TensorView* output_tv = input_tv->rFactor(dims); + return fd->addTensor(output_tv); + }; + nvf_sched.def( + "reduction_factor", + reduction_factor_func, + py::arg("arg"), + py::arg("dims")); + nvf_sched.def( + "rfactor", reduction_factor_func, py::arg("arg"), py::arg("dims")); + nvf_sched.def( + "reorder", + [](FusionDefinition::SchedOperators& self, + Tensor arg, + const std::unordered_map& old2new) { + FUSER_PERF_SCOPE("SchedOperators.reorder"); + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + auto input_tv = + fd->getFusionState(arg.index)->template as(); + input_tv->reorder(old2new); + }, + py::arg("arg"), + py::arg("old2new")); + nvf_sched.def( + "split", + [](FusionDefinition::SchedOperators& self, + Tensor arg, + int64_t dim, + int64_t factor, + bool inner_split) { + FUSER_PERF_SCOPE("SchedOperators.split"); + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + auto input_tv = + fd->getFusionState(arg.index)->template as(); + input_tv->split(dim, factor, inner_split); + }, + py::arg("arg"), + py::arg("dim"), + py::arg("factor"), + py::arg("inner_split") = true); + nvf_sched.def( + "set_allocation_as_loop", + [](FusionDefinition::SchedOperators& self, Tensor arg) { + FUSER_PERF_SCOPE("SchedOperators.set_allocation_as_loop"); + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + auto* tv = fd->getFusionState(arg.index)->template as(); + tv->setAllocationDomain(tv->getLoopDomain(), true); + }, + py::arg("arg")); + nvf_sched.def( + "cache_after", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + const LoadStoreOpType& op_type, + const CacheOp& cache_op) -> Tensor { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + TensorView* input_tv = + fd->getFusionState(tensor.index)->template as(); + TensorView* output_tv = input_tv->cacheAfter(op_type, cache_op); + return fd->addTensor(output_tv); + }, + py::arg("tensor"), + py::arg("op_type") = LoadStoreOpType::Set, + py::arg("cache_op") = CacheOp::Unspecified); + nvf_sched.def( + "cache_before", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + const LoadStoreOpType& op_type) -> Tensor { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + TensorView* input_tv = + fd->getFusionState(tensor.index)->template as(); + TensorView* output_tv = input_tv->cacheBefore(op_type); + return fd->addTensor(output_tv); + }, + py::arg("tensor"), + py::arg("op_type") = LoadStoreOpType::Set); + nvf_sched.def( + "cache_fork", + [](FusionDefinition::SchedOperators& self, Tensor tensor) -> Tensor { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + TensorView* input_tv = + fd->getFusionState(tensor.index)->template as(); + TensorView* output_tv = input_tv->cacheFork(); + return fd->addTensor(output_tv); + }, + py::arg("tensor")); + nvf_sched.def( + "set_memory_type", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + const MemoryType& memory_type) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + FusionDefinition* fd = self.fusion_definition; + TensorView* tv = + fd->getFusionState(tensor.index)->template as(); + tv->setMemoryType(memory_type); + }, + py::arg("tensor"), + py::arg("memory_type")); + nvf_sched.def( + "transform_like", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + const std::vector& selected_tensors) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + + FusionDefinition* fd = self.fusion_definition; + TensorView* reference_tv = + fd->getFusionState(tensor.index)->template as(); + + TransformPropagator propagator(reference_tv); + if (selected_tensors.empty()) { + // Propagate scheduler transformations on reference TensorView to the + // rest of the fusion. + MaxLogicalDomainInfoSpanningTree(reference_tv).traverse(&propagator); + } else { + // Propagate scheduler transformations on reference TensorView to the + // subset of the fusion. + std::unordered_set selected_tv_set; + selected_tv_set.reserve(selected_tensors.size()); + std::transform( + selected_tensors.begin(), + selected_tensors.end(), + std::inserter(selected_tv_set, selected_tv_set.end()), + [&fd](const Tensor& t) { + return fd->getFusionState(t.index)->template as(); + }); + SetSelector selector( + {selected_tv_set.begin(), selected_tv_set.end()}); + MaxLogicalDomainInfoSpanningTree(reference_tv, &selector) + .traverse(&propagator); + } + }, + py::arg("tensor"), + py::arg("selected_tensors") = std::vector()); + nvf_sched.def( + "parallelize_like", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + int64_t pos, + const std::vector& selected_tensors, + const std::unordered_set& selected_parallel_types, + bool propagate_padding) { + // Propagate the parallelization from the selected dimensions of the + // reference tensor to their corresponding dimensions in all selected + // tensors in the DAG. + // + // 1. Position `pos` means selecting all the dimensions + // [0, 1, ..., pos - 1]. pos = -1 means selecting all dimensions. + // 2. `selected_tvs` are selected tensors in the DAG. Empty + // `selected_tvs` means selecting all tensors in the fusion of + // `reference_tv`. + // 3. `selected_parallel_types` are the selected parallel types. Empty + // `selected_parallel_types` means selecting all parallel types. + + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + + FusionDefinition* fd = self.fusion_definition; + TensorView* reference_tv = + fd->getFusionState(tensor.index)->template as(); + + std::vector selected_tvs; + selected_tvs.reserve(selected_tensors.size()); + std::transform( + selected_tensors.begin(), + selected_tensors.end(), + std::back_inserter(selected_tvs), + [&fd](const Tensor& t) { + return fd->getFusionState(t.index)->template as(); + }); + + nvfuser::scheduler_utils::parallelizeAllLike( + reference_tv, + pos, + selected_tvs, + selected_parallel_types, + propagate_padding); + }, + py::arg("tensor"), + py::arg("pos") = -1, + py::arg("selected_tensors") = std::vector(), + py::arg("selected_parallel_types") = std::unordered_set(), + py::arg("propagate_padding") = true); + nvf_sched.def( + "inline_most", + [](FusionDefinition::SchedOperators& self, + const std::vector& selected_tensors) { + // Inline to the right most allowed position for the selected tensors in + // the current fusion. + + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + + FusionDefinition* fd = self.fusion_definition; + + if (selected_tensors.empty()) { + nvfuser::inlineMost(); + } else { + std::vector selected_tvs; + selected_tvs.reserve(selected_tensors.size()); + std::transform( + selected_tensors.begin(), + selected_tensors.end(), + std::back_inserter(selected_tvs), + [&fd](const Tensor& t) { + return fd->getFusionState(t.index)->template as(); + }); + nvfuser::inlineMost(selected_tvs); + } + }, + py::arg("selected_tensors") = std::vector()); + nvf_sched.def( + "inline_at", + [](FusionDefinition::SchedOperators& self, + Tensor tensor, + int64_t pos, + bool best_effort, + const std::vector& selected_tensors) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + + FusionDefinition* fd = self.fusion_definition; + TensorView* reference_tv = + fd->getFusionState(tensor.index)->template as(); + + if (selected_tensors.empty()) { + // Inline to the position corresponding to the reference position in + // the reference tensor for all tensors in the current fusion. + nvfuser::inlineAllAt(reference_tv, pos, best_effort); + } else { + // Inline to the position corresponding to the reference position in + // the reference tensor for selected tensors in the current fusion. + std::unordered_set selected_tvs; + selected_tvs.reserve(selected_tensors.size()); + std::transform( + selected_tensors.begin(), + selected_tensors.end(), + std::inserter(selected_tvs, selected_tvs.end()), + [&fd](const Tensor& t) { + return fd->getFusionState(t.index)->template as(); + }); + + nvfuser::inlineSelectedAt( + selected_tvs, reference_tv, pos, best_effort); + } + }, + py::arg("tensor"), + py::arg("pos") = -1, + py::arg("best_effort") = false, + py::arg("selected_tensors") = std::vector()); + nvf_sched.def("tensors", [](FusionDefinition::SchedOperators& self) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + // Return all Tensors in FusionDefinition + return self.fusion_definition->tensors(); + }); + nvf_sched.def( + "is_reduction", + [](FusionDefinition::SchedOperators& self, Tensor tensor) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + // Determine if tensor is a result from a reduction operation. + FusionDefinition* fd = self.fusion_definition; + TensorView* tv = + fd->getFusionState(tensor.index)->template as(); + return ( + !tv->isFusionInput() && + std::any_of( + tv->getMaybeRootDomain().begin(), + tv->getMaybeRootDomain().end(), + [](IterDomain* id) { return id->isReduction(); }) && + !isResharding(tv->definition())); + }, + py::arg("tensor")); + nvf_sched.def( + "can_schedule", + [](FusionDefinition::SchedOperators& self, + const SchedulerType& scheduler_type) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + return self.fusion_definition->userSchedule()->canScheduleDebug( + scheduler_type); + }, + py::arg("scheduler_type")); + nvf_sched.def( + "find_compatible_schedulers", [](FusionDefinition::SchedOperators& self) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + + std::vector valid_scheduler_types; + valid_scheduler_types.reserve(all_heuristics_in_priority_order.size()); + std::copy_if( + all_heuristics_in_priority_order.begin(), + all_heuristics_in_priority_order.end(), + std::back_inserter(valid_scheduler_types), + [sched = self.fusion_definition->userSchedule()]( + SchedulerType scheduler_type) { + return sched->canSchedule(scheduler_type); + }); + return valid_scheduler_types; + }); + nvf_sched.def( + "schedule", + [](FusionDefinition::SchedOperators& self, + const SchedulerType& scheduler_type) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + auto&& [can_schedule, error_msg] = + sched->canScheduleDebug(scheduler_type); + NVF_CHECK(can_schedule, error_msg); + sched->scheduleWithType(scheduler_type); + }, + py::arg("heuristic")); + nvf_sched.def("schedule", [](FusionDefinition::SchedOperators& self) { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + sched->schedule(); + }); + nvf_sched.def( + "compute_pointwise_heuristics", + [](FusionDefinition::SchedOperators& self) -> PointwiseParams& { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + HeuristicParams* parameters = + sched->computeHeuristics(SchedulerType::PointWise); + return *parameters->as(); + }, + py::return_value_policy::reference); + nvf_sched.def( + "compute_reduction_heuristics", + [](FusionDefinition::SchedOperators& self) -> ReductionParams& { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + HeuristicParams* parameters = + sched->computeHeuristics(SchedulerType::Reduction); + return *parameters->as(); + }, + py::return_value_policy::reference); + nvf_sched.def( + "compute_matmul_heuristics", + [](FusionDefinition::SchedOperators& self) -> MatmulParams& { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + HeuristicParams* parameters = + sched->computeHeuristics(SchedulerType::Matmul); + return *parameters->as(); + }, + py::return_value_policy::reference); + nvf_sched.def( + "schedule_hyperparameters", + [](FusionDefinition::SchedOperators& self) + -> scheduler_utils::SchedulerHyperParameters& { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + auto scheduler_hyperparameters_entry = HeuristicDataCacheEntry< + HeuristicCompileTime::SchedulerHyperParameters>( + sched->data_cache.get(), []() { + return std::make_unique< + scheduler_utils::SchedulerHyperParameters>( + /*vectorize_factor=*/1, + /*unroll_factor=*/1, + /*threads_per_block_min=*/1, + /*threads_per_block_max=*/1); + }); + return scheduler_hyperparameters_entry.get(); + }, + py::return_value_policy::reference); +} + +} // namespace nvfuser::python_frontend From aa371a4bd09e735f42da0f93c69e37219db68683 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 23 Dec 2024 17:26:47 -0800 Subject: [PATCH 3/3] rename --- CMakeLists.txt | 2 +- csrc/python_frontend/python_bindings.cpp | 2 +- csrc/python_frontend/python_bindings.h | 2 +- .../{sched_op_bindings.cpp => schedule_bindings.cpp} | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) rename csrc/python_frontend/{sched_op_bindings.cpp => schedule_bindings.cpp} (99%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1e3fa7a91ac..2c673b17332 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -442,7 +442,7 @@ if(BUILD_PYTHON) list(APPEND NVFUSER_PYTHON_SRCS ${NVFUSER_SRCS_DIR}/python_frontend/python_bindings.cpp ${NVFUSER_SRCS_DIR}/python_frontend/python_bindings_extension.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/sched_op_bindings.cpp + ${NVFUSER_SRCS_DIR}/python_frontend/schedule_bindings.cpp ) add_library(nvf_py_internal OBJECT ${NVFUSER_PYTHON_SRCS}) diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index e1f6d9b41d7..ea061b094f1 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -3569,7 +3569,7 @@ void initNvFuserPythonBindings(PyObject* module) { py::arg("scale").none(true) = py::none(), py::return_value_policy::reference); - bindSchedulingOperators(fusion_def); + bindSchedule(fusion_def); } void cleanup() { diff --git a/csrc/python_frontend/python_bindings.h b/csrc/python_frontend/python_bindings.h index 2889f681d2c..bd8f0347530 100644 --- a/csrc/python_frontend/python_bindings.h +++ b/csrc/python_frontend/python_bindings.h @@ -16,7 +16,7 @@ namespace nvfuser::python_frontend { NVF_API void initNvFuserPythonBindings(PyObject* module); -void bindSchedulingOperators(py::class_& fusion_def); +void bindSchedule(py::class_& fusion_def); NVF_API void cleanup(); } // namespace nvfuser::python_frontend diff --git a/csrc/python_frontend/sched_op_bindings.cpp b/csrc/python_frontend/schedule_bindings.cpp similarity index 99% rename from csrc/python_frontend/sched_op_bindings.cpp rename to csrc/python_frontend/schedule_bindings.cpp index 078a330a604..b77982711cb 100644 --- a/csrc/python_frontend/sched_op_bindings.cpp +++ b/csrc/python_frontend/schedule_bindings.cpp @@ -16,7 +16,7 @@ namespace nvfuser::python_frontend { -void bindSchedulingOperators(py::class_& fusion_def) { +void bindSchedule(py::class_& fusion_def) { //! The SchedOperators class is a nested class of FusionDefinition to allow //! the user to query the class for the list of schedule operators. //!