Skip to content

Commit

Permalink
fix bug and solve conversation
Browse files Browse the repository at this point in the history
  • Loading branch information
galeselee committed Dec 22, 2022
1 parent 1e8c1f4 commit 325732a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
2 changes: 2 additions & 0 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,9 +841,11 @@ void TaichiLLVMContext::update_runtime_jit_module(
}

if (arch_ == Arch::amdgpu) {
#ifdef TI_WITH_AMDGPU
llvm::legacy::PassManager module_pass_manager;
module_pass_manager.add(new AMDGPUConvertFuncParamAddressSpacePass());
module_pass_manager.run(*module);
#endif
}

eliminate_unused_functions(module.get(), [](std::string func_name) {
Expand Down
13 changes: 9 additions & 4 deletions taichi/runtime/llvm/llvm_context_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,23 @@ struct AMDGPUConvertFuncParamAddressSpacePass : public ModulePass {
const std::string func_name = f.getName().str();
if (starts_with(func_name, "runtime_")) {
f.setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
f.addFnAttr("amdgpu-flat-work-group-size", "1, 256");
// ref https://llvm.org/docs/AMDGPUUsage.html
// “amdgpu-flat-work-group-size”=”min,max”
// Specify the minimum and maximum flat work group sizes that will be specified when the kernel is dispatched.
// Generated by the amdgpu_flat_work_group_size CLANG attribute [CLANG-ATTR].
// The implied default value is 1,1024.
f.addFnAttr("amdgpu-flat-work-group-size", "1, 1024");
is_kernel = true;
}
if (!is_kernel && !f.isDeclaration())
f.setLinkage(llvm::Function::PrivateLinkage);
}
std::vector<llvm::Function *> global_func;
std::vector<llvm::Function *> kernel_function;
for (auto &f : M) {
if (f.getCallingConv() == llvm::CallingConv::AMDGPU_KERNEL)
global_func.push_back(&f);
kernel_function.push_back(&f);
}
for (auto &f : global_func) {
for (auto &f : kernel_function) {
llvm::FunctionType *func_type = f->getFunctionType();
std::vector<llvm::Type*> new_func_params;
for (auto &arg : f->args()) {
Expand Down

0 comments on commit 325732a

Please sign in to comment.