Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
add op unitest for trunc (#1506)
Browse files Browse the repository at this point in the history
  • Loading branch information
MayYouBeProsperous authored Jun 7, 2023
1 parent 51da94a commit 6c341cd
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cinn/runtime/cuda/cinn_cuda_runtime_source.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ __device__ inline int FN_INT32(bitwise_not)(int a) { return ~a; }
__device__ inline int FN_INT32(clz)(int a) { return __clz(a); }
__device__ inline int FN_INT32(popc)(int a) { return __popc(a); }
__device__ inline int FN_INT32(logical_right_shift)(int a, int b) { return ((unsigned int)a >> b); }
__device__ inline int FN_INT32(trunc)(int a) { return a; }

__device__ inline int FN_INT32(max)(int a, int b) { return max(a, b); }
__device__ inline int FN_INT32(min)(int a, int b) { return min(a, b); }
Expand All @@ -170,6 +171,7 @@ __device__ inline long long int FN_INT64(bitwise_xor)(long long int a, long long
__device__ inline long long int FN_INT64(bitwise_not)(long long int a) { return ~a; }
__device__ inline long long int FN_INT64(clz)(long long int a) { return __clzll(a); }
__device__ inline long long int FN_INT64(popc)(long long int a) { return __popcll(a); }
__device__ inline long long int FN_INT64(trunc)(long long int a) { return a; }
__device__ inline long long int FN_INT64(mod)(long long int a, long long int b) {
long long int res = a % b;
if ((res != 0) && ((b ^ res) < 0)) res += b;
Expand Down
2 changes: 2 additions & 0 deletions cinn/runtime/cuda/cuda_intrinsics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ CINN_REGISTER_HELPER(cuda_intrinsics) {
REGISTER_EXTERN_FUNC_1_IN_1_INT32(bitwise_not)
REGISTER_EXTERN_FUNC_1_IN_1_INT32(clz)
REGISTER_EXTERN_FUNC_1_IN_1_INT32(popc)
REGISTER_EXTERN_FUNC_1_IN_1_INT32(trunc)

#undef REGISTER_EXTERN_FUNC_1_IN_1_INT32

Expand All @@ -213,6 +214,7 @@ CINN_REGISTER_HELPER(cuda_intrinsics) {
REGISTER_EXTERN_FUNC_1_IN_1_INT64(bitwise_not)
REGISTER_EXTERN_FUNC_1_IN_1_INT64(clz)
REGISTER_EXTERN_FUNC_1_IN_1_INT64(popc)
REGISTER_EXTERN_FUNC_1_IN_1_INT64(trunc)

#undef REGISTER_EXTERN_FUNC_1_IN_1_INT64

Expand Down
111 changes: 111 additions & 0 deletions python/tests/ops/test_trunc_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#!/usr/bin/env python3

# Copyright (c) 2023 CINN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from op_test import OpTest, OpTestTool
from op_test_helper import TestCaseHelper
import paddle
import cinn
from cinn.frontend import *
from cinn.common import *


@OpTestTool.skip_if(not is_compiled_with_cuda(),
"x86 test will be skipped due to timeout.")
class TestTruncOp(OpTest):
def setUp(self):
print(f"\nRunning {self.__class__.__name__}: {self.case}")
self.prepare_inputs()

def prepare_inputs(self):
self.x_np = self.random(
shape=self.case["x_shape"],
dtype=self.case["x_dtype"],
low=-1000.0,
high=1000.0)

def build_paddle_program(self, target):
x = paddle.to_tensor(self.x_np, stop_gradient=True)
out = paddle.trunc(x)
self.paddle_outputs = [out]

def build_cinn_program(self, target):
builder = NetBuilder("unary_elementwise_test")
x = builder.create_input(
self.nptype2cinntype(self.case["x_dtype"]), self.case["x_shape"],
"x")
out = builder.trunc(x)
prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.x_np], [out])

self.cinn_outputs = [res[0]]

def test_check_results(self):
self.check_outputs_and_grads()


class TestTruncOpShape(TestCaseHelper):
def init_attrs(self):
self.class_name = "TestTruncOpShape"
self.cls = TestTruncOp
self.inputs = [{
"x_shape": [1],
}, {
"x_shape": [1024],
}, {
"x_shape": [1, 2048],
}, {
"x_shape": [1, 1, 1],
}, {
"x_shape": [32, 64],
}, {
"x_shape": [16, 8, 4, 2],
}, {
"x_shape": [16, 8, 4, 2, 1],
}]
self.dtypes = [{
"x_dtype": "float32",
}]
self.attrs = []


class TestTruncOpDtype(TestCaseHelper):
def init_attrs(self):
self.class_name = "TestTruncOpDtype"
self.cls = TestTruncOp
self.inputs = [{
"x_shape": [32, 64],
}]
self.dtypes = [
{
"x_dtype": "int32",
},
{
"x_dtype": "int64",
},
{
"x_dtype": "float32",
},
{
"x_dtype": "float64",
},
]
self.attrs = []


if __name__ == "__main__":
TestTruncOpShape().run()
TestTruncOpDtype().run()

0 comments on commit 6c341cd

Please sign in to comment.