Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
add elementwise_mul op and tests (#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenming2014 authored Sep 3, 2020
1 parent a040133 commit 671c08a
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 50 deletions.
71 changes: 59 additions & 12 deletions cinn/hlir/op/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,28 @@ std::shared_ptr<OpStrategy> StrategyForElementwiseAdd(const framework::NodeAttr
const Target &target) {
framework::CINNCompute add_compute([&attrs](lang::Args args, lang::RetValue *ret) {
CINNValuePack a = args[0];
ir::Expr A_expr = a[0];
ir::Expr B_expr = a[1];
Expr A_expr = a[0];
Expr B_expr = a[1];
CHECK(A_expr.as_tensor());
CHECK(B_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref();
ir::Tensor B = B_expr.as_tensor_ref();
auto attr_store = attrs.attr_store;
auto iter = attr_store.find("axis");
ir::Expr axis;
Expr axis;
if (iter != attr_store.end()) {
axis = ir::Expr(std::get<int>(iter->second));
axis = Expr(std::get<int>(iter->second));
}

auto out = pe::Add(A, B, UniqName("C"), axis);

auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(ir::Expr(out.get())), CINNValue(stages)}};
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
});

framework::CINNSchedule add_schedule([](lang::Args args, lang::RetValue *ret) {
CINNValuePack arg_pack = args[0];
ir::Expr A [[maybe_unused]] = arg_pack[0];
CINNValuePack arg_pack = args[0];
Expr A [[maybe_unused]] = arg_pack[0];
CHECK_EQ(arg_pack.size(), 2UL);
*ret = arg_pack;
});
Expand All @@ -52,14 +52,52 @@ std::shared_ptr<OpStrategy> StrategyForElementwiseAdd(const framework::NodeAttr
return strategy;
}

std::vector<std::vector<int>> InferShapeForElementwiseAdd(const std::vector<std::vector<int>> &inputs_shape,
const framework::NodeAttr &attrs) {
std::shared_ptr<OpStrategy> StrategyForElementwiseMul(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const Target &target) {
framework::CINNCompute mul_compute([&attrs](lang::Args args, lang::RetValue *ret) {
CINNValuePack a = args[0];
Expr A_expr = a[0];
Expr B_expr = a[1];
CHECK(A_expr.as_tensor());
CHECK(B_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref();
ir::Tensor B = B_expr.as_tensor_ref();
auto attr_store = attrs.attr_store;
auto iter = attr_store.find("axis");
Expr axis;
if (iter != attr_store.end()) {
axis = Expr(std::get<int>(iter->second));
}

auto out = pe::Multiply(A, B, UniqName("C"), axis);

auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
});

framework::CINNSchedule mul_schedule([](lang::Args args, lang::RetValue *ret) {
CINNValuePack arg_pack = args[0];
Expr A [[maybe_unused]] = arg_pack[0];
CHECK_EQ(arg_pack.size(), 2UL);
*ret = arg_pack;
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(mul_compute, mul_schedule, "strategy.elementwise_mul.x86", 1);

return strategy;
}

std::vector<std::vector<int>> InferShapeForElementwise(const std::vector<std::vector<int>> &inputs_shape,
const framework::NodeAttr &attrs) {
CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again.";
std::vector<std::vector<int>> res{inputs_shape[0]};
return res;
}

std::vector<Type> InferDtypeForElementwiseAdd(const std::vector<Type> &inputs_type, const framework::NodeAttr &attrs) {
std::vector<Type> InferDtypeForElementwise(const std::vector<Type> &inputs_type, const framework::NodeAttr &attrs) {
CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again.";
std::vector<Type> res{inputs_type[0]};
return res;
Expand All @@ -75,7 +113,16 @@ CINN_REGISTER_HELPER(broadcast_ops) {
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForElementwiseAdd)
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForElementwiseAdd))
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForElementwiseAdd))
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForElementwise))
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForElementwise))
.set_support_level(4);

CINN_REGISTER_OP(elementwise_mul)
.describe("multiply two tensors")
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForElementwiseMul)
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForElementwise))
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForElementwise))
.set_support_level(4);
}
36 changes: 18 additions & 18 deletions cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ std::shared_ptr<OpStrategy> StrategyForRelu(const framework::NodeAttr &attrs,
const Target &target) {
framework::CINNCompute relu_compute([](lang::Args args, lang::RetValue *ret) {
CINNValuePack a = args[0];
ir::Expr A = a[0];
Expr A = a[0];
CHECK(A.as_tensor());
auto out = pe::Relu<float>(A.as_tensor_ref(), 0.0, UniqName("Relu_output"));
auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(ir::Expr(out.get())), CINNValue(stages)}};
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
});

