Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce repeat and RepeatOp #3687

Merged
merged 4 commits into from
Jan 10, 2025
Merged

Introduce repeat and RepeatOp #3687

merged 4 commits into from
Jan 10, 2025

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Jan 9, 2025

Adds repeat as an alias op as well as the RepeatOp IR node. The repeat op has almost the same semantics as the PyTorch repeat.

The main motivation is to fix #3682, which is due to #3645, which introduced a preseg pass that detects and translates a repeat pattern to broadcast, expand and reshape. The issue of #3682 is because that the translation-based method does not work when a broadcast ID is repeated. I originally just used TensorDomain::flatten (https://github.com/NVIDIA/Fuser/blob/main/csrc/ir/nodes.cpp#L3674-L3740), which just merges broadcast IDs. However, for reshape, it should not merge but squeeze them. Merging broadcast IDs triggered an assertion of the transpose scheduler as seen in #3682.

TensorDomain::flatten needs to be fixed (#3691), but that's a separate issue. For fixing #3682, since repeating broadcast IDs cannot be translated to the broadcast-expand-reshape pattern anyway, I added the new RepeatOp node. I initially thought it could be just LoadStoreOp but decided to have a different IR node since, unlike usual LoadStore case, some of the broadcast IDs of a producer becomes concrete IDs in the corresponding consumer logical domain. I did actually try using LoadStoreOp but some of the preseg passes complained the mismatched broadcast pattern.

Repeating non-broadcast IDs is still done by the broadcast-expand-reshape patten. Only for repeating broadcast IDs gets represented using the RepeatOp node.

Fixes #3682

Almost same semantics PyTorch repeat.

Previously only partially introduced as a translation from a repeat
pattern using concat. which had a bug when repeating broadcast IDs. This
PR fixes the issue by handling broadcast separately using a new IR node,
RepeatOp, which represents repetition of broadcast IDs.
@naoyam
Copy link
Collaborator Author

naoyam commented Jan 9, 2025

!test

@naoyam naoyam marked this pull request as ready for review January 9, 2025 20:00
@naoyam naoyam requested a review from wujingyue January 9, 2025 20:01
csrc/ops/alias.cpp Outdated Show resolved Hide resolved
}

std::string RepeatOp::toInlineString(int indent_size) const {
NVF_CHECK(false, "Tensor op can not be printed inline");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question for my understanding: is it correct that only IterDomain ops can be printed inline according to our convention?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like that's the case, but I just copied this from ExpandOp::toInlineString.

csrc/ir/nodes.cpp Outdated Show resolved Hide resolved
sizes.reserve(out()->getLogicalDomain().size());
const auto c2p =
PairwiseLogicalDomainMap(in(), out()).mapConsumerToProducer();
for (const auto i : c10::irange(out()->getLogicalDomain().size())) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skip reduction?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is the output, it should not have a reduction ID. I added an assertion in the constructor.

csrc/ir/nodes.cpp Outdated Show resolved Hide resolved
csrc/ir/nodes.cpp Outdated Show resolved Hide resolved
csrc/ir/nodes.cpp Outdated Show resolved Hide resolved
@@ -504,6 +504,10 @@ class ComputeAtLogicalDomainMapBuilder : private BackwardVisitor {
mapPointwiseLikeOp(op);
}

void handle(RepeatOp* op) override {
mapPointwiseLikeOp(op);
Copy link
Collaborator

@wujingyue wujingyue Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it supposed to be mapped as pointwise? The input and the output don't even have the same extent.

tests/cpp/test_gpu3.cpp Outdated Show resolved Hide resolved
csrc/ops/alias.cpp Show resolved Hide resolved
csrc/preseg_passes/translate_repeat_to_expand.cpp Outdated Show resolved Hide resolved
tests/cpp/test_preseg_passes.cpp Show resolved Hide resolved
tests/cpp/test_preseg_passes.cpp Outdated Show resolved Hide resolved
@naoyam
Copy link
Collaborator Author

naoyam commented Jan 10, 2025

!build

@naoyam
Copy link
Collaborator Author

naoyam commented Jan 10, 2025

!build

@naoyam naoyam merged commit 78db5e1 into main Jan 10, 2025
15 of 16 checks passed
@naoyam naoyam deleted the repeat branch January 10, 2025 01:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants