Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

op unittest for cbrt/ceil/cholesky/concat/constant/fill_constant #1495

Merged
merged 11 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,17 +246,6 @@ Placeholder NetBuilder::CreateInput(const Variable& var) {
return Placeholder(var);
}

Variable NetBuilder::FillConstant(
const std::vector<int>& shape, float value, const std::string& name, const std::string& dtype, bool force_cpu) {
auto out =
CustomInstr("fill_constant", {}, {{"shape", shape}, {"value", value}, {"dtype", dtype}, {"force_cpu", force_cpu}})
.front();
if (!name.empty()) {
out.set_id(cinn::utils::TransValidVarName(name));
}
return out;
}

Variable NetBuilder::FillConstant(const std::vector<int>& shape,
const std::string& str_value,
const std::string& name,
Expand Down
18 changes: 14 additions & 4 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ class NetBuilder {
const std::string& id_hint = "");

/**
* @brief Create constant tensor with the specific value/vector and type, the type is infered from value.
* @brief Create constant tensor with the specific value/vector and type
* @param value The constant value to be set.
* @param name The name of output variable.
* @return The result variable.
Expand Down Expand Up @@ -408,11 +408,21 @@ class NetBuilder {
* @param force_cpu Whether the variable should force placed in cpu, default in device memory. Default is false.
* @return The result variable.
*/
template <typename T = float>
Variable FillConstant(const cinn::utils::ShapeType& shape,
float value,
T value,
const std::string& name,
const std::string& dtype,
bool force_cpu = false);
bool force_cpu = false) {
auto out =
CustomInstr(
"fill_constant", {}, {{"shape", shape}, {"value", value}, {"dtype", dtype}, {"force_cpu", force_cpu}})
.front();
if (!name.empty()) {
out.set_id(cinn::utils::TransValidVarName(name));
}
return out;
}

