Skip to content

Commit b0dbe49

Browse files
author
xiaying
committed
Revert "GeometryInnerProduct: embed bias to MatMul if batch size is 1"
This reverts commit 557540a.
1 parent afa1b48 commit b0dbe49

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

source/geometry/GeometryInnerProduct.cpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ class GeometryInnerProduct : public GeometryComputer {
2222
auto parameter = op->main_as_InnerProduct();
2323
int outputCount = parameter->outputCount();
2424
int srcCount = parameter->weight()->size() / outputCount;
25-
bool hasBias = parameter->biasTerm() > 0;
2625

2726
MNN_ASSERT(inputs.size() == 1);
2827
MNN_ASSERT(outputs.size() == 1);
@@ -77,24 +76,22 @@ class GeometryInnerProduct : public GeometryComputer {
7776
res.extras.emplace_back(tmpInput);
7877
}
7978

80-
std::shared_ptr<Tensor> tmpOutput;
79+
std::shared_ptr<Tensor> tmpOutput(new Tensor);
8180
std::shared_ptr<Tensor> C(new Tensor);
8281
auto& constTensors = context.searchConst(op);
8382
Tensor* weight = nullptr;
8483
Tensor* bias = nullptr;
8584
if (!constTensors.empty()) {
86-
MNN_ASSERT(constTensors.size() == (hasBias ? 2 : 1));
85+
MNN_ASSERT(constTensors.size() == 2);
8786
weight = constTensors[0].get();
88-
bias = hasBias ? constTensors[1].get() : nullptr;
87+
bias = constTensors[1].get();
8988
} else {
9089
auto weightTensor = context.allocConst(op, {outputCount, srcCount}, halide_type_of<float>());
9190
::memcpy(weightTensor.get()->host<float>(), parameter->weight()->data(), parameter->weight()->size()*sizeof(float));
9291
weight = weightTensor.get();
93-
if (hasBias) {
94-
auto biasTensor = context.allocConst(op, {batch, outputCount}, halide_type_of<float>());
95-
::memcpy(biasTensor.get()->host<float>(), parameter->bias()->data(), parameter->bias()->size() * sizeof(float));
96-
bias = biasTensor.get();
97-
}
92+
auto biasTensor = context.allocConst(op, {batch, outputCount}, halide_type_of<float>());
93+
::memcpy(biasTensor.get()->host<float>(), parameter->bias()->data(), parameter->bias()->size()*sizeof(float));
94+
bias = biasTensor.get();
9895
}
9996
{
10097
B = weight;
@@ -104,13 +101,12 @@ class GeometryInnerProduct : public GeometryComputer {
104101
C->setLength(0, batch);
105102
C->setLength(1, outputCount);
106103

107-
auto cmd = GeometryComputerUtils::makeMatMul(A, B, C.get(), bias, false, true);
104+
auto cmd = GeometryComputerUtils::makeMatMul(A, B, C.get(), nullptr, false, true);
108105
res.extras.emplace_back(C);
109106
res.command.emplace_back(std::move(cmd));
110107
}
111108

112-
if (hasBias && batch > 1) {
113-
tmpOutput.reset(new Tensor);
109+
{
114110
tmpOutput->buffer().type = halide_type_of<float>();
115111
tmpOutput->buffer().dimensions = 2;
116112
tmpOutput->setLength(0, batch);
@@ -119,8 +115,6 @@ class GeometryInnerProduct : public GeometryComputer {
119115
auto cmd = GeometryComputerUtils::makeBinary(BinaryOpOperation_ADD, C.get(), bias, tmpOutput.get());
120116
res.extras.emplace_back(tmpOutput);
121117
res.command.emplace_back(std::move(cmd));
122-
} else {
123-
tmpOutput = C;
124118
}
125119

126120
{

0 commit comments

Comments
 (0)