Skip to content

Commit be07fac

Browse files
authored
[PASS] not vectorize if_then_else (#2389)
1 parent a12c556 commit be07fac

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

src/codegen/llvm/codegen_llvm.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
654654
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
655655
return builder_->CreateIsNull(MakeValue(op->args[0]));
656656
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
657+
CHECK_EQ(op->args[0].type().lanes(), 1)
658+
<< "if_then_else can only take scalar condition";
657659
using llvm::BasicBlock;
658660
BasicBlock* then_block = BasicBlock::Create(
659661
*ctx_, "if_then", function_);

src/pass/vectorize_loop.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,19 @@ class Vectorizer : public IRMutator {
8383
// user mutate from parent.
8484
using IRMutator::Mutate;
8585

86+
Stmt Mutate(Stmt stmt) final {
87+
CHECK(!need_scalarize_);
88+
89+
Stmt ret = IRMutator::Mutate(stmt);
90+
if (need_scalarize_) {
91+
need_scalarize_ = false;
92+
return Scalarize(stmt);
93+
} else {
94+
return ret;
95+
}
96+
}
97+
98+
8699
Expr Mutate_(const Add* op, const Expr &e) final {
87100
return AddSubVec(op, e);
88101
}
@@ -200,10 +213,37 @@ class Vectorizer : public IRMutator {
200213
return e;
201214
}
202215
}
216+
// IfThenElse expr
217+
Expr MutateIfThenElseExpr_(const Call *op, const Expr& e) {
218+
Expr cond = this->Mutate(op->args[0]);
219+
if (cond.type().is_vector()) {
220+
need_scalarize_ = true;
221+
return e;
222+
}
223+
Expr t = this->Mutate(op->args[1]);
224+
Expr f = this->Mutate(op->args[2]);
225+
if (cond.same_as(op->args[0]) &&
226+
t.same_as(op->args[1]) &&
227+
f.same_as(op->args[2])) {
228+
return e;
229+
} else {
230+
int lanes = std::max(t.type().lanes(), f.type().lanes());
231+
t = BroadcastTo(t, lanes);
232+
f = BroadcastTo(f, lanes);
233+
return Call::make(
234+
op->type.with_lanes(lanes), op->name,
235+
{cond, t, f}, op->call_type, op->func, op->value_index);
236+
}
237+
}
203238
// Call
204239
Expr Mutate_(const Call* op, const Expr& e) final {
240+
if (op->name == intrinsic::tvm_if_then_else) {
241+
return MutateIfThenElseExpr_(op, e);
242+
}
205243
int lane = 0;
206244
Array<Expr> new_args = MutateArray(op->args, &lane);
245+
246+
// normal code path.
207247
if (op->args.same_as(new_args)) {
208248
return e;
209249
} else {
@@ -367,6 +407,8 @@ class Vectorizer : public IRMutator {
367407
int var_lanes_;
368408
// ramp representing the var.
369409
Expr ramp_;
410+
// flag to mark requirment of scalarization.
411+
bool need_scalarize_{false};
370412
// The lets
371413
std::unordered_map<const Variable*, Expr> lets_;
372414
// mutate array, with given lane requirement

tests/python/unittest/test_pass_vectorize.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,36 @@ def test_vectorize_with_if():
5353
assert stmt.then_case.value.dtype == "float32x4"
5454
assert isinstance(stmt.else_case, tvm.stmt.For)
5555

56+
def test_vectorize_if_then_else():
57+
n = tvm.var('n')
58+
x = tvm.var('x')
59+
ib = tvm.ir_builder.create()
60+
A = ib.pointer("float32", name="A")
61+
with ib.for_range(0, 4, for_type="vectorize") as i:
62+
A[i] = tvm.call_intrin("float32", "tvm_if_then_else",
63+
i > 0,
64+
A[i] + 1, A[i])
65+
stmt = ib.get()
66+
stmt = tvm.ir_pass.VectorizeLoop(stmt)
67+
assert isinstance(stmt, tvm.stmt.For)
68+
69+
70+
ib = tvm.ir_builder.create()
71+
A = ib.pointer("float32", name="A")
72+
with ib.for_range(0, n) as k:
73+
with ib.for_range(0, 4, for_type="vectorize") as i:
74+
A[k * 4 + i] = tvm.call_intrin("float32", "tvm_if_then_else",
75+
k > 0,
76+
A[k * 4 + i], 0)
77+
stmt = ib.get()
78+
assert isinstance(stmt.body, tvm.stmt.For)
79+
stmt = tvm.ir_pass.VectorizeLoop(stmt)
80+
assert not isinstance(stmt.body, tvm.stmt.For)
81+
assert isinstance(stmt.body.value.args[2], tvm.expr.Broadcast)
82+
83+
5684
if __name__ == "__main__":
5785
test_vectorize_vector()
5886
test_vectorize_with_if()
5987
test_vectorize_loop()
88+
test_vectorize_if_then_else()

0 commit comments

Comments
 (0)