Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate the repetition pattern to expand and reshape #3645

Merged
merged 7 commits into from
Dec 31, 2024

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Dec 25, 2024

In RoPE, this repeat pattern shows up commonly:

https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L136

torch.cat((freqs, freqs), dim=-1)

This pattern shows up as PadOp followed by CatOp in nvFuser.

This preseg pass translates this pattern to expand and reshape ops. For example, given a pattern like:

t0 = [i0];
t1 = cat({t0, t0}, -1);

It will be translated to:

t0 = [i0]
t2 = broadcast(t0, {true, false});
t3 = expand(t2, {2, i0});
t4 = reshape(t3, {2 * i0});

And all uses of t1 will be replaced by t4.

While the pattern can be handled by the resize scheduler, it's currently limited to segments with pointwise ops only, and its scheduling heuristics are not tuned yet. Specifically, I experimentally observed a significant perf gain with the Mistral backward function.

@naoyam
Copy link
Collaborator Author

naoyam commented Dec 25, 2024

!test --diff

@naoyam naoyam added the rope label Dec 25, 2024
@naoyam
Copy link
Collaborator Author

naoyam commented Dec 25, 2024

!test --diff

@naoyam
Copy link
Collaborator Author

naoyam commented Dec 25, 2024

!test --diff

@naoyam naoyam marked this pull request as ready for review December 25, 2024 04:20
@naoyam naoyam requested a review from jjsjann123 December 25, 2024 04:20
@naoyam naoyam changed the title Translate the repetition pattern with expand and reshape Translate the repetition pattern to expand and reshape Dec 25, 2024
auto tv2 = cat({tv0, tv0}, 1);

fusion.addOutput(tv1);
fusion.addOutput(tv2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, when we say there's nothing to allow the output IDs to be mapped, Is it saying that IDs of tv1 and tv2 are dis-connected?

Since the iter domain of tv1 and tv2 are produced by the same IDs with constants, wouldn't IdModel still be able to identify them as exact mapping?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First of all, the comment at line 911 is wrong. Both cat ops use the same ID.

Yes, tv1 and tv2 are disconnected. No property that we use to build IdModel would detect the two IDs of the tensors have the same extent. It doesn't mean we can't extend IdModel to detect that, it's just not how it's implemented at this moment. Unless there's any motivating case, I don't think we would need to worry too much about it.

void inspect() {
const auto exprs = fusion_->exprs();

for (auto pad : ir_utils::filterByType<PadOp>(exprs)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: what's the reason that we choose to filter on PadOp here, instead of looking for CatOp and iterate through all its input?

It's just a nitpicking question, I think when we iterate through each PadOp, we could repetitively remove entries and adding things back into repeat_info_map_.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was initially also thinking about supporting concatenation without CatOp. I think I have seen a sequence of ops like a PadOp followed by an addition, maybe in some backward kernels.

I dropped that since I don't find that pattern for repetitions.
4c1abda

csrc/preseg_passes/translate_repeat_to_expand.cpp Outdated Show resolved Hide resolved
@naoyam
Copy link
Collaborator Author

naoyam commented Dec 30, 2024

!test

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@naoyam
Copy link
Collaborator Author

naoyam commented Dec 31, 2024

!build

@naoyam naoyam merged commit 6466834 into main Dec 31, 2024
15 of 16 checks passed
@naoyam naoyam deleted the translate_repeat_pattern branch December 31, 2024 05:22
naoyam added a commit that referenced this pull request Jan 10, 2025
Adds `repeat` as an alias op as well as the `RepeatOp` IR node. The
`repeat` op has almost the same semantics as the PyTorch repeat.

The main motivation is to fix #3682, which is due to #3645, which
introduced a preseg pass that detects and translates a repeat pattern to
broadcast, expand and reshape. The issue of #3682 is because that the
translation-based method does not work when a broadcast ID is repeated.
I originally just used `TensorDomain::flatten`
(https://github.com/NVIDIA/Fuser/blob/main/csrc/ir/nodes.cpp#L3674-L3740),
which just merges broadcast IDs. However, for reshape, it should not
merge but squeeze them. Merging broadcast IDs triggered an assertion of
the transpose scheduler as seen in #3682.

`TensorDomain::flatten` needs to be fixed (#3691), but that's a separate
issue. For fixing #3682, since repeating broadcast IDs cannot be
translated to the broadcast-expand-reshape pattern anyway, I added the
new `RepeatOp` node. I initially thought it could be just `LoadStoreOp`
but decided to have a different IR node since, unlike usual LoadStore
case, some of the broadcast IDs of a producer becomes concrete IDs in
the corresponding consumer logical domain. I did actually try using
`LoadStoreOp` but some of the preseg passes complained the mismatched
broadcast pattern.

Repeating non-broadcast IDs is still done by the
broadcast-expand-reshape patten. Only for repeating broadcast IDs gets
represented using the `RepeatOp` node.

Fixes #3682
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants