Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expand RemoveBcastSqueeze to handle unary operations between broadcast/squeeze ops #3643

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 92 additions & 14 deletions csrc/preseg_passes/remove_bcast_squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <multidevice/utils.h>
#include <ops/alias.h>
#include <ops/arith.h>
#include <ops/utils.h>
#include <options.h>
#include <preseg_passes/remove_bcast_squeeze.h>

Expand Down Expand Up @@ -155,6 +156,26 @@ std::vector<bool> nonPreservedDims(const AxisOps& ops) {
return flags;
}

TensorView* replayAxisOp(
AxisOp simple_op_type,
const AxisOps& axis_ops,
TensorView* tv) {
switch (simple_op_type) {
case AxisOp::PRESERVE:
// This is equivalent to a set Op
return tv;
break;
case AxisOp::SQUEEZE:
return squeeze(tv, nonPreservedDims(axis_ops), true);
break;
case AxisOp::BROADCAST:
return broadcast(tv, nonPreservedDims(axis_ops));
break;
}
NVF_ERROR(false, "unrecognized AxisOp type in replayAxisOp");
return nullptr;
}

//! Given a descriptors of two sequences of broadcast+squeeze ops, return a
//! descriptor of their composition
AxisOps composeOps(const AxisOps& prev, const AxisOps& next) {
Expand Down Expand Up @@ -318,13 +339,79 @@ TensorView* maybeDoReplacement(TensorView* orig) {
if (!isReplaceableExpr(second)) {
return orig;
}
AxisOps second_ops = exprToAxisOps(second);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a hard time to understand what this function (maybeDoReplacement) is doing. What is the parameter assumed to be? What is supposed to be returned?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybeDoReplacement is trying to merge tv->first->second->orig as a tv->merged->new_out when both first and second are replaceable exprs.

i.e. when we have tv->broadcast->squeeze, we might be able to just cancel the two and ended up returning a tv directly.

The function returns the new_out after the replay. The logic here is that:
if the returned pointer is different from orig, it would consider a replacement has happened and would try to the same loop with new_out;
if the returned pointer is the same as orig, merge failed, it would skip second here and move on and push inputs to second as new candidate as orig in the stack.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the added logic here is, when we try to swap tv->first->second->orig as tv->replayed_second->replayed_first, we return replayed_second->output(0).

Even though we are not merging two consecutive replaceable operations, by returning replayed_second->output(0) instead of orig, we kept replayed_second as the candidate for the iteration, effectively skipped unary-op first from preventing us merging neighboring replaceable operations.


Expr* first = second->input(0)->definition();
if (!isReplaceableExpr(first)) {
// when second is an axis op, while first is not. We try to swap first and
// second. This allows us to opportunistically put two axis ops.
// e.g.
// T1 = broadcast(T0)
// T2 = relu(T1)
// T3 = squeeze(T2)
// In the iteration where squeeze is `second` and relu is `first`, if we
// swap the two operations, we'll ended up with
// T1 = broadcast(T0)
// replayed_T2 = replayed_squeeze(T1)
// replayed_T3 = replayed_relu(replayed_T2)
// The following iteration will have an opportunity to merge the broacast
// and the replayed_squeeze together.
if (auto uop = dynamic_cast<UnaryOp*>(first)) {
// replace [unary-op -> second] with:
// [second -> unary-op]
// skip if we need to preserve the output from unary-op.
if (uop->out()->isFusionOutput() || uop->out()->uses().size() > 1) {
return orig;
}

// make sure we preserve the allcoation domain on second->output(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the allocation domain matter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answered in the example below. I think I can use another comment here as well.

// initializing alloc_domain permutation of second output.
auto second_out_tv = second->output(0)->as<TensorView>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this second_out_tv always the same as orig?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I now realized that I could have just used orig instead.

std::optional<std::vector<int64_t>> second_out_allocation_permutation =
ir_utils::computePermutation(
second_out_tv->getLogicalDomain(),
second_out_tv->getMaybeAllocationDomain());
// We only support simple permutation, any complex transformation is not
// allowed
if (!second_out_allocation_permutation.has_value()) {
return orig;
}

TensorView* uop_in_tv = uop->in()->as<TensorView>();

// replay second on unary-op input
std::optional<AxisOp> second_op_type_opt =
getSimplifiedOpType(second_ops);
TensorView* replayed_second_out =
replayAxisOp(second_op_type_opt.value(), second_ops, uop_in_tv);

// replay uop on the replayed second's output
Val* replayed_uop_out = ops::newValLike(
replayed_second_out, uop->out()->getDataType().value());

// restore allocation domain on replayed_uop_out
auto replayed_uop_out_tv = replayed_uop_out->as<TensorView>();
replayed_uop_out_tv->setAllocationDomain(
ir_utils::applyPermutation(
replayed_uop_out_tv->getLogicalDomain(),
second_out_allocation_permutation.value()),
true);

IrBuilder::create<UnaryOp>(
uop->getUnaryOpType(), replayed_uop_out, replayed_second_out);

// replace uses of second output with replayed unary-op out
ir_utils::replaceValInAllExprInputsAndFusionOutputs(
second->output(0), replayed_uop_out);

// return replayed_second_out to indicate replacement.
return replayed_second_out;
}
// return orig to indicate no replacement.
return orig;
}

AxisOps first_ops = exprToAxisOps(first);
AxisOps second_ops = exprToAxisOps(second);

AxisOps simplified_ops = composeOps(first_ops, second_ops);
std::optional<AxisOp> simple_op_type_opt =
getSimplifiedOpType(simplified_ops);
Expand All @@ -337,18 +424,8 @@ TensorView* maybeDoReplacement(TensorView* orig) {
replacement = first->output(0)->as<TensorView>();
} else {
TensorView* input_tv = first->input(0)->as<TensorView>();
switch (simple_op_type_opt.value()) {
case AxisOp::PRESERVE:
// This is equivalent to a set Op
replacement = input_tv;
break;
case AxisOp::SQUEEZE:
replacement = squeeze(input_tv, nonPreservedDims(simplified_ops));
break;
case AxisOp::BROADCAST:
replacement = broadcast(input_tv, nonPreservedDims(simplified_ops));
break;
}
replacement =
replayAxisOp(simple_op_type_opt.value(), simplified_ops, input_tv);
}
NVF_ERROR(replacement != orig, "Expected non-trivial replacement");

Expand Down Expand Up @@ -406,6 +483,7 @@ TensorView* maybeDoReplacement(TensorView* orig) {

// Remove broadcast-squeeze and squeeze-broadcast patterns
void removeBcastSqueeze(Fusion* fusion) {
FusionGuard fg(fusion);
// Iterate from outputs toward producers using a depth-first search for
// replaceable patterns
std::vector<TensorView*> stack;
Expand Down
151 changes: 151 additions & 0 deletions tests/cpp/test_preseg_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ops/all_ops.h>
#include <preseg_passes/optimization_pass.h>
#include <preseg_passes/pre_segmenter.h>
#include <preseg_passes/remove_bcast_squeeze.h>
#include <preseg_passes/translate_repeat_to_expand.h>
#include <tests/cpp/utils.h>
#include <tests/cpp/validator.h>
Expand Down Expand Up @@ -982,4 +983,154 @@ TEST_F(PresegTest, TranslateRepeatToExpand5) {
EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise);
}

TEST_F(PresegTest, FusionRemoveBroadcastSqueeze0) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr;
FusionGuard fg(&fusion);

auto tv0 = makeContigConcreteTensor({2, 3, 4, 5});
fusion.addInput(tv0);
auto tv1 = broadcast(tv0, {true, false, false, false, false});
auto tv2 = relu(tv1);
auto tv3 = squeeze(tv2, {0});
// specify output permutation;
std::vector<IterDomain*> tv3_nhwc = {
tv3->axis(0), tv3->axis(2), tv3->axis(3), tv3->axis(1)};
tv3->setAllocationDomain(tv3_nhwc, true);
fusion.addOutput(tv3);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the reason why we care about allocation domain.

i.e. tv1->relu->tv2->squeeze->tv3. Here tv3 has an allocation domain that's a permutation.
when we replace it as tv1->replayed_squeeze->tv4->replayed_relu->tv5. We need to ensure that tv5 has the same allocation domain as with tv3, otherwise we are going to change the semantics and return an output with the wrong memory format.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not saying we should ignore the allocation domain. I just don't see why having an allocation domain can interfere this translation. Why not just keep using tv3? Or, it should also be possible to reproduce the same allocation domain with tv5.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mistaken what you meant in your earlier question!

Why not just keep using tv3?

By keep using tv3, do you mean that I can have it replayed as tv1->replayed_squeeze->tv4->replayed_relu->tv3, I didn't realized that I can just re-use tv3 here, without needing to create a clone of it. Let me try that...

Or, it should also be possible to reproduce the same allocation domain with tv5.

Yes. I was just trying to keep it simple. If we want to support general transformations, I think I can just do the same replay I did in #3644 https://github.com/NVIDIA/Fuser/pull/3644/files#diff-abe2e10add90523ff6b18e1dc50da46762420e1011078ba47ab52140dc213b6fR80-R85.


{
// Make sure squeeze/broadcast no longer exists
Fusion fusion_copy = fusion;
OptimizationPass<RemoveBcastSqueeze>::runPass(&fusion_copy);
auto new_exprs = fusion_copy.exprs();
EXPECT_EQ(
std::find_if(
new_exprs.begin(),
new_exprs.end(),
[](Expr* new_expr) {
return new_expr->isOneOf<BroadcastOp, SqueezeOp>();
}),
new_exprs.end());
}

auto options = at::TensorOptions().device(at::kCUDA, 0);
auto t0 = at::randn({2, 3, 4, 5}, options);
std::vector<c10::IValue> inputs = {t0};
FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs(inputs);
// validate output permutation is preserved
ASSERT_TRUE(outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the allocation domain update, this check would fail and it's optimization pass changing the user intended behavior.

testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__);
}

TEST_F(PresegTest, FusionRemoveBroadcastSqueeze1) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr;
FusionGuard fg(&fusion);

auto tv0 = makeContigConcreteTensor({1, 3, 4, 5});
fusion.addInput(tv0);
auto tv1 = reshape(tv0, {1, 3, 4, 5}, {1, 3, 4 * 5});
// replay tv1 have rfactor product in IDs.
auto tv2 = relu(tv1);
auto tv3 = broadcast(tv2, {true, false, false, false});
fusion.addOutput(tv3);

{
// broadcast shouldn't be removed
Fusion fusion_copy = fusion;
OptimizationPass<RemoveBcastSqueeze>::runPass(&fusion_copy);
auto new_exprs = fusion_copy.exprs();
EXPECT_NE(
std::find_if(
new_exprs.begin(),
new_exprs.end(),
[](Expr* new_expr) { return new_expr->isA<BroadcastOp>(); }),
new_exprs.end());
}

auto options = at::TensorOptions().device(at::kCUDA, 0);
auto t0 = at::randn({1, 3, 4, 5}, options);
std::vector<c10::IValue> inputs = {t0};
FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs(inputs);
testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__);
}

