Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into wjy/comm
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Nov 30, 2024
2 parents 7ccfccd + c154e90 commit 39f2809
Show file tree
Hide file tree
Showing 18 changed files with 630 additions and 152 deletions.
45 changes: 35 additions & 10 deletions csrc/id_model/id_model_index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,25 @@ void IdGraphIndexCompute::handle(Split* split) {
auto inner_extent = split->inner()->extent();

if (is_forward) {
auto in_idx = getIndex(split->in());
auto outer_idx = SimplifyingIrBuilder::divExpr(in_idx, inner_extent);
Val* inner_idx = SimplifyingIrBuilder::modExpr(in_idx, inner_extent);
setIndex(split->outer(), outer_idx);
setIndex(split->inner(), inner_idx);
// When propagating Split forward, if one of the outputs is mapped
// with the input (because of the almost-exact mapping), don't
// update the index and just set 0 as the index of the other
// output. This is necessary when the other output is a broadcast
// ID, which is ignored for predication. See
// IndexingTest.AlmostExactIndexingUpdate for a concrete example.
if (traversal_graph_.disjointValSets().strictAreMapped(
split->in(), split->inner())) {
setIndex(split->outer(), split->fusion()->zeroVal());
} else if (traversal_graph_.disjointValSets().strictAreMapped(
split->in(), split->outer())) {
setIndex(split->inner(), split->fusion()->zeroVal());
} else {
auto in_idx = getIndex(split->in());
auto outer_idx = SimplifyingIrBuilder::divExpr(in_idx, inner_extent);
Val* inner_idx = SimplifyingIrBuilder::modExpr(in_idx, inner_extent);
setIndex(split->outer(), outer_idx);
setIndex(split->inner(), inner_idx);
}
} else {
auto outer_idx = getIndex(split->outer());
auto inner_idx = getIndex(split->inner());
Expand All @@ -43,11 +57,22 @@ void IdGraphIndexCompute::handle(Merge* merge) {
SimplifyingIrBuilder::mulExpr(outer_idx, inner_ext), inner_idx);
setIndex(merge->out(), out_idx);
} else {
auto out_idx = getIndex(merge->out());
auto outer_idx = SimplifyingIrBuilder::divExpr(out_idx, inner_ext);
setIndex(merge->outer(), outer_idx);
Val* inner_idx = SimplifyingIrBuilder::modExpr(out_idx, inner_ext);
setIndex(merge->inner(), inner_idx);
// Similar to the forward propagation of Split, when propagating Merge
// backward, if one of the inputs is mapped with the output, don't update
// the index and just set 0 as the index of the other input.
if (traversal_graph_.disjointValSets().strictAreMapped(
merge->out(), merge->inner())) {
setIndex(merge->outer(), merge->fusion()->zeroVal());
} else if (traversal_graph_.disjointValSets().strictAreMapped(
merge->out(), merge->outer())) {
setIndex(merge->inner(), merge->fusion()->zeroVal());
} else {
auto out_idx = getIndex(merge->out());
auto outer_idx = SimplifyingIrBuilder::divExpr(out_idx, inner_ext);
setIndex(merge->outer(), outer_idx);
Val* inner_idx = SimplifyingIrBuilder::modExpr(out_idx, inner_ext);
setIndex(merge->inner(), inner_idx);
}
}
}

Expand Down
5 changes: 4 additions & 1 deletion csrc/id_model/id_model_index_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ class IdGraphIndexCompute : public OptOutDispatch {
}

void setIndex(IterDomain* id, Val* idx) {
index_map_.emplace(toGroup(id), idx);
// May overwrite index. When the graph is cyclic due to, e.g.,
// resize, the index obtained by traversing most through the
// indexing path should be used (see also PR #3454)
index_map_[toGroup(id)] = idx;
}

const ValGroup& toGroup(IterDomain* id) const {
Expand Down
50 changes: 25 additions & 25 deletions csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,9 @@ void TensorIndexer::buildLoopIndexMap() {
}
}

