Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

op unittest for sort & enhance test helper #1411

Merged
merged 18 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cinn/hlir/op/contrib/argmax_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void TestGenerateCodeCpu_Argmax_Keep(void* _args, int32_t num_args)
for (int32_t j = 0; j < 3; j += 1) {
for (int32_t k = 0; k < 28; k += 1) {
for (int32_t a = 0; a < 28; a += 1) {
test_argmax_in_index[((2352 * i) + ((784 * j) + ((28 * k) + a)))] = cinn_host_find_int_nd(_test_argmax_in_index_temp, 3, j, ((2352 * i) + ((28 * k) + a)), 784);
test_argmax_in_index[((2352 * i) + ((784 * j) + ((28 * k) + a)))] = cinn_host_next_smallest_int32(_test_argmax_in_index_temp, 3, j, ((2352 * i) + ((28 * k) + a)), 784);
};
};
};
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/op/contrib/argmin_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void TestGenerateCodeCpu_Argmin_Keep(void* _args, int32_t num_args)
for (int32_t j = 0; j < 3; j += 1) {
for (int32_t k = 0; k < 28; k += 1) {
for (int32_t a = 0; a < 28; a += 1) {
test_argmin_in_index[((2352 * i) + ((784 * j) + ((28 * k) + a)))] = cinn_host_find_int_nd(_test_argmin_in_index_temp, 3, j, ((2352 * i) + ((28 * k) + a)), 784);
test_argmin_in_index[((2352 * i) + ((784 * j) + ((28 * k) + a)))] = cinn_host_next_smallest_int32(_test_argmin_in_index_temp, 3, j, ((2352 * i) + ((28 * k) + a)), 784);
};
};
};
Expand Down
4 changes: 2 additions & 2 deletions cinn/hlir/op/contrib/sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ std::vector<ir::Tensor> ArgSort(const ir::Tensor &A,
std::string find_func_name;
std::string index_func_name;
if (target.arch == common::Target::Arch::NVGPU) {
find_func_name.assign("cinn_cuda_find_int_nd");
find_func_name.assign("cinn_nvgpu_next_smallest_int32");
zzk0 marked this conversation as resolved.
Show resolved Hide resolved
} else if (target.arch == common::Target::Arch::X86) {
find_func_name.assign("cinn_host_find_int_nd");
find_func_name.assign("cinn_host_next_smallest_int32");
} else {
LOG(FATAL) << "ArgSort only supports X86 and NVGPU ! Please Check.\n";
}
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/op/contrib/sort_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ void TestGenerateCodeCpu_Sort(void* _args, int32_t num_args)
};
for (int32_t i = 0; i < 4; i += 1) {
for (int32_t j = 0; j < 28; j += 1) {
test_sort_out_index[((28 * i) + j)] = cinn_host_find_int_nd(_test_sort_out_index_temp, 28, j, (28 * i), 1);
test_sort_out_index[((28 * i) + j)] = cinn_host_next_smallest_int32(_test_sort_out_index_temp, 28, j, (28 * i), 1);
};
};
for (int32_t i = 0; i < 4; i += 1) {
Expand Down
23 changes: 23 additions & 0 deletions cinn/runtime/cpu/host_intrinsics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ inline int cinn_host_find_float_nd(const cinn_buffer_t* buf, int size, float num

#undef __cinn_host_find_kernel

inline int cinn_host_next_smallest_int32(cinn_buffer_t* buf, int size, int num, int begin, int stride) {
int id = -1;
for (int i = begin; i < begin + size * stride; i += stride) {
if (id == -1 || reinterpret_cast<int*>(buf->memory)[i] < reinterpret_cast<int*>(buf->memory)[id]) {
id = i;
}
}
if (id != -1) {
reinterpret_cast<int*>(buf->memory)[id] = 2147483647;
return (id - begin) / stride;
}
return -1;
}

#define CINN_HOST_LT_NUM(TYPE_SUFFIX, TYPE) \
inline int cinn_host_lt_num_##TYPE_SUFFIX( \
const cinn_buffer_t* buf, const int size, const TYPE num, const int offset, const int stride) { \
Expand Down Expand Up @@ -349,6 +363,15 @@ CINN_REGISTER_HELPER(host_intrinsics) {
.AddInputType<int>()
.End();

REGISTER_EXTERN_FUNC_HELPER(cinn_host_next_smallest_int32, host_target)
.SetRetType<int>()
.AddInputType<cinn_buffer_t*>()
.AddInputType<int>()
.AddInputType<int>()
.AddInputType<int>()
.AddInputType<int>()
.End();

#define _REGISTER_CINN_HOST_LT_NUM(TYPE_SUFFIX, TYPE) \
REGISTER_EXTERN_FUNC_HELPER(cinn_host_lt_num_##TYPE_SUFFIX, host_target) \
.SetRetType<int>() \
Expand Down
14 changes: 14 additions & 0 deletions cinn/runtime/cuda/cinn_cuda_runtime_source.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,20 @@ __device__ inline int cinn_cuda_find_float_nd(const float *buf, int size, float

#undef __cinn_cuda_find_kernel

__device__ inline int cinn_nvgpu_next_smallest_int32(int *buf, int size, int num, int begin, int stride) {
int id = -1;
for (int i = begin; i < begin + size * stride; i += stride) {
if (id == -1 || buf[i] < buf[id]) {
id = i;
}
}
if (id != -1) {
buf[id] = 2147483647;
zzk0 marked this conversation as resolved.
Show resolved Hide resolved
return (id - begin) / stride;
}
return -1;
}

#define __cinn_cuda_find_from_kernel(buf, size, num, begin) \
do { \
for (int i = begin; i < size; ++i) { \
Expand Down
9 changes: 9 additions & 0 deletions cinn/runtime/cuda/cuda_intrinsics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,15 @@ CINN_REGISTER_HELPER(cuda_intrinsics) {
.AddInputType<int>()
.End();

REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_nvgpu_next_smallest_int32, target)
.SetRetType<int>()
.AddInputType<cinn_buffer_t *>()
.AddInputType<int>()
.AddInputType<int>()
.AddInputType<int>()
.AddInputType<int>()
.End();

#define _REGISTER_CINN_CUDA_LT_NUM(TYPE_SUFFIX, TYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_cuda_lt_num_##TYPE_SUFFIX, target) \
.SetRetType<int>() \
Expand Down
17 changes: 17 additions & 0 deletions python/tests/ops/op_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import unittest
import re

from unittest import suite
from typing import Union, List

parser = argparse.ArgumentParser(description="Argparse for op test helper")
parser.add_argument(
"--case",
Expand Down Expand Up @@ -104,3 +107,17 @@ def run(self):
res = runner.run(test_suite)
if not res.wasSuccessful():
sys.exit(not res.wasSuccessful())


def run_test(test_class: Union[suite.TestSuite, List[suite.TestSuite]]):
test_suite = unittest.TestSuite()
test_loader = unittest.TestLoader()
if isinstance(test_class, type):
test_suite.addTests(test_loader.loadTestsFromTestCase(test_class))
else:
for cls in test_class:
test_suite.addTests(test_loader.loadTestsFromTestCase(cls))
runner = unittest.TextTestRunner()
res = runner.run(test_suite)
if not res.wasSuccessful():
sys.exit(not res.wasSuccessful())
Loading