framework::CINNSchedule relu_schedule([](lang::Args args, lang::RetValue *ret) {
CINNValuePack arg_pack = args[0];
ir::Expr A [[maybe_unused]] = arg_pack[0];
CINNValuePack arg_pack = args[0];
Expr A [[maybe_unused]] = arg_pack[0];
CHECK_EQ(arg_pack.size(), 2UL);
*ret = arg_pack;
});
Expand Down Expand Up @@ -62,16 +62,16 @@ std::shared_ptr<OpStrategy> StrategyForRelu6(const framework::NodeAttr &attrs,
const Target &target) {
framework::CINNCompute relu_compute([](lang::Args args, lang::RetValue *ret) {
CINNValuePack a = args[0];
ir::Expr A = a[0];
Expr A = a[0];
CHECK(A.as_tensor());
auto out = pe::Relu6<float>(A.as_tensor_ref(), 0.0, UniqName("Relu6_output"));
auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(ir::Expr(out.get())), CINNValue(stages)}};
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
});

framework::CINNSchedule relu_schedule([](lang::Args args, lang::RetValue *ret) {
CINNValuePack arg_pack = args[0];
ir::Expr A [[maybe_unused]] = arg_pack[0];
CINNValuePack arg_pack = args[0];
Expr A [[maybe_unused]] = arg_pack[0];
CHECK_EQ(arg_pack.size(), 2UL);
*ret = arg_pack;
});
Expand Down Expand Up @@ -108,8 +108,8 @@ std::shared_ptr<OpStrategy> StrategyForConv2d(const framework::NodeAttr &attrs,
}
framework::CINNCompute conv2d_compute([=](lang::Args args, lang::RetValue *ret) {
CINNValuePack a = args[0];
ir::Expr A = a[0];
ir::Expr B = a[1];
Expr A = a[0];
Expr B = a[1];
CHECK(A.as_tensor());
CHECK(B.as_tensor());
CHECK_EQ(padding.size(), 2) << "The size of padding in conv2d op is not 2! Please check.";
Expand All @@ -126,15 +126,15 @@ std::shared_ptr<OpStrategy> StrategyForConv2d(const framework::NodeAttr &attrs,
auto stages = CreateStages(out);
std::vector<CINNValue> res;
for (auto &t : out) {
res.push_back(CINNValue(ir::Expr(t.get())));
res.push_back(CINNValue(Expr(t.get())));
}
res.push_back(CINNValue(stages));
*ret = CINNValuePack{res};
});

framework::CINNSchedule conv2d_schedule([](lang::Args args, lang::RetValue *ret) {
CINNValuePack arg_pack = args[0];
ir::Expr A [[maybe_unused]] = arg_pack[0];
CINNValuePack arg_pack = args[0];
Expr A [[maybe_unused]] = arg_pack[0];
CHECK_EQ(arg_pack.size(), 4UL);
*ret = arg_pack;
});
Expand Down Expand Up @@ -189,18 +189,18 @@ std::shared_ptr<OpStrategy> StrategyForBatchNorm(const framework::NodeAttr &attr
}
framework::CINNCompute batchnorm_compute([=](lang::Args args, lang::RetValue *ret) {
CINNValuePack a = args[0];
ir::Expr A = a[0];
ir::Expr B = a[1];
Expr A = a[0];
Expr B = a[1];
CHECK(A.as_tensor());
CHECK(B.as_tensor());
auto out = pe::BatchNorm_NCHW(A.as_tensor_ref(), B.as_tensor_ref(), epsilon, UniqName("BatchNorm_output"));
auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(ir::Expr(out.get())), CINNValue(stages)}};
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
});

framework::CINNSchedule batchnorm_schedule([](lang::Args args, lang::RetValue *ret) {
CINNValuePack arg_pack = args[0];
ir::Expr A [[maybe_unused]] = arg_pack[0];
CINNValuePack arg_pack = args[0];
Expr A [[maybe_unused]] = arg_pack[0];
CHECK_EQ(arg_pack.size(), 2UL);
*ret = arg_pack;
});
Expand Down
10 changes: 5 additions & 5 deletions cinn/hlir/op/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ std::shared_ptr<OpStrategy> StrategyForMul(const framework::NodeAttr &attrs,
const Target &target) {
framework::CINNCompute add_compute([&attrs](lang::Args args, lang::RetValue *ret) {
CINNValuePack a = args[0];
ir::Expr A = a[0];
ir::Expr B = a[1];
Expr A = a[0];
Expr B = a[1];
CHECK(A.as_tensor());
CHECK(B.as_tensor());
auto attr_store = attrs.attr_store;
Expand All @@ -43,12 +43,12 @@ std::shared_ptr<OpStrategy> StrategyForMul(const framework::NodeAttr &attrs,
A.as_tensor_ref(), B.as_tensor_ref(), trans_a, trans_b, x_num_col_dims, y_num_col_dims, UniqName("C"));

auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(ir::Expr(out.get())), CINNValue(stages)}};
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
});

