Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Yet another) indexing war for resize #3454

Merged
merged 15 commits into from
Nov 28, 2024
Merged
5 changes: 4 additions & 1 deletion csrc/id_model/id_model_index_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ class IdGraphIndexCompute : public OptOutDispatch {
}

void setIndex(IterDomain* id, Val* idx) {
index_map_.emplace(toGroup(id), idx);
// May overwrite index. When the graph is cyclic due to, e.g.,
// resize, the index obtained by traversing most through the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"by traversing most recently"? So the index map will reflect whichever Expr was most recently processed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's a good question as this will have some related change in a follow-up PR 😃

// indexing path should be used (see also PR #3454)
index_map_[toGroup(id)] = idx;
}

const ValGroup& toGroup(IterDomain* id) const {
Expand Down
19 changes: 10 additions & 9 deletions csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -836,12 +836,14 @@ std::unordered_map<ValGroup, Val*> TensorIndexer::getInitialIndexMap(
std::vector<Val*> TensorIndexer::getIndexFor(
const Expr* expr,
bool as_consumer,
const ValGroups& index_groups,
const std::vector<IterDomain*>& index_ids,
const std::vector<ForLoop*>& for_loops) const {
auto info = computeIndex(expr, index_groups, for_loops);
auto info = computeIndex(expr, index_ids, for_loops);
const auto& replacement_map = getIndexReplacementMap(
expr, as_consumer, info.loop_domains, for_loops, info.index_map);

const auto index_groups = traversalGraph().toGroups(index_ids);

std::vector<Val*> result;
result.reserve(index_groups.size());
for (const auto& g : index_groups) {
Expand Down Expand Up @@ -916,13 +918,13 @@ std::vector<IterDomain*> TensorIndexer::getLoopDomains(const Expr* expr) const {

IndexingInfo TensorIndexer::computeIndex(
const Expr* expr,
const ValGroups& index_groups,
const std::vector<IterDomain*>& index_ids,
const std::vector<ForLoop*>& for_loops) const {
const auto loop_domains = getLoopDomains(expr);
const auto loop_domains = getLoopIds(expr, id_model_);

const ValGroups loop_groups = traversalGraph().toGroups(loop_domains);
const ExprPath<ExprGroup> traversal_path = IndexingTraversal::getExprsBetween(
expr, traversalGraph(), loop_groups, index_groups);
expr, traversalGraph(), loop_domains, index_ids);

const std::unordered_map<ValGroup, Val*> initial_index_map =
getInitialIndexMap(loop_domains, for_loops);
Expand Down Expand Up @@ -1049,8 +1051,8 @@ std::vector<PredicateInfo> TensorIndexer::getPredicates(
const std::vector<IterDomain*>& predicate_domains =
getPredicateDomains(tv, expr);

const IndexingInfo& index_info = computeIndex(
expr, traversalGraph().toGroups(predicate_domains), for_loops);
const IndexingInfo& index_info =
computeIndex(expr, predicate_domains, for_loops);

const auto& index_map = index_info.index_map;

Expand Down Expand Up @@ -1282,8 +1284,7 @@ std::pair<std::vector<Val*>, std::vector<Val*>> TensorIndexer::
bool as_consumer,
const IndexingAllocationInfo& alloc_info,
const std::vector<ForLoop*>& for_loops) const {
const auto& index_groups = traversalGraph().toGroups(alloc_info.domains);
auto index_info = computeIndex(expr, index_groups, for_loops);
auto index_info = computeIndex(expr, alloc_info.domains, for_loops);
const auto& index_map = index_info.index_map;
const auto& replacement_map = getIndexReplacementMap(
expr, as_consumer, index_info.loop_domains, for_loops, index_map);
Expand Down
4 changes: 2 additions & 2 deletions csrc/id_model/indexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class TensorIndexer {
std::vector<Val*> getIndexFor(
const Expr* expr,
bool as_consumer,
const ValGroups& index_groups,
const std::vector<IterDomain*>& index_ids,
const std::vector<ForLoop*>& loops) const;

// Get the contig indices of the given ID groups with their strides
Expand Down Expand Up @@ -137,7 +137,7 @@ class TensorIndexer {
// getIndexFor.
IndexingInfo computeIndex(
const Expr* expr,
const ValGroups& index_groups,
const std::vector<IterDomain*>& index_ids,
const std::vector<ForLoop*>& for_loops) const;

// Propagate the loop indices of a given list of loop domains to the
Expand Down
129 changes: 127 additions & 2 deletions csrc/id_model/indexing_traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <id_model/id_model.h>
#include <id_model/indexing_traversal.h>
#include <ir/utils.h>

Expand All @@ -14,8 +15,9 @@ IndexingTraversal::IndexingTraversal(
const Expr* expr,
const ValGraph& graph,
std::vector<NodeType> from_groups,
std::vector<NodeType> to_groups)
: ValGraphBFS(graph, from_groups, to_groups) {
std::vector<NodeType> to_groups,
bool require_all_to_visited)
: ValGraphBFS(graph, from_groups, to_groups, require_all_to_visited) {
auto consumer_tv = ir_utils::getTvOutput(expr);
NVF_ERROR(consumer_tv != nullptr);
// Remember the resize exprs appearing in the consumer
Expand Down Expand Up @@ -44,4 +46,127 @@ IndexingTraversal::IndexingTraversal(
}
}

std::optional<IndexingTraversal::ExprPath> IndexingTraversal::
getExprsBetweenForResize(
const Expr* expr,
const ValGraph& graph,
const std::vector<IterDomain*>& from_ids,
const std::vector<IterDomain*>& to_ids) {
auto consumer_tv = ir_utils::getTvOutput(expr);
NVF_ERROR(consumer_tv != nullptr);

IdModel local_model(
std::vector<Expr*>{consumer_tv->definition()},
/*additional_tvs=*/{},
/*build_graphs=*/false);

// If there's no resize in the producer and consumer tensors of this
// expr, it should not need this WAR.
if (std::none_of(
local_model.idUses().begin(),
local_model.idUses().end(),
[](const auto& kv) {
const VectorOfUniqueEntries<Expr*>& exprs = kv.second;
return !exprs.empty() && exprs.at(0)->isA<Resize>();
})) {
return std::nullopt;
}

const auto& local_graph = local_model.buildAlmostExactGraph();

// from_ids are loop domains, which are representative
// domains of loop groups and not necessarily domains of any
// of the producer and the consumer. In that case, find an ID out
// of the global group that is mapped in the local graph.
ValGroups from_groups;
for (const auto i : c10::irange(from_ids.size())) {
auto from_id = from_ids.at(i);
if (local_graph.hasGroup(from_id)) {
from_groups.pushBack(local_graph.toGroup(from_id));
continue;
}
bool found = false;
const auto& global_group = graph.toGroup(from_id);
for (const auto& vg : local_graph.disjointValSets().disjointSets()) {
if (global_group->has(vg->front())) {
from_groups.pushBack(vg);
found = true;
break;
}
}
// If not found, it should mean it's promoted to some IDs of
// further consumer tensors. This WAR does not work then. We could
// simply fall back to the default ValGraph-based path, but that
// might hit the resize indexing issue (#3455). For now, this is
// considered an error.
NVF_ERROR(
found, "Indexing path for resize not found: ", from_id->toString());
}

// Similarly, to_ids may not be IDs found in any of the producer and
// consumer tensors of this expr. For example, if it's an allocation
// ID, it may be a loop promotion ID.
ValGroups to_groups;
for (auto to_id : to_ids) {
if (local_graph.hasGroup(to_id)) {
to_groups.pushBack(local_graph.toGroup(to_id));
continue;
}
// to_id is not found in the producer or consumer tensors of the
// expr. Look for a mapped ID in the ID group of the global graph.
bool found = false;
const auto& global_group = graph.toGroup(to_id);
for (const auto& vg : local_graph.disjointValSets().disjointSets()) {
if (global_group->has(vg->front())) {
to_groups.pushBack(vg);
found = true;
break;
}
}
NVF_ERROR(found, "Indexing path for resize not found: ", to_id->toString());
}

IndexingTraversal traversal(
expr,
local_graph,
{from_groups.vector().begin(), from_groups.vector().end()},
{to_groups.vector().begin(), to_groups.vector().end()},
/*require_all_to_visited=*/true);
traversal.traverse();
auto [path, all_visited] = traversal.getShortestExprPath();

for (const auto& [g, d] : path) {
if (g->front()->isA<Resize>()) {
return path;
}
}

// If resize doesn't appear, the default path should work fine.
return std::nullopt;
}

IndexingTraversal::ExprPath IndexingTraversal::getExprsBetween(
const Expr* expr,
const ValGraph& graph,
const std::vector<IterDomain*>& from_domains,
const std::vector<IterDomain*>& to_domains) {
// Take the path if found by the war for resize indexing
if (auto path =
getExprsBetweenForResize(expr, graph, from_domains, to_domains);
path.has_value()) {
return *path;
}

auto from_groups = graph.toGroups(from_domains);
auto to_groups = graph.toGroups(to_domains);

IndexingTraversal traversal(
expr,
graph,
{from_groups.vector().begin(), from_groups.vector().end()},
{to_groups.vector().begin(), to_groups.vector().end()});
traversal.traverse();
return traversal.getShortestExprPath().first;
}

} // namespace nvfuser
21 changes: 10 additions & 11 deletions csrc/id_model/indexing_traversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,22 @@ class IndexingTraversal : public ValGraphBFS {
const Expr* expr,
const ValGraph& graph,
std::vector<NodeType> from_groups,
std::vector<NodeType> to_groups);
std::vector<NodeType> to_groups,
bool require_all_to_visited = true);

~IndexingTraversal() override = default;

static ExprPath getExprsBetween(
const Expr* expr,
const ValGraph& graph,
const ValGroups& from_groups,
const ValGroups& to_groups) {
IndexingTraversal traversal(
expr,
graph,
{from_groups.vector().begin(), from_groups.vector().end()},
{to_groups.vector().begin(), to_groups.vector().end()});
traversal.traverse();
return traversal.getShortestExprPath().first;
}
const std::vector<IterDomain*>& from_domains,
const std::vector<IterDomain*>& to_domains);

static std::optional<ExprPath> getExprsBetweenForResize(
const Expr* expr,
const ValGraph& graph,
const std::vector<IterDomain*>& from_domains,
const std::vector<IterDomain*>& to_domains);

using ValGraphBFS::isVisited;

Expand Down
22 changes: 22 additions & 0 deletions csrc/id_model/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <expr_simplifier.h>
#include <id_model/id_model.h>
#include <id_model/to_string.h>
#include <ir/utils.h>
#include <options.h>
#include <utils.h>

Expand Down Expand Up @@ -106,6 +107,27 @@ inline IterDomain* getLoopPromotion(
return loop_promotion_map_it->second;
}

// Get the loop domains of a given expr. Currently, they're always
// the loop domains of a consumer tensor, but in the future this
// function may return the loop domains of a producer for
// producer-based indexing.
inline std::vector<IterDomain*> getLoopIds(
const Expr* expr,
const IdModel& id_model) {
// Assume consumer-based indexing. Needs to revisit for ops like
// scatter
NVF_ERROR(!expr->outputs().empty());
auto output_tv = ir_utils::getTvOutput(expr);
NVF_ERROR(output_tv != nullptr);
auto loop_ids = output_tv->getLoopDomain();

for (auto& loop_id : loop_ids) {
loop_id = getLoopPromotion(loop_id, id_model);
}

return loop_ids;
}

inline ParallelType getParallelType(const ValGroup& loop_group) {
ParallelType common_pt = ParallelType::Serial;
for (const auto val : *loop_group) {
Expand Down
28 changes: 21 additions & 7 deletions csrc/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1619,12 +1619,10 @@ std::vector<Val*> Index::getConsumerPerDimLogicalIndex(
if (!ir_utils::hasRootToLoopLinearTransformations(consumer_tv) ||
GpuLower::current()->idModelOptions().consumerIndex()) {
const TensorIndexer& indexer = GpuLower::current()->tensorIndexer();
ValGroups logical_indices =
indexer.traversalGraph().toGroups(consumer_tv->getLogicalDomain());
return indexer.getIndexFor(
consumer_tv->definition(),
/*as_consumer=*/true,
logical_indices,
consumer_tv->getLogicalDomain(),
loops);
} else {
auto guard = ir_utils::allocateToLogicalDomainGuard(consumer_tv, false);
Expand All @@ -1644,12 +1642,10 @@ std::vector<Val*> Index::getProducerPerDimLogicalIndex(
if (!ir_utils::hasRootToLoopLinearTransformations(producer_tv) ||
GpuLower::current()->idModelOptions().producerIndex()) {
const TensorIndexer& indexer = GpuLower::current()->tensorIndexer();
ValGroups logical_indices =
indexer.traversalGraph().toGroups(producer_tv->getLogicalDomain());
return indexer.getIndexFor(
consumer_tv->definition(),
/*as_consumer=*/false,
logical_indices,
producer_tv->getLogicalDomain(),
loops);
} else {
auto guard = ir_utils::allocateToLogicalDomainGuard(producer_tv, false);
Expand Down Expand Up @@ -2655,9 +2651,27 @@ std::pair<Val*, Val*> Index::getCpAsyncBulkGmemIndex(

ValGroups groups_to_index = tma_info.getTMADomain();

// TensorIndexer needs IterDomain instead of ValGroup to work around
// the resize indexing issue
std::vector<IterDomain*> ids_to_index;
ids_to_index.reserve(groups_to_index.size());
const auto tma_all_ids = is_load ? consumer_tv->domain()->allIDs()
: producer_tv->domain()->allIDs();
for (const auto& group : groups_to_index) {
auto it = std::find_if(
tma_all_ids.begin(), tma_all_ids.end(), [&](IterDomain* gmem_id) {
return group->has(gmem_id);
});
if (it != tma_all_ids.end()) {
ids_to_index.push_back(*it);
} else {
ids_to_index.push_back(group->front()->as<IterDomain>());
}
}

const TensorIndexer& indexer = GpuLower::current()->tensorIndexer();
auto indices_inner_to_outer =
indexer.getIndexFor(ldst, !is_load, groups_to_index, loops);
indexer.getIndexFor(ldst, !is_load, ids_to_index, loops);

int64_t dim = (int64_t)tma_info.dims().size();
auto coordinate = IrBuilder::arrayExpr(indices_inner_to_outer);
Expand Down
Loading