Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 139 additions & 2 deletions csrc/id_model/id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <utility>

#include "device_lower/analysis/circular_buffer.h"
#include "device_lower/analysis/trivial_broadcast.h"
#include "device_lower/lower2device.h"
#include "device_lower/utils.h"
#include "disjoint_set.h"
Expand All @@ -25,7 +24,6 @@
#include "iter_visitor.h"
#include "logical_domain_map.h"
#include "transform_iter.h"
#include "val_graph_visitor.h"

namespace nvfuser {

Expand Down Expand Up @@ -1426,4 +1424,143 @@ ValGraph buildPermissiveResizeGraph(const ValGraph& permissive_graph) {
return resize_graph;
}

// https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md#2-properties-of-iterdomain-transformations
ValGraph mapAlmostExactSplits(const ValGraph& graph) {
auto new_graph = graph;

// vg: I0
auto get_l1r2_splits =
[&new_graph](
const ValGroup& vg) -> std::vector<std::pair<ExprGroup, ExprGroup>> {
std::vector<std::pair<ExprGroup, ExprGroup>> l1_r2_splits;

if (!new_graph.hasUses(vg)) {
return {};
}

for (const ExprGroup& use_of_vg : new_graph.getUses(vg)) {
auto split_of_vg = dynamic_cast<Split*>(use_of_vg->front());
if (split_of_vg == nullptr) {
continue;
}

// mn
const ValGroup& inner_group = new_graph.toGroup(split_of_vg->inner());

if (!new_graph.hasUses(inner_group)) {
return {};
}

for (const ExprGroup& use_of_inner_group :
new_graph.getUses(inner_group)) {
auto split_of_inner_group =
dynamic_cast<Split*>(use_of_inner_group->front());
if (split_of_inner_group == nullptr) {
continue;
}

// This split needs to be divisible
auto extent = split_of_inner_group->in()->extent();
auto factor = split_of_inner_group->factor();
if (extent->isConstScalar() && factor->isConstScalar() &&
(extent->evaluate().as<int64_t>() %
factor->evaluate().as<int64_t>() !=
0)) {
continue;
}

l1_r2_splits.emplace_back(use_of_vg, use_of_inner_group);

std::cerr << "L1R2 found: " << split_of_vg->toString()
<< split_of_inner_group->toString();
}
}

return l1_r2_splits;
};

auto get_matching_l2r1_splits =
[&new_graph](
const ValGroup& vg, const std::pair<ExprGroup, ExprGroup>& l1_r2)
-> std::optional<std::pair<ExprGroup, ExprGroup>> {
auto m = l1_r2.second->front()->as<Split>()->outer()->extent();
auto n = l1_r2.second->front()->as<Split>()->inner()->extent();

for (const ExprGroup& use_of_vg : new_graph.getUses(vg)) {
auto split_of_vg = dynamic_cast<Split*>(use_of_vg->front());
if (split_of_vg == nullptr) {
continue;
}

if (!split_of_vg->inner()->extent()->sameAs(n)) {
continue;
}

// I0/n
const ValGroup& outer_group = new_graph.toGroup(split_of_vg->outer());

if (!new_graph.hasUses(outer_group)) {
return {};
}

for (const ExprGroup& use_of_outer_group :
new_graph.getUses(outer_group)) {
auto split_of_outer_group =
dynamic_cast<Split*>(use_of_outer_group->front());
if (split_of_outer_group == nullptr) {
continue;
}

if (!split_of_outer_group->inner()->extent()->sameAs(m)) {
continue;
}

std::cerr << "Matching L2R1 found: " << split_of_vg->toString()
<< split_of_outer_group->toString();
return std::make_pair(use_of_vg, use_of_outer_group);
}
}

return std::nullopt;
};

std::vector<std::pair<ValGroup, ValGroup>> groups_to_map;

for (const ValGroup& vg : new_graph.disjointValSets().disjointSets()) {
const auto all_l1r2_splits = get_l1r2_splits(vg);
for (const auto& l1r2 : all_l1r2_splits) {
std::cerr << "L1R2: " << l1r2.first->front()->toString()
<< l1r2.second->front()->toString();
auto l2r1 = get_matching_l2r1_splits(vg, l1r2);
if (!l2r1.has_value()) {
continue;
}

std::cerr << "Found\n";

auto l1r2_first_outputs = new_graph.outputGroups(l1r2.first);
auto l1r2_second_outputs = new_graph.outputGroups(l1r2.second);

auto l2r1_first_outputs = new_graph.outputGroups(l2r1->first);
auto l2r1_second_outputs = new_graph.outputGroups(l2r1->second);

groups_to_map.emplace_back(
l1r2_first_outputs.at(0), l2r1_second_outputs.at(0));
groups_to_map.emplace_back(
l1r2_second_outputs.at(0), l2r1_second_outputs.at(1));
groups_to_map.emplace_back(
l1r2_second_outputs.at(1), l2r1_first_outputs.at(1));
}
}

for (const auto& [vg1, vg2] : groups_to_map) {
std::cerr << "Mapping " << nvfuser::toString(vg1) << ", "
<< vg1->front()->toString() << " and " << nvfuser::toString(vg2)
<< ", " << vg2->front()->toString() << "\n";
new_graph.mapVals(vg1->front(), vg2->front());
}

return new_graph;
}

} // namespace nvfuser
2 changes: 2 additions & 0 deletions csrc/id_model/id_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,4 +375,6 @@ std::unordered_map<ValGroup, IterDomain*> updateValGroupIdMap(
// This adds additional mappings for resize operations.
ValGraph buildPermissiveResizeGraph(const ValGraph& permissive_graph);

ValGraph mapAlmostExactSplits(const ValGraph& graph);

} // namespace nvfuser
26 changes: 17 additions & 9 deletions csrc/scheduler/tools/loop_domain_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,19 @@ class LoopDomainScheduler {
public:
LoopDomainScheduler(
std::vector<IterDomain*> ref_loop_dom,
bool update_loop_domain_only = false)
bool update_loop_domain_only = false,
const ValGraph* scheduling_graph = nullptr)
: ref_loop_dom_(std::move(ref_loop_dom)),
update_loop_domain_only_(update_loop_domain_only) {
update_loop_domain_only_(update_loop_domain_only),
graph_(scheduling_graph) {
NVF_ERROR(!ref_loop_dom_.empty());

Fusion* fusion = ref_loop_dom_.front()->fusion();
id_model_ = std::make_unique<IdModel>(fusion, /*build_graphs=*/false);
id_model_->buildExactGraph();
if (graph_ == nullptr) {
Fusion* fusion = ref_loop_dom_.front()->fusion();
id_model_ = std::make_unique<IdModel>(fusion, /*build_graphs=*/false);
id_model_->buildExactGraph();
graph_ = &(id_model_->idGraph(IdMappingMode::EXACT));
}

ref_id_groups_ = graph().toGroups(ref_loop_dom_);

Expand All @@ -203,8 +208,9 @@ class LoopDomainScheduler {
void schedule(TensorView* tv) const;

private:
ValGraph& graph() const {
return id_model_->idGraph(IdMappingMode::EXACT);
const ValGraph& graph() const {
NVF_ERROR(graph_ != nullptr);
return *graph_;
}

ValGraphBFS::ExprPath getReplayPath(TensorView* tv) const;
Expand Down Expand Up @@ -248,6 +254,7 @@ class LoopDomainScheduler {
// updates it to make it look like the given reference loop domain
bool update_loop_domain_only_ = false;
std::unique_ptr<IdModel> id_model_;
const ValGraph* graph_ = nullptr;
ValGroups ref_id_groups_;
ValGroups all_ancestors_of_ref_;
};
Expand Down Expand Up @@ -477,12 +484,13 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const {
void scheduleLoopDomainsLike(
const std::vector<TensorView*>& tvs,
const std::vector<IterDomain*>& ref_loop_dom,
bool update_loop_domain_only) {
bool update_loop_domain_only,
const ValGraph* graph) {
if (tvs.empty()) {
return;
}

LoopDomainScheduler scheduler(ref_loop_dom, update_loop_domain_only);
LoopDomainScheduler scheduler(ref_loop_dom, update_loop_domain_only, graph);

for (auto tv : tvs) {
// Loop domain of fusion inputs should have no meaning,
Expand Down
3 changes: 2 additions & 1 deletion csrc/scheduler/tools/loop_domain_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ namespace scheduler_tools {
void scheduleLoopDomainsLike(
const std::vector<TensorView*>& tvs,
const std::vector<IterDomain*>& ref_loop_dom,
bool update_loop_domain_only = false);
bool update_loop_domain_only = false,
const ValGraph* graph = nullptr);

// Replay a series of transform exprs on the loop domain of each of the given
// tensors. If the replay direction is specified, the exprs are replayed
Expand Down
Loading
Loading