framework::CINNSchedule add_schedule([](lang::Args args, lang::RetValue *ret) {
CINNValuePack arg_pack = args[0];
ir::Expr A [[maybe_unused]] = arg_pack[0];
CINNValuePack arg_pack = args[0];
Expr A [[maybe_unused]] = arg_pack[0];
CHECK_EQ(arg_pack.size(), 2UL);
*ret = arg_pack;
});
Expand Down
16 changes: 8 additions & 8 deletions cinn/hlir/pe/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ namespace pe {
*
* @return The result Tensor or Expr.
*/
#define HLIR_DCL_BC_PE(name__) \
ir::Tensor name__(const ir::Tensor& A, \
const ir::Tensor& B, \
const std::string& output_name = "T_" #name__ "_out", \
const ir::Expr& axis = ir::Expr()); \
ir::Tensor name__(const ir::Expr& A, const ir::Tensor& B, const std::string& output_name = "T_" #name__ "_out"); \
ir::Tensor name__(const ir::Tensor& A, const ir::Expr& B, const std::string& output_name = "T_" #name__ "_out"); \
ir::Expr name__(const ir::Expr& A, const ir::Expr& B);
#define HLIR_DCL_BC_PE(name__) \
ir::Tensor name__(const ir::Tensor& A, \
const ir::Tensor& B, \
const std::string& output_name = "T_" #name__ "_out", \
const Expr& axis = Expr()); \
ir::Tensor name__(const Expr& A, const ir::Tensor& B, const std::string& output_name = "T_" #name__ "_out"); \
ir::Tensor name__(const ir::Tensor& A, const Expr& B, const std::string& output_name = "T_" #name__ "_out"); \
Expr name__(const Expr& A, const Expr& B);

//! Compute A + B with auto-broadcasting.
HLIR_DCL_BC_PE(Add);
Expand Down
14 changes: 7 additions & 7 deletions cinn/hlir/pe/reduction.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once
#include <vector>

#include "cinn/ir/node.h"
#include "cinn/ir/ir.h"

namespace cinn {
namespace hlir {
Expand All @@ -22,9 +22,9 @@ namespace pe {
*/
std::vector<ir::Tensor> Sum(const ir::Tensor& A,
poly::StageMap stages,
const std::vector<ir::Expr>& axis,
const std::vector<Expr>& axis,
bool keep_dims = false,
const ir::Expr& initial = ir::Expr(0),
const Expr& initial = Expr(0),
const std::string& output_name = "T_Reduce_Sum_out");

/**
Expand All @@ -43,9 +43,9 @@ std::vector<ir::Tensor> Sum(const ir::Tensor& A,
*/
std::vector<ir::Tensor> Prod(const ir::Tensor& A,
poly::StageMap stages,
const std::vector<ir::Expr>& axis,
const std::vector<Expr>& axis,
bool keep_dims = false,
const ir::Expr& initial = ir::Expr(1),
const Expr& initial = Expr(1),
const std::string& output_name = "T_Reduce_Prod_out");

/**
Expand All @@ -63,7 +63,7 @@ std::vector<ir::Tensor> Prod(const ir::Tensor& A,
*/
ir::Tensor Max(const ir::Tensor& A,
poly::StageMap stages,
const std::vector<ir::Expr>& axis,
const std::vector<Expr>& axis,
bool keep_dims = false,
const std::string& output_name = "T_Reduce_Max_out");

Expand All @@ -82,7 +82,7 @@ ir::Tensor Max(const ir::Tensor& A,
*/
ir::Tensor Min(const ir::Tensor& A,
poly::StageMap stages,
const std::vector<ir::Expr>& axis,
const std::vector<Expr>& axis,
bool keep_dims = false,
const std::string& output_name = "T_Reduce_Min_out");

Expand Down
23 changes: 23 additions & 0 deletions python/tests/test_op_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,28 @@ def test_op(self):
self.to_test_op([[3, 2], [2]], [[3, 2]], "elementwise_add", attrs)


class OpTest_mul_0(SingleOpTester):
def create_target_data(self, inputs_data):
[X, Y] = inputs_data
return X * Y

def test_op(self):
attrs = framework.NodeAttr()
attrs.attr_store = {"axis": 0}
self.to_test_op([[100, 32], [100, 32]], [[100, 32]], "elementwise_mul",
attrs)


class OpTest_mul_1(SingleOpTester):
def create_target_data(self, inputs_data):
[X, Y] = inputs_data
return X * Y

def test_op(self):
attrs = framework.NodeAttr()
attrs.attr_store = {"axis": 1}
self.to_test_op([[3, 2], [2]], [[3, 2]], "elementwise_mul", attrs)


if __name__ == "__main__":
unittest.main()

0 comments on commit 671c08a

Please sign in to comment.