TEST_F(PresegTest, FusionRemoveBroadcastSqueeze2) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr;
FusionGuard fg(&fusion);

auto tv0 = makeContigConcreteTensor({2, 3, 4, 5});
fusion.addInput(tv0);
auto tv1 = broadcast(tv0, {true, false, false, false, false});
auto tv2 = relu(tv1);
// tv2 is also an output, remove broadcast squeeze pass will not replay the
// broadcast
fusion.addOutput(tv2);
auto tv3 = squeeze(tv2, {0});
fusion.addOutput(tv3);

{
// Make sure squeeze/broadcast is not removed from fusion.
Fusion fusion_copy = fusion;
OptimizationPass<RemoveBcastSqueeze>::runPass(&fusion_copy);
auto new_exprs = fusion_copy.exprs();
EXPECT_NE(
std::find_if(
new_exprs.begin(),
new_exprs.end(),
[](Expr* new_expr) {
return new_expr->isOneOf<BroadcastOp, SqueezeOp>();
}),
new_exprs.end());
}

auto options = at::TensorOptions().device(at::kCUDA, 0);
auto t0 = at::randn({2, 3, 4, 5}, options);
std::vector<c10::IValue> inputs = {t0};
FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs(inputs);
testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__);
}

TEST_F(PresegTest, FusionRemoveBroadcastSqueeze3) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr;
FusionGuard fg(&fusion);

auto tv0 = makeContigConcreteTensor({2, 3, 4, 5});
fusion.addInput(tv0);
auto tv1 = broadcast(tv0, {true, false, false, false, false});
// tv2 is permuted, we currently do not support swapping permute with axis
// ops.
auto tv2 = permute(tv1, {{0, 4}});
auto tv3 = squeeze(tv2, {4});
fusion.addOutput(tv3);

{
// Make sure squeeze/broadcast is not removed from fusion.
Fusion fusion_copy = fusion;
OptimizationPass<RemoveBcastSqueeze>::runPass(&fusion_copy);
auto new_exprs = fusion_copy.exprs();
EXPECT_NE(
std::find_if(
new_exprs.begin(),
new_exprs.end(),
[](Expr* new_expr) {
return new_expr->isOneOf<BroadcastOp, SqueezeOp>();
}),
new_exprs.end());
}

auto options = at::TensorOptions().device(at::kCUDA, 0);
auto t0 = at::randn({2, 3, 4, 5}, options);
std::vector<c10::IValue> inputs = {t0};
FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs(inputs);
testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__);
}

} // namespace nvfuser::preseg_passes
Loading