/**
* @brief The op return a variable with the specific string value, shape and type.
Expand Down Expand Up @@ -442,7 +452,7 @@ class NetBuilder {
T value,
const std::string& name = "",
bool force_cpu = false) {
return FillConstant(shape, static_cast<float>(value), name, common::Type2Str(common::type_of<T>()), force_cpu);
return FillConstant<T>(shape, value, name, common::Type2Str(common::type_of<T>()), force_cpu);
}

/**
Expand Down
25 changes: 19 additions & 6 deletions cinn/hlir/op/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ std::shared_ptr<OpStrategy> StrategyForConstScalar(const framework::NodeAttr &at
framework::CINNCompute const_scalar_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of const_float compute is empty! Please check.";
auto scalar = GetScalarExpr(attrs.attr_store.at("value"));
auto scalar_type = out_type.at(0);
CINNValuePack pack_args = args[0];
std::string tensor_name = UniqName("const_scalar_Out");
if (FLAGS_cinn_ir_schedule) {
Expand All @@ -210,7 +211,12 @@ std::shared_ptr<OpStrategy> StrategyForConstScalar(const framework::NodeAttr &at
}

auto out = lang::Compute(
{Expr(1)}, [=](const std::vector<Expr> &indice) { return scalar; }, tensor_name);
{Expr(1)},
[=](const std::vector<Expr> &indice) {
auto res = (scalar_type == scalar->type()) ? scalar : ir::Cast::Make(scalar_type, scalar);
return res;
},
tensor_name);
CHECK(out.defined()) << "can't create const scalar with the given type " << out_type[0];
auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}};
Expand All @@ -229,9 +235,16 @@ std::vector<shape_t> InferShapeForConstScalar(const std::vector<shape_t> &inputs
}

std::vector<Type> InferDtypeForConstScalar(const std::vector<Type> &inputs_type, const framework::AttrMapType &attrs) {
CHECK(attrs.count("value"));
auto scalar = GetScalarExpr(attrs.at("value"));
auto out_type = scalar->type();
Type out_type;
if (attrs.find("dtype") != attrs.end()) {
auto dtype_str = absl::get<std::string>(attrs.at("dtype"));
if (!dtype_str.empty()) {
out_type = common::Str2Type(dtype_str);
}
} else {
auto scalar = GetScalarExpr(attrs.at("value"));
out_type = scalar->type();
}
VLOG(3) << "scalar type: " << out_type;
return {out_type};
}
Expand Down Expand Up @@ -356,10 +369,10 @@ std::vector<std::vector<std::string>> InferLayoutForFillConstant(const std::vect

#define EXPAND_ATTR_TYPE(MACRO) \
MACRO(bool) \
MACRO(float) \
MACRO(int) \
MACRO(int64_t) \
MACRO(double)
MACRO(double) \
MACRO(float)

std::shared_ptr<OpStrategy> StrategyForAssignValue(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
Expand Down
36 changes: 17 additions & 19 deletions cinn/pybind/frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ static const char *SnakeName(const char *name) {

#define EXPAND_CINN_SUPPORT_TYPE(EXPAND_MACRO) \
EXPAND_MACRO(bool) \
EXPAND_MACRO(float) \
EXPAND_MACRO(int) \
EXPAND_MACRO(int64_t) \
EXPAND_MACRO(double)

thisjiang marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -405,14 +403,23 @@ void BindFrontend(pybind11::module *m) {
#undef EXPAND_QUINTIC_VECTOR
#undef EXPAND_SEXTIC_VECTOR
#undef PY_REGISTER_CONSTANT_OP
#define PY_REGISTER_FILLCONSTANT_OP(TYPE__) \
.def("fill_constant", \
static_cast<Variable (NetBuilder::*)( \
const std::vector<int> &, TYPE__, const std::string &, bool)>( \
&NetBuilder::template FillConstant<TYPE__>), \
py::arg("shape"), \
py::arg("value"), \
py::arg("name") = "", \
#define PY_REGISTER_FILLCONSTANT_OP(TYPE__) \
.def("fill_constant", \
static_cast<Variable (NetBuilder::*)( \
const std::vector<int> &, TYPE__, const std::string &, const std::string &, bool)>( \
&NetBuilder::FillConstant<TYPE__>), \
py::arg("shape"), \
py::arg("value"), \
py::arg("name") = "", \
py::arg("dtype"), \
py::arg("force_cpu") = false) \
.def("fill_constant", \
static_cast<Variable (NetBuilder::*)( \
const std::vector<int> &, TYPE__, const std::string &, bool)>( \
&NetBuilder::template FillConstant<TYPE__>), \
py::arg("shape"), \
py::arg("value"), \
py::arg("name") = "", \
py::arg("force_cpu") = false)
EXPAND_CINN_SUPPORT_TYPE(PY_REGISTER_FILLCONSTANT_OP)
#undef PY_REGISTER_FILLCONSTANT_OP
Expand Down Expand Up @@ -462,15 +469,6 @@ void BindFrontend(pybind11::module *m) {
.def("name", &NetBuilder::name)
.def("__str__", [](NetBuilder &self) { return self.name(); })
.def("append_instruction", &NetBuilder::AppendInstruction, py::arg("instr"))
.def("fill_constant",
static_cast<Variable (NetBuilder::*)(
const std::vector<int> &, float, const std::string &, const std::string &, bool)>(
&NetBuilder::FillConstant),
py::arg("shape"),
py::arg("value"),
py::arg("name") = "",
py::arg("dtype"),
py::arg("force_cpu") = false)
.def("fill_constant",
static_cast<Variable (NetBuilder::*)(
const std::vector<int> &, const std::string &, const std::string &, const std::string &, bool)>(
Expand Down
137 changes: 97 additions & 40 deletions python/tests/ops/test_cbrt_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest, OpTestTool
import paddle
import paddle.nn.functional as F
import cinn
from cinn.frontend import *
import numpy as np
from cinn.common import *
from cinn.frontend import *
from op_test import OpTest, OpTestTool
from op_test_helper import TestCaseHelper


@OpTestTool.skip_if(not is_compiled_with_cuda(),
"x86 test will be skipped due to timeout.")
class TestCbrtOp(OpTest):
def setUp(self):
self.init_case()
print(f"\nRunning {self.__class__.__name__}: {self.case}")
self.inputs = {}
self.prepare_inputs()

def init_case(self):
def prepare_inputs(self):
self.inputs = {
"x": np.array([0, 1, 0.01, 27, 1000000,
0.970299]).astype("float32")
"x":
self.random(self.case["shape"], self.case["dtype"], -100.0, 100.0),
}

def build_paddle_program(self, target):
Expand All @@ -43,44 +43,101 @@ def build_paddle_program(self, target):

def build_cinn_program(self, target):
builder = NetBuilder("cbrt")
x = builder.create_input(Float(32), self.inputs["x"].shape, "x")
x = builder.create_input(
self.nptype2cinntype(self.inputs["x"].dtype),
self.inputs["x"].shape, "x")
out = builder.cbrt(x)

prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]],
[out])

self.cinn_outputs = [res[0]]
self.cinn_outputs = res

def test_check_results(self):
self.check_outputs_and_grads()


class TestCbrtCase1(TestCbrtOp):
def init_case(self):
self.inputs = {
"x":
np.array([0, 1, 0.01, 27, 1000000, 0.970299, 124483,
13.7396]).astype("float32")
}


class TestCbrtCase2(TestCbrtOp):
def init_case(self):
self.inputs = {
"x":
np.array([[0, 1, 0.01, 27], [1000000, 0.970299, 124483,
13.7396]]).astype("float32"),
}


class TestCbrtCase3(TestCbrtOp):
def init_case(self):
np.random.seed(0)
self.inputs = {
"x": np.random.random((32, 64)).astype("float32"),
}
self.check_outputs_and_grads(
max_relative_error=1e-3, max_absolute_error=1e-3)


class TestCbrtOpShape(TestCaseHelper):
def init_attrs(self):
self.class_name = "TestCbrtOpShape"
self.cls = TestCbrtOp
self.inputs = [
{
"shape": [10],
},
{
"shape": [8, 5],
},
{
"shape": [10, 3, 5],
},
{
"shape": [80, 40, 5, 7],
},
{
"shape": [80, 1, 5, 7],
},
{
"shape": [80, 3, 1024, 7],
},
{
"shape": [10, 5, 1024, 2048],
},
{
"shape": [1],
},
{
"shape": [512],
},
{
"shape": [1024],
},
{
"shape": [2048],
},
{
"shape": [1, 1, 1, 1],
},
]
self.dtypes = [
{
"dtype": "float32"
},
]
self.attrs = []


class TestCbrtOpDtype(TestCaseHelper):
def init_attrs(self):
self.class_name = "TestCbrtOpDtype"
self.cls = TestCbrtOp
self.inputs = [
{
"shape": [1],
},
{
"shape": [5],
},
{
"shape": [80, 40, 5, 7],
},
]
self.dtypes = [
{
"dtype": "float16"
},
{
"dtype": "float32"
},
{
"dtype": "float64"
},
]
self.attrs = []


if __name__ == "__main__":
unittest.main()
TestCbrtOpShape().run()
TestCbrtOpDtype().run()
Loading