-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Translate the repetition pattern to expand and reshape #3645
Conversation
!test --diff |
!test --diff |
!test --diff |
auto tv2 = cat({tv0, tv0}, 1); | ||
|
||
fusion.addOutput(tv1); | ||
fusion.addOutput(tv2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out of curiosity, when we say there's nothing to allow the output IDs to be mapped
, Is it saying that IDs of tv1
and tv2
are dis-connected?
Since the iter domain of tv1 and tv2 are produced by the same IDs with constants, wouldn't IdModel still be able to identify them as exact mapping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First of all, the comment at line 911 is wrong. Both cat ops use the same ID.
Yes, tv1
and tv2
are disconnected. No property that we use to build IdModel would detect the two IDs of the tensors have the same extent. It doesn't mean we can't extend IdModel to detect that, it's just not how it's implemented at this moment. Unless there's any motivating case, I don't think we would need to worry too much about it.
void inspect() { | ||
const auto exprs = fusion_->exprs(); | ||
|
||
for (auto pad : ir_utils::filterByType<PadOp>(exprs)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: what's the reason that we choose to filter on PadOp
here, instead of looking for CatOp
and iterate through all its input?
It's just a nitpicking question, I think when we iterate through each PadOp
, we could repetitively remove entries and adding things back into repeat_info_map_
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I was initially also thinking about supporting concatenation without CatOp
. I think I have seen a sequence of ops like a PadOp
followed by an addition, maybe in some backward kernels.
I dropped that since I don't find that pattern for repetitions.
4c1abda
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
!build |
Adds `repeat` as an alias op as well as the `RepeatOp` IR node. The `repeat` op has almost the same semantics as the PyTorch repeat. The main motivation is to fix #3682, which is due to #3645, which introduced a preseg pass that detects and translates a repeat pattern to broadcast, expand and reshape. The issue of #3682 is because that the translation-based method does not work when a broadcast ID is repeated. I originally just used `TensorDomain::flatten` (https://github.com/NVIDIA/Fuser/blob/main/csrc/ir/nodes.cpp#L3674-L3740), which just merges broadcast IDs. However, for reshape, it should not merge but squeeze them. Merging broadcast IDs triggered an assertion of the transpose scheduler as seen in #3682. `TensorDomain::flatten` needs to be fixed (#3691), but that's a separate issue. For fixing #3682, since repeating broadcast IDs cannot be translated to the broadcast-expand-reshape pattern anyway, I added the new `RepeatOp` node. I initially thought it could be just `LoadStoreOp` but decided to have a different IR node since, unlike usual LoadStore case, some of the broadcast IDs of a producer becomes concrete IDs in the corresponding consumer logical domain. I did actually try using `LoadStoreOp` but some of the preseg passes complained the mismatched broadcast pattern. Repeating non-broadcast IDs is still done by the broadcast-expand-reshape patten. Only for repeating broadcast IDs gets represented using the `RepeatOp` node. Fixes #3682
In RoPE, this repeat pattern shows up commonly:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L136
This pattern shows up as
PadOp
followed byCatOp
in nvFuser.This preseg pass translates this pattern to expand and reshape ops. For example, given a pattern like:
It will be translated to:
And all uses of
t1
will be replaced byt4
.While the pattern can be handled by the resize scheduler, it's currently limited to segments with pointwise ops only, and its scheduling heuristics are not tuned yet. Specifically, I experimentally observed a significant perf gain with the Mistral backward function.