Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 4, 2023
1 parent 5af3ccd commit d755838
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
12 changes: 7 additions & 5 deletions taichi/jit/jit_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ class JITModule {
}

inline std::tuple<std::vector<void *>, std::vector<int> > get_arg_pointers() {
return std::make_tuple(std::vector<void *>(), std::vector<int>() );
return std::make_tuple(std::vector<void *>(), std::vector<int>());
}

template <typename... Args, typename T>
inline std::tuple<std::vector<void *>, std::vector<int> > get_arg_pointers(T &t, Args &...args) {
inline std::tuple<std::vector<void *>, std::vector<int> > get_arg_pointers(
T &t,
Args &...args) {
auto [arg_pointers, arg_sizes] = get_arg_pointers(args...);
arg_pointers.insert(arg_pointers.begin(), &t);
arg_sizes.insert(arg_sizes.begin(), sizeof(t));
Expand All @@ -53,8 +55,7 @@ class JITModule {
void call(const std::string &name, Args... args) {
if (direct_dispatch()) {
get_function<Args...>(name)(args...);
}
else {
} else {
auto [arg_pointers, arg_sizes] = JITModule::get_arg_pointers(args...);
call(name, arg_pointers, arg_sizes);
}
Expand All @@ -75,7 +76,8 @@ class JITModule {
std::size_t shared_mem_bytes,
Args... args) {
auto [arg_pointers, arg_sizes] = JITModule::get_arg_pointers(args...);
launch(name, grid_dim, block_dim, shared_mem_bytes, arg_pointers, arg_sizes);
launch(name, grid_dim, block_dim, shared_mem_bytes, arg_pointers,
arg_sizes);
}

virtual void launch(const std::string &name,
Expand Down
23 changes: 12 additions & 11 deletions taichi/rhi/amdgpu/amdgpu_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,23 @@ int AMDGPUContext::get_args_byte(std::vector<int> arg_sizes) {
naive_add += size;
if (size < 32) {
if ((byte_cnt + size) % 32 > (byte_cnt) % 32 ||
(byte_cnt + size) % 32 == 0) byte_cnt += size;
else byte_cnt += 32 - byte_cnt % 32 + size;
}
else {
if (byte_cnt % 32 != 0)
(byte_cnt + size) % 32 == 0)
byte_cnt += size;
else
byte_cnt += 32 - byte_cnt % 32 + size;
} else {
if (byte_cnt % 32 != 0)
byte_cnt += 32 - byte_cnt % 32 + size;
else
else
byte_cnt += size;
}
}
return byte_cnt;
}

void AMDGPUContext::pack_args(std::vector<void *> arg_pointers,
std::vector<int> arg_sizes, char *arg_packed) {
std::vector<int> arg_sizes,
char *arg_packed) {
int byte_cnt = 0;
for (int ii = 0; ii < arg_pointers.size(); ii++) {
// The parameter is taken as a vec4
Expand All @@ -95,16 +97,15 @@ void AMDGPUContext::pack_args(std::vector<void *> arg_pointers,
int padding_size = 32 - byte_cnt % 32;
byte_cnt += padding_size;
std::memcpy(arg_packed + byte_cnt, arg_pointers[ii], arg_sizes[ii]);
byte_cnt += arg_sizes[ii];
byte_cnt += arg_sizes[ii];
}
} else {
if (byte_cnt % 32 != 0) {
int padding_size = 32 - byte_cnt % 32;
byte_cnt+= padding_size;
byte_cnt += padding_size;
std::memcpy(arg_packed + byte_cnt, arg_pointers[ii], arg_sizes[ii]);
byte_cnt += arg_sizes[ii];
}
else {
} else {
std::memcpy(arg_packed + byte_cnt, arg_pointers[ii], arg_sizes[ii]);
byte_cnt += arg_sizes[ii];
}
Expand Down
3 changes: 2 additions & 1 deletion taichi/rhi/amdgpu/amdgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class AMDGPUContext {
}

void pack_args(std::vector<void *> arg_pointers,
std::vector<int> arg_sizes, char *arg_packed);
std::vector<int> arg_sizes,
char *arg_packed);

int get_args_byte(std::vector<int> arg_sizes);

Expand Down
5 changes: 3 additions & 2 deletions taichi/runtime/cuda/jit_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ class JITModuleCUDA : public JITModule {
const std::vector<void *> &arg_pointers,
const std::vector<int> &arg_sizes) override {
auto func = lookup_function(name);
CUDAContext::get_instance().launch(func, name, arg_pointers, arg_sizes, grid_dim,
block_dim, dynamic_shared_mem_bytes);
CUDAContext::get_instance().launch(func, name, arg_pointers, arg_sizes,
grid_dim, block_dim,
dynamic_shared_mem_bytes);
}

bool direct_dispatch() const override {
Expand Down

0 comments on commit d755838

Please sign in to comment.