Skip to content

Commit

Permalink
op unittest for scatter_assign (PaddlePaddle#1390)
Browse files Browse the repository at this point in the history
* op unittest for scatter_assign

* add dtype float16
  • Loading branch information
zzk0 authored and jiahy0825 committed May 25, 2023
1 parent 6dca946 commit 1e3e1f4
Showing 1 changed file with 186 additions and 52 deletions.
238 changes: 186 additions & 52 deletions python/tests/ops/test_scatter_assign_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest, OpTestTool
import paddle
import paddle.nn.functional as F
import cinn
from cinn.frontend import *
from cinn.common import *
from op_test import OpTest, OpTestTool
from op_test_helper import TestCaseHelper


@OpTestTool.skip_if(not is_compiled_with_cuda(),
"x86 test will be skipped due to timeout.")
class TestScatterAssignOp(OpTest):
class TestScatterAssignOpBase(OpTest):
def setUp(self):
self.init_case()
self.target = DefaultNVGPUTarget()
print(f"\nRunning {self.__class__.__name__}: {self.case}")
self.inputs = {}
self.prepare_inputs()

def init_case(self):
self.axis = 0
self.inputs = {
"x": np.random.random([10, 5]).astype("float32"),
"y": np.random.random([3, 5]).astype("float32"),
"index": np.random.randint(0, 10, size=3).astype("int32")
}
def prepare_inputs(self):
self.inputs["x"] = self.random(self.case["x_shape"]).astype(
self.case["x_dtype"])
self.inputs["y"] = self.random(self.case["y_shape"]).astype(
self.case["y_dtype"])
self.inputs["index"] = np.random.randint(
0, self.case["index_upper"],
size=self.case["index_size"]).astype("int32")
self.axis = self.case["axis"]

def build_paddle_program(self, target):
x = self.inputs["x"].copy()
Expand All @@ -60,6 +61,13 @@ def build_paddle_program(self, target):
for j in range(self.inputs["x"].shape[1]):
for k in range(self.inputs["index"].shape[0]):
out[i][j][self.inputs["index"][k]] = y[i][j][k]
elif axis == 3:
for i in range(self.inputs["x"].shape[0]):
for j in range(self.inputs["x"].shape[1]):
for k in range(self.inputs["x"].shape[2]):
for l in range(self.inputs["index"].shape[0]):
out[i][j][k][self.inputs["index"]
[l]] = y[i][j][k][l]
else:
self.assertTrue(False, "Axis {} No Implement".format(self.axis))

Expand All @@ -68,10 +76,15 @@ def build_paddle_program(self, target):

def build_cinn_program(self, target):
builder = NetBuilder("scatter_assign")
x = builder.create_input(Float(32), self.inputs["x"].shape, "x")
y = builder.create_input(Float(32), self.inputs["y"].shape, "y")
x = builder.create_input(
OpTest.nptype2cinntype(self.inputs["x"].dtype),
self.inputs["x"].shape, "x")
y = builder.create_input(
OpTest.nptype2cinntype(self.inputs["y"].dtype),
self.inputs["y"].shape, "y")
index = builder.create_input(
Int(32), self.inputs["index"].shape, "index")
OpTest.nptype2cinntype(self.inputs["index"].dtype),
self.inputs["index"].shape, "index")
out = builder.scatter_assign(x, y, index, self.axis)

prog = builder.build()
Expand All @@ -85,45 +98,166 @@ def test_check_results(self):
self.check_outputs_and_grads(all_equal=True)


class TestScatterAssignCase1(TestScatterAssignOp):
def init_case(self):
self.inputs = {
"x": np.random.random([10, 5]).astype("float32"),
"y": np.random.random([10, 3]).astype("float32"),
"index": np.random.randint(0, 5, size=3).astype("int32")
}
self.axis = 1


