Skip to content

Commit

Permalink
[ir] Replace FuncCallExpression with FrontendFuncCallStmt
Browse files Browse the repository at this point in the history
ghstack-source-id: 7ccf3062f461ad642aa03b61879060b86e37df18
Pull Request resolved: taichi-dev#7027
  • Loading branch information
lin-hitonami authored and quadpixels committed May 13, 2023
1 parent 56818fd commit 8ef64dc
Show file tree
Hide file tree
Showing 13 changed files with 89 additions and 68 deletions.
9 changes: 4 additions & 5 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,12 @@ def func_call_rvalue(self, key, args):
non_template_args.append(args[i])
non_template_args = impl.make_expr_group(non_template_args,
real_func_arg=True)
func_call = Expr(
_ti_core.make_func_call_expr(
self.taichi_functions[key.instance_id], non_template_args))
impl.get_runtime().prog.current_ast_builder().insert_expr_stmt(
func_call.ptr)
func_call = impl.get_runtime().prog.current_ast_builder(
).insert_func_call(self.taichi_functions[key.instance_id],
non_template_args)
if self.return_type is None:
return None
func_call = Expr(func_call)
if id(self.return_type) in primitive_types.type_ids:
return Expr(_ti_core.make_get_element_expr(func_call.ptr, (0, )))
if isinstance(self.return_type, StructType):
Expand Down
4 changes: 2 additions & 2 deletions taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
emit(expr->axis);
}

void visit(FuncCallExpression *expr) override {
emit(ExprOpCode::FuncCallExpression);
void visit(FrontendFuncCallStmt *expr) override {
emit(StmtOpCode::FrontendFuncCallStmt);
emit(expr->func);
emit(expr->args.exprs);
}
Expand Down
19 changes: 10 additions & 9 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2723,27 +2723,28 @@ void TaskCodeGenLLVM::visit(FuncCallStmt *stmt) {
llvm::ConstantInt::get(*llvm_context, llvm::APInt(32, i, true)), val);
}
llvm::Value *result_buffer = nullptr;
auto *ret_type = get_real_func_ret_type(stmt->func);
result_buffer = builder->CreateAlloca(ret_type);
auto *result_buffer_u64 = builder->CreatePointerCast(
result_buffer, llvm::PointerType::get(tlctx->get_data_type<uint64>(), 0));
call("RuntimeContext_set_result_buffer", new_ctx, result_buffer_u64);
if (stmt->ret_type) {
auto *ret_type = tlctx->get_data_type(stmt->ret_type);
result_buffer = builder->CreateAlloca(ret_type);
auto *result_buffer_u64 = builder->CreatePointerCast(
result_buffer,
llvm::PointerType::get(tlctx->get_data_type<uint64>(), 0));
call("RuntimeContext_set_result_buffer", new_ctx, result_buffer_u64);
}
call(llvm_func, new_ctx);
llvm_val[stmt] = result_buffer;
call("recycle_runtime_context", get_runtime(), new_ctx);
}

void TaskCodeGenLLVM::visit(GetElementStmt *stmt) {
auto *real_func = stmt->src->as<FuncCallStmt>()->func;
auto *real_func_ret_type = tlctx->get_data_type(real_func->ret_type);
auto *struct_type = tlctx->get_data_type(stmt->src->ret_type);
std::vector<llvm::Value *> index;
index.reserve(stmt->index.size() + 1);
index.push_back(tlctx->get_constant(0));
for (auto &i : stmt->index) {
index.push_back(tlctx->get_constant(i));
}
auto *gep =
builder->CreateGEP(real_func_ret_type, llvm_val[stmt->src], index);
auto *gep = builder->CreateGEP(struct_type, llvm_val[stmt->src], index);
auto *val = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), gep);
llvm_val[stmt] = val;
}
Expand Down
1 change: 0 additions & 1 deletion taichi/inc/expressions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ PER_EXPRESSION(AtomicOpExpression)
PER_EXPRESSION(SNodeOpExpression)
PER_EXPRESSION(ConstExpression)
PER_EXPRESSION(ExternalTensorShapeAlongAxisExpression)
PER_EXPRESSION(FuncCallExpression)
PER_EXPRESSION(MeshPatchIndexExpression)
PER_EXPRESSION(MeshRelationAccessExpression)
PER_EXPRESSION(MeshIndexConversionExpression)
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/frontend_statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ PER_STATEMENT(FrontendSNodeOpStmt) // activate, deactivate, append, clear
PER_STATEMENT(FrontendAssertStmt)
PER_STATEMENT(FrontendFuncDefStmt)
PER_STATEMENT(FrontendReturnStmt)
PER_STATEMENT(FrontendFuncCallStmt)
6 changes: 0 additions & 6 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,6 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
emit(", ", expr->axis, ')');
}

