diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index e933f8de190f1..91e8b00246e29 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -268,13 +268,12 @@ def func_call_rvalue(self, key, args): non_template_args.append(args[i]) non_template_args = impl.make_expr_group(non_template_args, real_func_arg=True) - func_call = Expr( - _ti_core.make_func_call_expr( - self.taichi_functions[key.instance_id], non_template_args)) - impl.get_runtime().prog.current_ast_builder().insert_expr_stmt( - func_call.ptr) + func_call = impl.get_runtime().prog.current_ast_builder( + ).insert_func_call(self.taichi_functions[key.instance_id], + non_template_args) if self.return_type is None: return None + func_call = Expr(func_call) if id(self.return_type) in primitive_types.type_ids: return Expr(_ti_core.make_get_element_expr(func_call.ptr, (0, ))) if isinstance(self.return_type, StructType): diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index ba44b15efc404..157b0c812b917 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -230,8 +230,8 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(expr->axis); } - void visit(FuncCallExpression *expr) override { - emit(ExprOpCode::FuncCallExpression); + void visit(FrontendFuncCallStmt *expr) override { + emit(StmtOpCode::FrontendFuncCallStmt); emit(expr->func); emit(expr->args.exprs); } diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index aa1b9a6d3eb74..b0d4ebd89dfde 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -2723,27 +2723,28 @@ void TaskCodeGenLLVM::visit(FuncCallStmt *stmt) { llvm::ConstantInt::get(*llvm_context, llvm::APInt(32, i, true)), val); } llvm::Value *result_buffer = nullptr; - auto *ret_type = get_real_func_ret_type(stmt->func); - result_buffer = builder->CreateAlloca(ret_type); - auto *result_buffer_u64 = builder->CreatePointerCast( - result_buffer, llvm::PointerType::get(tlctx->get_data_type(), 0)); - call("RuntimeContext_set_result_buffer", new_ctx, result_buffer_u64); + if (stmt->ret_type) { + auto *ret_type = tlctx->get_data_type(stmt->ret_type); + result_buffer = builder->CreateAlloca(ret_type); + auto *result_buffer_u64 = builder->CreatePointerCast( + result_buffer, + llvm::PointerType::get(tlctx->get_data_type(), 0)); + call("RuntimeContext_set_result_buffer", new_ctx, result_buffer_u64); + } call(llvm_func, new_ctx); llvm_val[stmt] = result_buffer; call("recycle_runtime_context", get_runtime(), new_ctx); } void TaskCodeGenLLVM::visit(GetElementStmt *stmt) { - auto *real_func = stmt->src->as()->func; - auto *real_func_ret_type = tlctx->get_data_type(real_func->ret_type); + auto *struct_type = tlctx->get_data_type(stmt->src->ret_type); std::vector index; index.reserve(stmt->index.size() + 1); index.push_back(tlctx->get_constant(0)); for (auto &i : stmt->index) { index.push_back(tlctx->get_constant(i)); } - auto *gep = - builder->CreateGEP(real_func_ret_type, llvm_val[stmt->src], index); + auto *gep = builder->CreateGEP(struct_type, llvm_val[stmt->src], index); auto *val = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), gep); llvm_val[stmt] = val; } diff --git a/taichi/inc/expressions.inc.h b/taichi/inc/expressions.inc.h index 9675b08ed0a04..12ed0e9cdbbfa 100644 --- a/taichi/inc/expressions.inc.h +++ b/taichi/inc/expressions.inc.h @@ -16,7 +16,6 @@ PER_EXPRESSION(AtomicOpExpression) PER_EXPRESSION(SNodeOpExpression) PER_EXPRESSION(ConstExpression) PER_EXPRESSION(ExternalTensorShapeAlongAxisExpression) -PER_EXPRESSION(FuncCallExpression) PER_EXPRESSION(MeshPatchIndexExpression) PER_EXPRESSION(MeshRelationAccessExpression) PER_EXPRESSION(MeshIndexConversionExpression) diff --git a/taichi/inc/frontend_statements.inc.h b/taichi/inc/frontend_statements.inc.h index 4ab27e8e4be40..e788075d61f36 100644 --- a/taichi/inc/frontend_statements.inc.h +++ b/taichi/inc/frontend_statements.inc.h @@ -12,3 +12,4 @@ PER_STATEMENT(FrontendSNodeOpStmt) // activate, deactivate, append, clear PER_STATEMENT(FrontendAssertStmt) PER_STATEMENT(FrontendFuncDefStmt) PER_STATEMENT(FrontendReturnStmt) +PER_STATEMENT(FrontendFuncCallStmt) diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index cb4d9afb30434..78f0fb3438876 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -209,12 +209,6 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { emit(", ", expr->axis, ')'); } - void visit(FuncCallExpression *expr) override { - emit("func_call(\"", expr->func->func_key.get_full_name(), "\", "); - emit_vector(expr->args.exprs); - emit(')'); - } - void visit(MeshPatchIndexExpression *expr) override { emit("mesh_patch_idx()"); } diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index b440a316bc172..2a2d96fcb2ed2 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -1150,37 +1150,14 @@ void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } -void FuncCallExpression::type_check(CompileConfig *) { - for (auto &arg : args.exprs) { - TI_ASSERT_TYPE_CHECKED(arg); - // no arg type compatibility check for now due to lack of specification - } - ret_type = PrimitiveType::u64; - ret_type.set_is_pointer(true); -} - -void FuncCallExpression::flatten(FlattenContext *ctx) { - std::vector stmt_args; - for (auto &arg : args.exprs) { - stmt_args.push_back(flatten_rvalue(arg, ctx)); - } - ctx->push_back(func, stmt_args); - stmt = ctx->back_stmt(); -} - void GetElementExpression::type_check(CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(src); - auto func_call = src.cast(); - TI_ASSERT(func_call); - // The return values are flattened now, - // so the length of stmt->index is 1. - // Will be refactored soon. - TI_ASSERT(index[0] < func_call->func->rets.size()); - ret_type = func_call->func->rets[index[0]].dt; + + ret_type = src->ret_type->as()->get_element_type(index); } void GetElementExpression::flatten(FlattenContext *ctx) { - ctx->push_back(src->get_flattened_stmt(), index); + ctx->push_back(flatten_rvalue(src, ctx), index); stmt = ctx->back_stmt(); } // Mesh related. @@ -1391,6 +1368,20 @@ Expr ASTBuilder::expr_alloca() { return var; } +std::optional ASTBuilder::insert_func_call(Function *func, + const ExprGroup &args) { + if (func->ret_type) { + auto var = Expr(std::make_shared(get_next_id())); + this->insert(std::make_unique( + func, args, std::static_pointer_cast(var.expr)->id)); + var.expr->ret_type = func->ret_type; + return var; + } else { + this->insert(std::make_unique(func, args)); + return std::nullopt; + } +} + Expr ASTBuilder::make_matrix_expr(const std::vector &shape, const DataType &dt, const std::vector &elements) { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 9e03107090bbe..0e3f657110c14 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -774,20 +774,25 @@ class ExternalTensorShapeAlongAxisExpression : public Expression { TI_DEFINE_ACCEPT_FOR_EXPRESSION }; -class FuncCallExpression : public Expression { +class FrontendFuncCallStmt : public Stmt { public: + std::optional ident; Function *func; ExprGroup args; - void type_check(CompileConfig *config) override; - - FuncCallExpression(Function *func, const ExprGroup &args) - : func(func), args(args) { + explicit FrontendFuncCallStmt( + Function *func, + const ExprGroup &args, + const std::optional &id = std::nullopt) + : ident(id), func(func), args(args) { + TI_ASSERT(id.has_value() == !func->rets.empty()); } - void flatten(FlattenContext *ctx) override; + bool is_container_statement() const override { + return false; + } - TI_DEFINE_ACCEPT_FOR_EXPRESSION + TI_DEFINE_ACCEPT }; class GetElementExpression : public Expression { @@ -962,6 +967,7 @@ class ASTBuilder { mesh::ConvType &conv_type); void expr_assign(const Expr &lhs, const Expr &rhs, std::string tb); + std::optional insert_func_call(Function *func, const ExprGroup &args); void create_assert_stmt(const Expr &cond, const std::string &msg, const std::vector &args); diff --git a/taichi/program/callable.cpp b/taichi/program/callable.cpp index 83692b2008f95..6bcb6b63fd8c1 100644 --- a/taichi/program/callable.cpp +++ b/taichi/program/callable.cpp @@ -32,6 +32,9 @@ int Callable::insert_texture_arg(const DataType &dt) { } void Callable::finalize_rets() { + if (rets.empty()) { + return; + } std::vector types; types.reserve(rets.size()); for (const auto &ret : rets) { diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index cf4b314b6c87f..f0edda99a6a33 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -316,6 +316,7 @@ void export_lang(py::module &m) { .def("expand_exprs", &ASTBuilder::expand_exprs) .def("mesh_index_conversion", &ASTBuilder::mesh_index_conversion) .def("expr_subscript", &ASTBuilder::expr_subscript) + .def("insert_func_call", &ASTBuilder::insert_func_call) .def("sifakis_svd_f32", sifakis_svd_export) .def("sifakis_svd_f64", sifakis_svd_export) .def("expr_var", &ASTBuilder::make_var) @@ -823,9 +824,6 @@ void export_lang(py::module &m) { with_runtime_context); }); - m.def("make_func_call_expr", - Expr::make); - m.def("make_get_element_expr", Expr::make>); diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 88531f6572f6d..5db42c5319540 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -294,6 +294,18 @@ class IRPrinter : public IRVisitor { } } + void visit(FrontendFuncCallStmt *stmt) override { + std::string args; + for (int i = 0; i < stmt->args.exprs.size(); i++) { + if (i) { + args += ", "; + } + args += expr_to_string(stmt->args.exprs[i]); + } + print("{}${} = call \"{}\", args = ({}), ret = {}", stmt->type_hint(), + stmt->id, stmt->func->get_name(), args, stmt->ident->name()); + } + void visit(FuncCallStmt *stmt) override { std::vector args; for (const auto &arg : stmt->args) { diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 3d91508ee38ee..0f1563011f35a 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -82,6 +82,23 @@ class LowerAST : public IRVisitor { } } + void visit(FrontendFuncCallStmt *stmt) override { + Block *block = stmt->parent; + std::vector args; + args.reserve(stmt->args.exprs.size()); + auto fctx = make_flatten_ctx(); + for (const auto &arg : stmt->args.exprs) { + args.push_back(flatten_rvalue(arg, &fctx)); + } + auto lowered = fctx.push_back(stmt->func, args); + stmt->parent->replace_with(stmt, std::move(fctx.stmts)); + if (const auto &ident = stmt->ident) { + TI_ASSERT(block->local_var_to_stmt.find(ident.value()) == + block->local_var_to_stmt.end()); + block->local_var_to_stmt.insert(std::make_pair(ident.value(), lowered)); + } + } + void visit(FrontendIfStmt *stmt) override { auto fctx = make_flatten_ctx(); auto condition_stmt = flatten_rvalue(stmt->condition, &fctx); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 50fdcc8fcf725..386124efab009 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -407,18 +407,18 @@ class TypeCheck : public IRVisitor { void visit(FuncCallStmt *stmt) override { auto *func = stmt->func; TI_ASSERT(func); - stmt->ret_type = PrimitiveType::u64; - stmt->ret_type.set_is_pointer(true); + stmt->ret_type = func->ret_type; + } + + void visit(FrontendFuncCallStmt *stmt) override { + auto *func = stmt->func; + TI_ASSERT(func); + stmt->ret_type = func->ret_type; } void visit(GetElementStmt *stmt) override { - TI_ASSERT(stmt->src->is()); - auto *func = stmt->src->as()->func; - // The return values are flattened now, - // so the length of stmt->index is 1. - // Will be refactored soon. - TI_ASSERT(stmt->index[0] < func->rets.size()); - stmt->ret_type = func->rets[stmt->index[0]].dt; + stmt->ret_type = + stmt->src->ret_type->as()->get_element_type(stmt->index); } void visit(ArgLoadStmt *stmt) override {