class TestScatterAssignCase2(TestScatterAssignOp):
def init_case(self):
self.inputs = {
"x": np.random.random([10, 5, 5]).astype("float32"),
"y": np.random.random([10, 5, 3]).astype("float32"),
"index": np.random.randint(0, 5, size=3).astype("int32")
}
self.axis = -1
class TestScatterAssignOp(TestCaseHelper):
def init_attrs(self):
self.class_name = "TestScatterAssignOp"
self.cls = TestScatterAssignOpBase
self.inputs = [
{
"x_shape": [10],
"y_shape": [1],
"index_upper": 10,
"index_size": 1,
"axis": -1
},
{
"x_shape": [10, 5],
"y_shape": [3, 5],
"index_upper": 10,
"index_size": 3,
"axis": 0
},
{
"x_shape": [10, 5, 5],
"y_shape": [10, 5, 4],
"index_upper": 5,
"index_size": 4,
"axis": -1
},
{
"x_shape": [10, 5, 5, 7],
"y_shape": [10, 5, 2, 7],
"index_upper": 5,
"index_size": 2,
"axis": -2
},
{
"x_shape": [10, 5, 1024, 2048],
"y_shape": [10, 5, 2, 2048],
"index_upper": 5,
"index_size": 2,
"axis": -2
},
]
self.dtypes = [
{
"x_dtype": "float32",
"y_dtype": "float32"
},
]
self.attrs = []


class TestScatterAssignCase3(TestScatterAssignOp):
def init_case(self):
self.inputs = {
"x": np.random.random([10]).astype("float32"),
"y": np.random.random([1]).astype("float32"),
"index": np.random.randint(0, 10, size=1).astype("int32")
}
self.axis = -1
class TestScatterAssignOpAttribute(TestCaseHelper):
def init_attrs(self):
self.class_name = "TestScatterAssignOpAttribute"
self.cls = TestScatterAssignOpBase
self.inputs = [
{
"x_shape": [1, 1, 1, 1],
"y_shape": [1, 1, 1, 1],
"index_upper": 1,
"index_size": 1,
"axis": 0,
},
{
"x_shape": [1, 10, 10, 3],
"y_shape": [1, 4, 10, 3],
"index_upper": 10,
"index_size": 4,
"axis": 1,
},
{
"x_shape": [10, 4, 8, 3],
"y_shape": [10, 4, 5, 3],
"index_upper": 8,
"index_size": 5,
"axis": 2,
},
{
"x_shape": [10, 4, 5, 6],
"y_shape": [10, 4, 5, 3],
"index_upper": 6,
"index_size": 3,
"axis": 3,
},
{
"x_shape": [10, 4, 5, 1024],
"y_shape": [10, 4, 5, 768],
"index_upper": 1024,
"index_size": 768,
"axis": -1,
},
{
"x_shape": [1024, 4, 12, 10],
"y_shape": [1024, 4, 5, 10],
"index_upper": 12,
"index_size": 5,
"axis": -2,
},
{
"x_shape": [10, 8192, 12, 10],
"y_shape": [10, 4096, 12, 10],
"index_upper": 8192,
"index_size": 4096,
"axis": -3,
},
{
"x_shape": [2048, 10, 12, 10],
"y_shape": [1024, 10, 12, 10],
"index_upper": 2048,
"index_size": 1024,
"axis": -4,
},
]
self.dtypes = [
{
"x_dtype": "float32",
"y_dtype": "float32"
},
]
self.attrs = []


class TestScatterAssignCase4(TestScatterAssignOp):
def init_case(self):
self.inputs = {
"x": np.random.random([10, 5]).astype("float32"),
"y": np.random.random([3, 5]).astype("float32"),
"index": np.array([0, 5, 0]).astype("int32")
}
self.axis = 0
class TestScatterAssignOpDtype(TestCaseHelper):
def init_attrs(self):
self.class_name = "TestScatterAssignOpDtype"
self.cls = TestScatterAssignOpBase
self.inputs = [
{
"x_shape": [10, 5, 20, 7],
"y_shape": [10, 5, 15, 7],
"index_upper": 20,
"index_size": 15,
"axis": -2
},
]
self.dtypes = [
{
"x_dtype": "float16",
"y_dtype": "float16"
},
{
"x_dtype": "float32",
"y_dtype": "float32"
},
{
"x_dtype": "float64",
"y_dtype": "float64"
},
{
"x_dtype": "int32",
"y_dtype": "int32"
},
{
"x_dtype": "int64",
"y_dtype": "int64"
},
]
self.attrs = []


if __name__ == "__main__":
unittest.main()
TestScatterAssignOp().run()
TestScatterAssignOpAttribute().run()
TestScatterAssignOpDtype().run()

0 comments on commit 1e3e1f4

Please sign in to comment.