void visit(FuncCallExpression *expr) override {
emit("func_call(\"", expr->func->func_key.get_full_name(), "\", ");
emit_vector(expr->args.exprs);
emit(')');
}

void visit(MeshPatchIndexExpression *expr) override {
emit("mesh_patch_idx()");
}
Expand Down
43 changes: 17 additions & 26 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1150,37 +1150,14 @@ void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void FuncCallExpression::type_check(CompileConfig *) {
for (auto &arg : args.exprs) {
TI_ASSERT_TYPE_CHECKED(arg);
// no arg type compatibility check for now due to lack of specification
}
ret_type = PrimitiveType::u64;
ret_type.set_is_pointer(true);
}

void FuncCallExpression::flatten(FlattenContext *ctx) {
std::vector<Stmt *> stmt_args;
for (auto &arg : args.exprs) {
stmt_args.push_back(flatten_rvalue(arg, ctx));
}
ctx->push_back<FuncCallStmt>(func, stmt_args);
stmt = ctx->back_stmt();
}

void GetElementExpression::type_check(CompileConfig *config) {
TI_ASSERT_TYPE_CHECKED(src);
auto func_call = src.cast<FuncCallExpression>();
TI_ASSERT(func_call);
// The return values are flattened now,
// so the length of stmt->index is 1.
// Will be refactored soon.
TI_ASSERT(index[0] < func_call->func->rets.size());
ret_type = func_call->func->rets[index[0]].dt;

ret_type = src->ret_type->as<StructType>()->get_element_type(index);
}

void GetElementExpression::flatten(FlattenContext *ctx) {
ctx->push_back<GetElementStmt>(src->get_flattened_stmt(), index);
ctx->push_back<GetElementStmt>(flatten_rvalue(src, ctx), index);
stmt = ctx->back_stmt();
}
// Mesh related.
Expand Down Expand Up @@ -1391,6 +1368,20 @@ Expr ASTBuilder::expr_alloca() {
return var;
}

std::optional<Expr> ASTBuilder::insert_func_call(Function *func,
const ExprGroup &args) {
if (func->ret_type) {
auto var = Expr(std::make_shared<IdExpression>(get_next_id()));
this->insert(std::make_unique<FrontendFuncCallStmt>(
func, args, std::static_pointer_cast<IdExpression>(var.expr)->id));
var.expr->ret_type = func->ret_type;
return var;
} else {
this->insert(std::make_unique<FrontendFuncCallStmt>(func, args));
return std::nullopt;
}
}