Val* TensorIndexer::getLoopIndex(IterDomain* loop_id) const {
Val* TensorIndexer::getLoopIndex(
IterDomain* loop_id,
const std::vector<ForLoop*>& for_loops) const {
// loop_id must be a loop domain.
const auto& loop_group =
id_model_.idGraph(IdMappingMode::LOOP).toGroup(loop_id);
Expand All @@ -792,6 +794,13 @@ Val* TensorIndexer::getLoopIndex(IterDomain* loop_id) const {
loop_id->toString());

Val* loop_index = loop_index_map_it->second;

// War for circular buffering
if (auto circular_buffer_loop_index =
getLoopIndexOfCircularBufferLoop(loop_id, for_loops, id_model_)) {
loop_index = circular_buffer_loop_index;
}

return loop_index;
}

Expand All @@ -803,16 +812,16 @@ std::unordered_map<ValGroup, Val*> TensorIndexer::getInitialIndexMap(
// For a given list of the loop domains, assign its corresponding
// index Val.
for (IterDomain* loop_id : loop_domains) {
Val* loop_index = getLoopIndex(loop_id);
Val* initial_index = getLoopIndex(loop_id, for_loops);
const auto& almost_exact_group = traversalGraph().toGroup(loop_id);

if (initial_index_map.find(almost_exact_group) != initial_index_map.end()) {
// Initial index already set. This can happen as this is an
// almost exact group. It should be just size-1 domain.
NVF_ERROR(
loop_index->isZeroInt(),
initial_index->isZeroInt(),
"Unexpected initial index: ",
loop_index->toInlineString());
initial_index->toInlineString());
auto existing_index = initial_index_map.at(almost_exact_group);
NVF_ERROR(
existing_index->isZeroInt(),
Expand All @@ -821,13 +830,7 @@ std::unordered_map<ValGroup, Val*> TensorIndexer::getInitialIndexMap(
continue;
}

// War for circular buffering
if (auto circular_buffer_loop_index =
getLoopIndexOfCircularBufferLoop(loop_id, for_loops, id_model_)) {
loop_index = circular_buffer_loop_index;
}

initial_index_map.emplace(almost_exact_group, loop_index);
initial_index_map.emplace(almost_exact_group, initial_index);
}

return initial_index_map;
Expand All @@ -836,12 +839,14 @@ std::unordered_map<ValGroup, Val*> TensorIndexer::getInitialIndexMap(
std::vector<Val*> TensorIndexer::getIndexFor(
const Expr* expr,
bool as_consumer,
const ValGroups& index_groups,
const std::vector<IterDomain*>& index_ids,
const std::vector<ForLoop*>& for_loops) const {
auto info = computeIndex(expr, index_groups, for_loops);
auto info = computeIndex(expr, index_ids, for_loops);
const auto& replacement_map = getIndexReplacementMap(
expr, as_consumer, info.loop_domains, for_loops, info.index_map);

const auto index_groups = traversalGraph().toGroups(index_ids);

std::vector<Val*> result;
result.reserve(index_groups.size());
for (const auto& g : index_groups) {
Expand Down Expand Up @@ -916,13 +921,13 @@ std::vector<IterDomain*> TensorIndexer::getLoopDomains(const Expr* expr) const {

IndexingInfo TensorIndexer::computeIndex(
const Expr* expr,
const ValGroups& index_groups,
const std::vector<IterDomain*>& index_ids,
const std::vector<ForLoop*>& for_loops) const {
const auto loop_domains = getLoopDomains(expr);
const auto loop_domains = getLoopIds(expr, id_model_);

const ValGroups loop_groups = traversalGraph().toGroups(loop_domains);
const ExprPath<ExprGroup> traversal_path = IndexingTraversal::getExprsBetween(
expr, traversalGraph(), loop_groups, index_groups);
expr, traversalGraph(), loop_domains, index_ids);

const std::unordered_map<ValGroup, Val*> initial_index_map =
getInitialIndexMap(loop_domains, for_loops);
Expand Down Expand Up @@ -978,11 +983,7 @@ std::unordered_map<Val*, Val*> TensorIndexer::getIndexReplacementMap(
std::unordered_map<Val*, Val*> replacement_map;

for (const auto loop_id : loop_domains) {
const ValGroup& loop_group = traversalGraph().toGroup(loop_id);
auto index_it = index_map.find(loop_group);
NVF_ERROR(index_it != index_map.end());
Val* cur_index = index_it->second;
NVF_ERROR(cur_index != nullptr);
Val* cur_index = getLoopIndex(loop_id, for_loops);

Val* replacement_index = nullptr;
// Replace the index of a vectorized/bulk domain with zero. Note that
Expand Down Expand Up @@ -1049,8 +1050,8 @@ std::vector<PredicateInfo> TensorIndexer::getPredicates(
const std::vector<IterDomain*>& predicate_domains =
getPredicateDomains(tv, expr);

const IndexingInfo& index_info = computeIndex(
expr, traversalGraph().toGroups(predicate_domains), for_loops);
const IndexingInfo& index_info =
computeIndex(expr, predicate_domains, for_loops);

const auto& index_map = index_info.index_map;

Expand Down Expand Up @@ -1282,8 +1283,7 @@ std::pair<std::vector<Val*>, std::vector<Val*>> TensorIndexer::
bool as_consumer,
const IndexingAllocationInfo& alloc_info,
const std::vector<ForLoop*>& for_loops) const {
const auto& index_groups = traversalGraph().toGroups(alloc_info.domains);
auto index_info = computeIndex(expr, index_groups, for_loops);
auto index_info = computeIndex(expr, alloc_info.domains, for_loops);
const auto& index_map = index_info.index_map;
const auto& replacement_map = getIndexReplacementMap(
expr, as_consumer, index_info.loop_domains, for_loops, index_map);
Expand Down
9 changes: 5 additions & 4 deletions csrc/id_model/indexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,15 @@ class TensorIndexer {
const Expr* expr,
const std::vector<ForLoop*>& loops) const;

// Get the index of a loop domain. Intended to be used only for testing.
Val* getLoopIndex(IterDomain* loop_id) const;
// Get the index of a loop domain.
Val* getLoopIndex(IterDomain* loop_id, const std::vector<ForLoop*>& for_loops)
const;

// Get the index of the given ID groups
std::vector<Val*> getIndexFor(
const Expr* expr,
bool as_consumer,
const ValGroups& index_groups,
const std::vector<IterDomain*>& index_ids,
const std::vector<ForLoop*>& loops) const;

// Get the contig indices of the given ID groups with their strides
Expand Down Expand Up @@ -137,7 +138,7 @@ class TensorIndexer {
// getIndexFor.
IndexingInfo computeIndex(
const Expr* expr,
const ValGroups& index_groups,
const std::vector<IterDomain*>& index_ids,
const std::vector<ForLoop*>& for_loops) const;

// Propagate the loop indices of a given list of loop domains to the
Expand Down
129 changes: 127 additions & 2 deletions csrc/id_model/indexing_traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <id_model/id_model.h>
#include <id_model/indexing_traversal.h>
#include <ir/utils.h>

Expand All @@ -14,8 +15,9 @@ IndexingTraversal::IndexingTraversal(
const Expr* expr,
const ValGraph& graph,
std::vector<NodeType> from_groups,
std::vector<NodeType> to_groups)
: ValGraphBFS(graph, from_groups, to_groups) {
std::vector<NodeType> to_groups,
bool require_all_to_visited)
: ValGraphBFS(graph, from_groups, to_groups, require_all_to_visited) {
auto consumer_tv = ir_utils::getTvOutput(expr);
NVF_ERROR(consumer_tv != nullptr);
// Remember the resize exprs appearing in the consumer
Expand Down Expand Up @@ -44,4 +46,127 @@ IndexingTraversal::IndexingTraversal(
}
}

std::optional<IndexingTraversal::ExprPath> IndexingTraversal::
getExprsBetweenForResize(
const Expr* expr,
const ValGraph& graph,
const std::vector<IterDomain*>& from_ids,
const std::vector<IterDomain*>& to_ids) {
auto consumer_tv = ir_utils::getTvOutput(expr);
NVF_ERROR(consumer_tv != nullptr);

IdModel local_model(
std::vector<Expr*>{consumer_tv->definition()},
/*additional_tvs=*/{},
/*build_graphs=*/false);

// If there's no resize in the producer and consumer tensors of this
// expr, it should not need this WAR.
if (std::none_of(
local_model.idUses().begin(),
local_model.idUses().end(),
[](const auto& kv) {
const VectorOfUniqueEntries<Expr*>& exprs = kv.second;
return !exprs.empty() && exprs.at(0)->isA<Resize>();
})) {
return std::nullopt;
}

const auto& local_graph = local_model.buildAlmostExactGraph();

// from_ids are loop domains, which are representative
// domains of loop groups and not necessarily domains of any
// of the producer and the consumer. In that case, find an ID out
// of the global group that is mapped in the local graph.
ValGroups from_groups;
for (const auto i : c10::irange(from_ids.size())) {
auto from_id = from_ids.at(i);
if (local_graph.hasGroup(from_id)) {
from_groups.pushBack(local_graph.toGroup(from_id));
continue;
}
bool found = false;
const auto& global_group = graph.toGroup(from_id);
for (const auto& vg : local_graph.disjointValSets().disjointSets()) {
if (global_group->has(vg->front())) {
from_groups.pushBack(vg);
found = true;
break;
}
}
// If not found, it should mean it's promoted to some IDs of
// further consumer tensors. This WAR does not work then. We could
// simply fall back to the default ValGraph-based path, but that
// might hit the resize indexing issue (#3455). For now, this is
// considered an error.
NVF_ERROR(
found, "Indexing path for resize not found: ", from_id->toString());
}

// Similarly, to_ids may not be IDs found in any of the producer and
// consumer tensors of this expr. For example, if it's an allocation
// ID, it may be a loop promotion ID.
ValGroups to_groups;
for (auto to_id : to_ids) {
if (local_graph.hasGroup(to_id)) {
to_groups.pushBack(local_graph.toGroup(to_id));
continue;
}
// to_id is not found in the producer or consumer tensors of the
// expr. Look for a mapped ID in the ID group of the global graph.
bool found = false;
const auto& global_group = graph.toGroup(to_id);
for (const auto& vg : local_graph.disjointValSets().disjointSets()) {
if (global_group->has(vg->front())) {
to_groups.pushBack(vg);
found = true;
break;
}
}
NVF_ERROR(found, "Indexing path for resize not found: ", to_id->toString());
}

IndexingTraversal traversal(
expr,
local_graph,
{from_groups.vector().begin(), from_groups.vector().end()},
{to_groups.vector().begin(), to_groups.vector().end()},
/*require_all_to_visited=*/true);
traversal.traverse();
auto [path, all_visited] = traversal.getShortestExprPath();

for (const auto& [g, d] : path) {
if (g->front()->isA<Resize>()) {
return path;
}
}

// If resize doesn't appear, the default path should work fine.
return std::nullopt;
}

IndexingTraversal::ExprPath IndexingTraversal::getExprsBetween(
const Expr* expr,
const ValGraph& graph,
const std::vector<IterDomain*>& from_domains,
const std::vector<IterDomain*>& to_domains) {
// Take the path if found by the war for resize indexing
if (auto path =
getExprsBetweenForResize(expr, graph, from_domains, to_domains);
path.has_value()) {
return *path;
}

auto from_groups = graph.toGroups(from_domains);
auto to_groups = graph.toGroups(to_domains);

IndexingTraversal traversal(
expr,
graph,
{from_groups.vector().begin(), from_groups.vector().end()},
{to_groups.vector().begin(), to_groups.vector().end()});
traversal.traverse();
return traversal.getShortestExprPath().first;
}

} // namespace nvfuser
Loading

0 comments on commit 39f2809

Please sign in to comment.