Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
add gather_nd tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed May 6, 2023
1 parent d25dba7 commit 38abe6a
Showing 1 changed file with 63 additions and 45 deletions.
108 changes: 63 additions & 45 deletions python/tests/ops/test_gather_nd_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,76 +14,94 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
# Copyright (c) 2022 CINN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest, OpTestTool
import paddle
import cinn
from cinn.frontend import *
from cinn.common import *
import logging
import os
from itertools import product

logging.basicConfig(level=os.environ.get('LOG_LEVEL', 'INFO').upper())
logger = logging.getLogger(name="gather_nd")


@OpTestTool.skip_if(not is_compiled_with_cuda(),
"x86 test will be skipped due to timeout.")
class TestGatherNdOp(OpTest):
def setUp(self):
self.data = []
self.init_case()

def init_case(self):
self.inputs = {
'x': self.random([2, 3, 4], 'float32'),
'index': np.array([[1]], dtype='int32')
}
self.inputs = [{"x": [3, 4, 3], "index": [4, 1]}]
self.dtypes = ["float32"]

def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
index = paddle.to_tensor(self.inputs["index"], stop_gradient=False)
out = paddle.gather_nd(x, index)
self.paddle_outputs = [out]
for inputs, dtype in product(self.inputs, self.dtypes):
x_shape = inputs["x"]
index_shape = inputs["index"]
x = np.random.randn(*x_shape).astype(dtype)
index = np.random.randint(0, min(x_shape),
index_shape).astype("int32")
self.data.append([x, index])
x = paddle.to_tensor(x, stop_gradient=True)
index = paddle.to_tensor(index, stop_gradient=True)
out = paddle.gather_nd(x, index)
logger.debug(" -- The output of Paddle:\n{}".format(out))
self.paddle_outputs.append(out)

def build_cinn_program(self, target):
builder = NetBuilder("GatherNd")
x = builder.create_input(
self.nptype2cinntype(self.inputs["x"].dtype),
self.inputs["x"].shape, "x")
index = builder.create_input(
self.nptype2cinntype(self.inputs["index"].dtype),
self.inputs["index"].shape, "index")
out = builder.gather_nd(x, index)

prog = builder.build()
res = self.get_cinn_output(prog, target, [x, index],
[self.inputs["x"], self.inputs["index"]],
[out])

self.cinn_outputs = [res[0]]
for i, (inputs, dtype) in enumerate(product(self.inputs, self.dtypes)):
builder = NetBuilder("gather")
x = builder.create_input(
self.nptype2cinntype(dtype), inputs["x"], "x")
index = builder.create_input(Int(32), inputs["index"], "index")
out = builder.gather_nd(x, index)
prog = builder.build()
res = self.get_cinn_output(prog, target, [x, index], self.data[i],
[out])
logger.debug(" -- The output of CINN:\n{}".format(res))
self.cinn_outputs.extend(res)

def test_check_results(self):
self.check_outputs_and_grads(all_equal=True)


class TestGatherNdCase1(TestGatherNdOp):
def init_case(self):
self.inputs = {
'x': self.random([2, 3, 4], 'float32'),
'index': np.array([[0, 2]], dtype='int32')
}


class TestGatherNdCase2(TestGatherNdOp):
def init_case(self):
self.inputs = {
'x': self.random([2, 3, 4], 'float32'),
'index': np.array([[1, 2, 3]], dtype='int32')
}


class TestGatherNdCase3(TestGatherNdOp):
class TestGatherOpAll(TestGatherNdOp):
def init_case(self):
self.inputs = {
'x': self.random([2, 3, 4], 'float64'),
'index': np.array([[1, 2, 3]], dtype='int64')
}
self.inputs = []
for x_shape in [
[16],
[8, 16],
[4, 8, 16],
[2, 4, 8, 16],
[2, 4, 8, 1],
[2, 4, 8, 1024],
]:
for j in range(1, len(x_shape)):
self.inputs.append({"x": x_shape, "index": [8, j]})

self.dtypes = [
"float32", "float64", "int16", "int32", "int64", "uint8"
]


if __name__ == "__main__":
Expand Down

0 comments on commit 38abe6a

Please sign in to comment.