Expr ASTBuilder::make_matrix_expr(const std::vector<int> &shape,
const DataType &dt,
const std::vector<Expr> &elements) {
Expand Down
20 changes: 13 additions & 7 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -774,20 +774,25 @@ class ExternalTensorShapeAlongAxisExpression : public Expression {
TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class FuncCallExpression : public Expression {
class FrontendFuncCallStmt : public Stmt {
public:
std::optional<Identifier> ident;
Function *func;
ExprGroup args;

void type_check(CompileConfig *config) override;

FuncCallExpression(Function *func, const ExprGroup &args)
: func(func), args(args) {
explicit FrontendFuncCallStmt(
Function *func,
const ExprGroup &args,
const std::optional<Identifier> &id = std::nullopt)
: ident(id), func(func), args(args) {
TI_ASSERT(id.has_value() == !func->rets.empty());
}

void flatten(FlattenContext *ctx) override;
bool is_container_statement() const override {
return false;
}

TI_DEFINE_ACCEPT_FOR_EXPRESSION
TI_DEFINE_ACCEPT
};

class GetElementExpression : public Expression {
Expand Down Expand Up @@ -962,6 +967,7 @@ class ASTBuilder {
mesh::ConvType &conv_type);

void expr_assign(const Expr &lhs, const Expr &rhs, std::string tb);
std::optional<Expr> insert_func_call(Function *func, const ExprGroup &args);
void create_assert_stmt(const Expr &cond,
const std::string &msg,
const std::vector<Expr> &args);
Expand Down
3 changes: 3 additions & 0 deletions taichi/program/callable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ int Callable::insert_texture_arg(const DataType &dt) {
}

void Callable::finalize_rets() {
if (rets.empty()) {
return;
}
std::vector<const Type *> types;
types.reserve(rets.size());
for (const auto &ret : rets) {
Expand Down
4 changes: 1 addition & 3 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ void export_lang(py::module &m) {
.def("expand_exprs", &ASTBuilder::expand_exprs)
.def("mesh_index_conversion", &ASTBuilder::mesh_index_conversion)
.def("expr_subscript", &ASTBuilder::expr_subscript)
.def("insert_func_call", &ASTBuilder::insert_func_call)
.def("sifakis_svd_f32", sifakis_svd_export<float32, int32>)
.def("sifakis_svd_f64", sifakis_svd_export<float64, int64>)
.def("expr_var", &ASTBuilder::make_var)
Expand Down Expand Up @@ -823,9 +824,6 @@ void export_lang(py::module &m) {
with_runtime_context);
});

m.def("make_func_call_expr",
Expr::make<FuncCallExpression, Function *, const ExprGroup &>);

m.def("make_get_element_expr",
Expr::make<GetElementExpression, const Expr &, std::vector<int>>);

Expand Down
12 changes: 12 additions & 0 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,18 @@ class IRPrinter : public IRVisitor {
}
}

void visit(FrontendFuncCallStmt *stmt) override {
std::string args;
for (int i = 0; i < stmt->args.exprs.size(); i++) {
if (i) {
args += ", ";
}
args += expr_to_string(stmt->args.exprs[i]);
}
print("{}${} = call \"{}\", args = ({}), ret = {}", stmt->type_hint(),
stmt->id, stmt->func->get_name(), args, stmt->ident->name());
}

void visit(FuncCallStmt *stmt) override {
std::vector<std::string> args;
for (const auto &arg : stmt->args) {
Expand Down
17 changes: 17 additions & 0 deletions taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ class LowerAST : public IRVisitor {
}
}

void visit(FrontendFuncCallStmt *stmt) override {
Block *block = stmt->parent;
std::vector<Stmt *> args;
args.reserve(stmt->args.exprs.size());
auto fctx = make_flatten_ctx();
for (const auto &arg : stmt->args.exprs) {
args.push_back(flatten_rvalue(arg, &fctx));
}
auto lowered = fctx.push_back<FuncCallStmt>(stmt->func, args);
stmt->parent->replace_with(stmt, std::move(fctx.stmts));
if (const auto &ident = stmt->ident) {
TI_ASSERT(block->local_var_to_stmt.find(ident.value()) ==
block->local_var_to_stmt.end());
block->local_var_to_stmt.insert(std::make_pair(ident.value(), lowered));
}
}

void visit(FrontendIfStmt *stmt) override {
auto fctx = make_flatten_ctx();
auto condition_stmt = flatten_rvalue(stmt->condition, &fctx);
Expand Down
18 changes: 9 additions & 9 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,18 +407,18 @@ class TypeCheck : public IRVisitor {
void visit(FuncCallStmt *stmt) override {
auto *func = stmt->func;
TI_ASSERT(func);
stmt->ret_type = PrimitiveType::u64;
stmt->ret_type.set_is_pointer(true);
stmt->ret_type = func->ret_type;
}

void visit(FrontendFuncCallStmt *stmt) override {
auto *func = stmt->func;
TI_ASSERT(func);
stmt->ret_type = func->ret_type;
}

void visit(GetElementStmt *stmt) override {
TI_ASSERT(stmt->src->is<FuncCallStmt>());
auto *func = stmt->src->as<FuncCallStmt>()->func;
// The return values are flattened now,
// so the length of stmt->index is 1.
// Will be refactored soon.
TI_ASSERT(stmt->index[0] < func->rets.size());
stmt->ret_type = func->rets[stmt->index[0]].dt;
stmt->ret_type =
stmt->src->ret_type->as<StructType>()->get_element_type(stmt->index);
}

void visit(ArgLoadStmt *stmt) override {
Expand Down

0 comments on commit 8ef64dc

Please sign in to comment.