-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce repeat and RepeatOp #3687
Conversation
Almost same semantics PyTorch repeat. Previously only partially introduced as a translation from a repeat pattern using concat. which had a bug when repeating broadcast IDs. This PR fixes the issue by handling broadcast separately using a new IR node, RepeatOp, which represents repetition of broadcast IDs.
!test |
} | ||
|
||
std::string RepeatOp::toInlineString(int indent_size) const { | ||
NVF_CHECK(false, "Tensor op can not be printed inline"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question for my understanding: is it correct that only IterDomain ops can be printed inline according to our convention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like that's the case, but I just copied this from ExpandOp::toInlineString
.
sizes.reserve(out()->getLogicalDomain().size()); | ||
const auto c2p = | ||
PairwiseLogicalDomainMap(in(), out()).mapConsumerToProducer(); | ||
for (const auto i : c10::irange(out()->getLogicalDomain().size())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skip reduction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is the output, it should not have a reduction ID. I added an assertion in the constructor.
@@ -504,6 +504,10 @@ class ComputeAtLogicalDomainMapBuilder : private BackwardVisitor { | |||
mapPointwiseLikeOp(op); | |||
} | |||
|
|||
void handle(RepeatOp* op) override { | |||
mapPointwiseLikeOp(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it supposed to be mapped as pointwise? The input and the output don't even have the same extent.
!build |
!build |
Adds
repeat
as an alias op as well as theRepeatOp
IR node. Therepeat
op has almost the same semantics as the PyTorch repeat.The main motivation is to fix #3682, which is due to #3645, which introduced a preseg pass that detects and translates a repeat pattern to broadcast, expand and reshape. The issue of #3682 is because that the translation-based method does not work when a broadcast ID is repeated. I originally just used
TensorDomain::flatten
(https://github.com/NVIDIA/Fuser/blob/main/csrc/ir/nodes.cpp#L3674-L3740), which just merges broadcast IDs. However, for reshape, it should not merge but squeeze them. Merging broadcast IDs triggered an assertion of the transpose scheduler as seen in #3682.TensorDomain::flatten
needs to be fixed (#3691), but that's a separate issue. For fixing #3682, since repeating broadcast IDs cannot be translated to the broadcast-expand-reshape pattern anyway, I added the newRepeatOp
node. I initially thought it could be justLoadStoreOp
but decided to have a different IR node since, unlike usual LoadStore case, some of the broadcast IDs of a producer becomes concrete IDs in the corresponding consumer logical domain. I did actually try usingLoadStoreOp
but some of the preseg passes complained the mismatched broadcast pattern.Repeating non-broadcast IDs is still done by the broadcast-expand-reshape patten. Only for repeating broadcast IDs gets represented using the
RepeatOp
node.Fixes #3682