Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 31, 2022
1 parent b905efa commit ff92f83
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 22 deletions.
39 changes: 19 additions & 20 deletions taichi/codegen/amdgpu/codegen_amdgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM {
#define UNARY_STD(x) \
else if (op == UnaryOpType::x) { \
if (input_taichi_type->is_primitive(PrimitiveTypeID::f16)) { \
llvm_val[stmt] = call("__ocml_" #x "_f16", input); \
llvm_val[stmt] = call("__ocml_" #x "_f16", input); \
} else if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { \
llvm_val[stmt] = call("__ocml_" #x "_f32", input); \
llvm_val[stmt] = call("__ocml_" #x "_f32", input); \
} else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { \
llvm_val[stmt] = call("__ocml_" #x "_f64", input); \
llvm_val[stmt] = call("__ocml_" #x "_f64", input); \
} else { \
TI_NOT_IMPLEMENTED \
} \
Expand Down Expand Up @@ -190,7 +190,7 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM {
TI_ASSERT(fast_reductions.at(prim_type).find(op) !=
fast_reductions.at(prim_type).end());
return call(fast_reductions.at(prim_type).at(op),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
{llvm_val[stmt->dest], llvm_val[stmt->val]});
}

#ifndef TI_LLVM_15
Expand Down Expand Up @@ -227,8 +227,8 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM {

auto [begin, end] = get_range_for_bounds(stmt);
call("gpu_parallel_range_for",
{get_arg(0), begin, end, tls_prologue, body, epilogue,
tlctx->get_constant(stmt->tls_size)});
{get_arg(0), begin, end, tls_prologue, body, epilogue,
tlctx->get_constant(stmt->tls_size)});
}

void create_offload_mesh_for(OffloadedStmt *stmt) override {
Expand Down Expand Up @@ -343,7 +343,7 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM {
}
current_offload = nullptr;
#else
TI_NOT_IMPLEMENTED
TI_NOT_IMPLEMENTED
#endif
}

Expand All @@ -359,8 +359,8 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM {
const auto arg_id = stmt->arg_id;
const auto axis = stmt->axis;
llvm_val[stmt] = call("RuntimeContext_get_extra_args",
{get_context(), tlctx->get_constant(arg_id),
tlctx->get_constant(axis)});
{get_context(), tlctx->get_constant(arg_id),
tlctx->get_constant(axis)});
}

void visit(BinaryOpStmt *stmt) override {
Expand Down Expand Up @@ -434,11 +434,11 @@ FunctionType AMDGPUModuleToFunctionConverter::convert(
const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
LLVMCompiledKernel data) const {
auto &mod = data.module;
auto &tasks = data.tasks;
auto jit = tlctx_->jit.get();
auto amdgpu_module =
jit->add_module(std::move(mod), executor_->get_config()->gpu_max_reg);
auto &mod = data.module;
auto &tasks = data.tasks;
auto jit = tlctx_->jit.get();
auto amdgpu_module =
jit->add_module(std::move(mod), executor_->get_config()->gpu_max_reg);

return [amdgpu_module, kernel_name, args, offloaded_tasks = tasks,
executor = this->executor_](RuntimeContext &context) {
Expand Down Expand Up @@ -489,10 +489,10 @@ FunctionType AMDGPUModuleToFunctionConverter::convert(

for (auto &task : offloaded_tasks) {
TI_TRACE("Launching kernel {}<<<{}, {}>>>", task.name, task.grid_dim,
task.block_dim);
task.block_dim);
amdgpu_module->launch(task.name, task.grid_dim, task.block_dim, 0,
(void *)&context_pointer,
(int)sizeof(RuntimeContext *));
(void *)&context_pointer,
(int)sizeof(RuntimeContext *));
}
AMDGPUDriver::get_instance().stream_synchronize(nullptr);
TI_TRACE("Launching kernel");
Expand All @@ -511,6 +511,5 @@ FunctionType AMDGPUModuleToFunctionConverter::convert(
};
}


}
}
} // namespace lang
} // namespace taichi
4 changes: 2 additions & 2 deletions taichi/codegen/amdgpu/codegen_amdgpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ class AMDGPUModuleToFunctionConverter : public ModuleToFunctionConverter {
LLVMCompiledKernel data) const override;
};

}
}
} // namespace lang
} // namespace taichi

0 comments on commit ff92f83

Please sign in to comment.