diff --git a/taichi/codegen/amdgpu/codegen_amdgpu.cpp b/taichi/codegen/amdgpu/codegen_amdgpu.cpp index 09dcc88099de1..3923de44ed716 100644 --- a/taichi/codegen/amdgpu/codegen_amdgpu.cpp +++ b/taichi/codegen/amdgpu/codegen_amdgpu.cpp @@ -294,7 +294,7 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { } else if (stmt->task_type == Type::range_for) { create_offload_range_for(stmt); } else if (stmt->task_type == Type::struct_for) { - create_offload_struct_for(stmt, true); + create_offload_struct_for(stmt); } else if (stmt->task_type == Type::mesh_for) { create_offload_mesh_for(stmt); } else if (stmt->task_type == Type::listgen) { @@ -395,6 +395,18 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { } } } + + private: + std::tuple get_spmd_info() override { + auto thread_idx = + builder->CreateIntrinsic(Intrinsic::amdgcn_workitem_id_x, {}, {}); + auto workgroup_dim_ = + call("__ockl_get_local_size", + llvm::ConstantInt::get(llvm::Type::getInt32Ty(*llvm_context), 0)); + auto block_dim = builder->CreateTrunc( + workgroup_dim_, llvm::Type::getInt32Ty(*llvm_context)); + return std::make_tuple(thread_idx, block_dim); + } }; LLVMCompiledTask KernelCodeGenAMDGPU::compile_task( diff --git a/taichi/codegen/cpu/codegen_cpu.cpp b/taichi/codegen/cpu/codegen_cpu.cpp index 9ef775e8afe4c..03224e529b155 100644 --- a/taichi/codegen/cpu/codegen_cpu.cpp +++ b/taichi/codegen/cpu/codegen_cpu.cpp @@ -209,6 +209,13 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM { TI_NOT_IMPLEMENTED } } + + private: + std::tuple get_spmd_info() override { + auto thread_idx = tlctx->get_constant(0); + auto block_dim = tlctx->get_constant(1); + return std::make_tuple(thread_idx, block_dim); + } }; } // namespace diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index f160118661109..3413b5fa2f9ed 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -474,7 +474,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } else if (stmt->task_type == Type::range_for) { create_offload_range_for(stmt); } else if (stmt->task_type == Type::struct_for) { - create_offload_struct_for(stmt, true); + create_offload_struct_for(stmt); } else if (stmt->task_type == Type::mesh_for) { create_offload_mesh_for(stmt); } else if (stmt->task_type == Type::listgen) { @@ -584,6 +584,15 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); } } + + private: + std::tuple get_spmd_info() override { + auto thread_idx = + builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}); + auto block_dim = + builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}); + return std::make_tuple(thread_idx, block_dim); + } }; LLVMCompiledTask KernelCodeGenCUDA::compile_task( diff --git a/taichi/codegen/dx12/codegen_dx12.cpp b/taichi/codegen/dx12/codegen_dx12.cpp index 0f4ac8e250819..5ed3d1f8b264b 100644 --- a/taichi/codegen/dx12/codegen_dx12.cpp +++ b/taichi/codegen/dx12/codegen_dx12.cpp @@ -191,6 +191,13 @@ class TaskCodeGenLLVMDX12 : public TaskCodeGenLLVM { TI_NOT_IMPLEMENTED } } + + private: + std::tuple get_spmd_info() override { + auto thread_idx = tlctx->get_constant(0); + auto block_dim = tlctx->get_constant(1); + return std::make_tuple(thread_idx, block_dim); + } }; } // namespace diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index fa45ef184aaae..45a30bbc89990 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -2021,8 +2021,7 @@ std::tuple TaskCodeGenLLVM::get_range_for_bounds( return std::tuple(begin, end); } -void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, - bool spmd) { +void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt) { using namespace llvm; // TODO: instead of constructing tons of LLVM IR, writing the logic in // runtime.cpp may be a cleaner solution. See @@ -2122,18 +2121,9 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, call("block_barrier"); // "__syncthreads()" } - llvm::Value *thread_idx = nullptr, *block_dim = nullptr; - - if (spmd) { - thread_idx = - builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}); - block_dim = builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_ntid_x, - {}, {}); - builder->CreateStore(builder->CreateAdd(thread_idx, lower_bound), - loop_index); - } else { - builder->CreateStore(lower_bound, loop_index); - } + auto [thread_idx, block_dim] = this->get_spmd_info(); + builder->CreateStore(builder->CreateAdd(thread_idx, lower_bound), + loop_index); auto loop_test_bb = BasicBlock::Create(*llvm_context, "loop_test", func); auto loop_body_bb = BasicBlock::Create(*llvm_context, "loop_body", func); @@ -2216,11 +2206,7 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, // body tail: increment loop_index and jump to loop_test builder->SetInsertPoint(body_tail_bb); - if (spmd) { - create_increment(loop_index, block_dim); - } else { - create_increment(loop_index, tlctx->get_constant(1)); - } + create_increment(loop_index, block_dim); builder->CreateBr(loop_test_bb); builder->SetInsertPoint(func_exit); diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index a991f91650c63..06aed262f5081 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -332,7 +332,7 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { TI_NOT_IMPLEMENTED; } - void create_offload_struct_for(OffloadedStmt *stmt, bool spmd = false); + void create_offload_struct_for(OffloadedStmt *stmt); void visit(LoopIndexStmt *stmt) override; @@ -410,6 +410,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { const Type *current_type, int ¤t_element, std::vector ¤t_index); + + virtual std::tuple get_spmd_info() = 0; }; } // namespace taichi::lang diff --git a/taichi/codegen/llvm/llvm_codegen_utils.h b/taichi/codegen/llvm/llvm_codegen_utils.h index c05ac6b0a85bf..51b08efc52a49 100644 --- a/taichi/codegen/llvm/llvm_codegen_utils.h +++ b/taichi/codegen/llvm/llvm_codegen_utils.h @@ -10,6 +10,11 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsNVPTX.h" + +#if defined(TI_WITH_AMDGPU) +#include "llvm/IR/IntrinsicsAMDGPU.h" +#endif + #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" diff --git a/taichi/codegen/wasm/codegen_wasm.cpp b/taichi/codegen/wasm/codegen_wasm.cpp index 3ddf8a1c278f6..504bc9065d92c 100644 --- a/taichi/codegen/wasm/codegen_wasm.cpp +++ b/taichi/codegen/wasm/codegen_wasm.cpp @@ -234,6 +234,11 @@ class TaskCodeGenWASM : public TaskCodeGenLLVM { res.module = std::move(this->module); return res; } + + private: + std::tuple get_spmd_info() override { + TI_NOT_IMPLEMENTED; + } }; FunctionType KernelCodeGenWASM::compile_to_function() { diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index eefdff7769817..8318e00194c68 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -46,9 +46,8 @@ #include "llvm_context.h" #include "taichi/runtime/program_impls/llvm/llvm_program.h" #include "taichi/codegen/codegen_utils.h" -#ifdef TI_WITH_AMDGPU + #include "taichi/runtime/llvm/llvm_context_pass.h" -#endif #ifdef _WIN32 // Travis CI seems doesn't support ... @@ -1039,52 +1038,21 @@ void TaichiLLVMContext::add_struct_for_func(llvm::Module *module, if (module->getFunction(func_name)) { return; } - auto struct_for_func = module->getFunction("parallel_struct_for"); - auto &llvm_context = module->getContext(); - auto value_map = llvm::ValueToValueMapTy(); - auto patched_struct_for_func = - llvm::CloneFunction(struct_for_func, value_map); - patched_struct_for_func->setName(func_name); - - int num_found_alloca = 0; - llvm::AllocaInst *alloca = nullptr; - - auto char_type = llvm::Type::getInt8Ty(llvm_context); - - // Find the "1" in "char tls_buffer[1]" and replace it with - // "tls_buffer_size" - for (auto &bb : *patched_struct_for_func) { - for (llvm::Instruction &inst : bb) { - auto now_alloca = llvm::dyn_cast(&inst); - if (!now_alloca || now_alloca->getAlign().value() != 8) - continue; - auto alloca_type = now_alloca->getAllocatedType(); - // Allocated type should be array [1 x i8] - if (alloca_type->isArrayTy() && alloca_type->getArrayNumElements() == 1 && - alloca_type->getArrayElementType() == char_type) { - alloca = now_alloca; - num_found_alloca++; - } - } - } - // There should be **exactly** one replacement. - TI_ASSERT(num_found_alloca == 1 && alloca); - auto new_type = llvm::ArrayType::get(char_type, tls_size); - { - llvm::IRBuilder<> builder(alloca); - auto *new_alloca = builder.CreateAlloca(new_type); - new_alloca->setAlignment(Align(8)); - TI_ASSERT(alloca->hasOneUse()); - auto *gep = llvm::cast(alloca->user_back()); - TI_ASSERT(gep->getPointerOperand() == alloca); - std::vector indices(gep->idx_begin(), gep->idx_end()); - builder.SetInsertPoint(gep); - auto *new_gep = builder.CreateInBoundsGEP(new_type, new_alloca, indices); - gep->replaceAllUsesWith(new_gep); - gep->eraseFromParent(); - alloca->eraseFromParent(); + llvm::legacy::PassManager module_pass_manager; + if (config_.arch == Arch::amdgpu) { +#ifdef TI_WITH_AMDGPU + module_pass_manager.add( + new AMDGPUAddStructForFuncPass(func_name, tls_size)); + module_pass_manager.run(*module); +#else + TI_NOT_IMPLEMENTED +#endif + } else { + module_pass_manager.add(new AddStructForFuncPass(func_name, tls_size)); + module_pass_manager.run(*module); } } + std::string TaichiLLVMContext::get_struct_for_func_name(int tls_size) { return "parallel_struct_for_" + std::to_string(tls_size); } diff --git a/taichi/runtime/llvm/llvm_context_pass.h b/taichi/runtime/llvm/llvm_context_pass.h index 4eb8e4477d87f..ff12fd3284500 100644 --- a/taichi/runtime/llvm/llvm_context_pass.h +++ b/taichi/runtime/llvm/llvm_context_pass.h @@ -10,6 +10,8 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include "llvm/Transforms/Utils/Cloning.h" #if defined(TI_WITH_AMDGPU) #include "taichi/rhi/amdgpu/amdgpu_context.h" @@ -18,6 +20,63 @@ namespace taichi { namespace lang { using namespace llvm; + +struct AddStructForFuncPass : public ModulePass { + static inline char ID{0}; + std::string func_name_; + int tls_size_; + AddStructForFuncPass(std::string func_name, int tls_size) : ModulePass(ID) { + func_name_ = func_name; + tls_size_ = tls_size; + } + bool runOnModule(llvm::Module &M) override { + auto struct_for_func = M.getFunction("parallel_struct_for"); + auto &llvm_context = M.getContext(); + auto value_map = llvm::ValueToValueMapTy(); + auto patched_struct_for_func = + llvm::CloneFunction(struct_for_func, value_map); + patched_struct_for_func->setName(func_name_); + + int num_found_alloca = 0; + llvm::AllocaInst *alloca = nullptr; + + auto char_type = llvm::Type::getInt8Ty(llvm_context); + + // Find the "1" in "char tls_buffer[1]" and replace it with + // "tls_buffer_size" + for (auto &bb : *patched_struct_for_func) { + for (llvm::Instruction &inst : bb) { + auto now_alloca = llvm::dyn_cast(&inst); + if (!now_alloca || now_alloca->getAlign().value() != 8) + continue; + auto alloca_type = now_alloca->getAllocatedType(); + // Allocated type should be array [1 x i8] + if (alloca_type->isArrayTy() && + alloca_type->getArrayNumElements() == 1 && + alloca_type->getArrayElementType() == char_type) { + alloca = now_alloca; + num_found_alloca++; + } + } + } + TI_ASSERT(num_found_alloca == 1 && alloca); + auto new_type = llvm::ArrayType::get(char_type, tls_size_); + llvm::IRBuilder<> builder(alloca); + auto *new_alloca = builder.CreateAlloca(new_type); + new_alloca->setAlignment(Align(8)); + TI_ASSERT(alloca->hasOneUse()); + auto *gep = llvm::cast(alloca->user_back()); + TI_ASSERT(gep->getPointerOperand() == alloca); + std::vector indices(gep->idx_begin(), gep->idx_end()); + builder.SetInsertPoint(gep); + auto *new_gep = builder.CreateInBoundsGEP(new_type, new_alloca, indices); + gep->replaceAllUsesWith(new_gep); + gep->eraseFromParent(); + alloca->eraseFromParent(); + return false; + } +}; + #if defined(TI_WITH_AMDGPU) struct AMDGPUConvertAllocaInstAddressSpacePass : public FunctionPass { static inline char ID{0}; @@ -52,6 +111,69 @@ struct AMDGPUConvertAllocaInstAddressSpacePass : public FunctionPass { } }; +struct AMDGPUAddStructForFuncPass : public ModulePass { + static inline char ID{0}; + std::string func_name_; + int tls_size_; + AMDGPUAddStructForFuncPass(std::string func_name, int tls_size) + : ModulePass(ID) { + func_name_ = func_name; + tls_size_ = tls_size; + } + bool runOnModule(llvm::Module &M) override { + auto struct_for_func = M.getFunction("parallel_struct_for"); + auto &llvm_context = M.getContext(); + auto value_map = llvm::ValueToValueMapTy(); + auto patched_struct_for_func = + llvm::CloneFunction(struct_for_func, value_map); + patched_struct_for_func->setName(func_name_); + + int num_found_alloca = 0; + llvm::AllocaInst *alloca = nullptr; + + auto char_type = llvm::Type::getInt8Ty(llvm_context); + + // Find the "1" in "char tls_buffer[1]" and replace it with + // "tls_buffer_size" + for (auto &bb : *patched_struct_for_func) { + for (llvm::Instruction &inst : bb) { + auto now_alloca = llvm::dyn_cast(&inst); + if (!now_alloca || now_alloca->getAlign().value() != 8) + continue; + auto alloca_type = now_alloca->getAllocatedType(); + // Allocated type should be array [1 x i8] + if (alloca_type->isArrayTy() && + alloca_type->getArrayNumElements() == 1 && + alloca_type->getArrayElementType() == char_type) { + alloca = now_alloca; + num_found_alloca++; + } + } + } + TI_ASSERT(num_found_alloca == 1 && alloca); + auto new_type = llvm::ArrayType::get(char_type, tls_size_); + llvm::IRBuilder<> builder(alloca); + auto *new_alloca = builder.CreateAlloca(new_type, (unsigned)5); + new_alloca->setAlignment(Align(8)); + auto new_ty = llvm::PointerType::get(new_type, unsigned(0)); + auto *new_cast = builder.CreateAddrSpaceCast(new_alloca, new_ty); + new_alloca->setAlignment(Align(8)); + TI_ASSERT(alloca->hasOneUse()); + auto *cast = llvm::cast(alloca->user_back()); + TI_ASSERT(cast->hasOneUse()); + auto *gep = llvm::cast(cast->user_back()); + TI_ASSERT(gep->getPointerOperand() == cast); + std::vector indices(gep->idx_begin(), gep->idx_end()); + builder.SetInsertPoint(gep); + auto *new_gep = builder.CreateInBoundsGEP(new_type, new_cast, indices); + gep->replaceAllUsesWith(new_gep); + gep->eraseFromParent(); + cast->eraseFromParent(); + alloca->eraseFromParent(); + return false; + } +}; + struct AMDGPUConvertFuncParamAddressSpacePass : public ModulePass { static inline char ID{0}; AMDGPUConvertFuncParamAddressSpacePass() : ModulePass(ID) {