Skip to content

Commit cfc5e47

Browse files
authored
Merge pull request #724 from lanchongyizu/bugfix/_concatbatch
[MNN:Bugfix] fix _concatBatch bug
2 parents 208dac8 + aa3817f commit cfc5e47

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

source/backend/cpu/CPUConcat.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,17 @@ static int _concatHeight(const Tensor* outputTensor, const vector<Tensor*>& inpu
8080

8181
static int _concatBatch(const Tensor* outputTensor, const vector<Tensor*>& inputTensors) {
8282
auto outputDim = outputTensor->buffer().dim;
83-
const int batchSize = outputDim[0].extent;
84-
for (int batchIndex = 0; batchIndex < batchSize; ++batchIndex) {
85-
float* outputOrigin = reinterpret_cast<float*>(outputTensor->buffer().host) + outputDim[0].stride * batchIndex;
86-
for (size_t b = 0; b < inputTensors.size(); b++) {
87-
auto& inputTensor = inputTensors[b]->buffer();
83+
int currentPositionB = 0;
84+
for (size_t b = 0; b < inputTensors.size(); b++) {
85+
auto& inputTensor = inputTensors[b]->buffer();
86+
const int batchSize = inputTensor.dim[0].extent;
87+
for (int batchIndex = 0; batchIndex < batchSize; ++batchIndex) {
8888
float* inputOrigin = reinterpret_cast<float*>(inputTensor.host) + inputTensor.dim[0].stride * batchIndex;
89+
float* outputOrigin = reinterpret_cast<float*>(outputTensor->buffer().host) +
90+
outputDim[0].stride * (currentPositionB + batchIndex);
8991
::memcpy(outputOrigin, inputOrigin, inputTensor.dim[0].stride * sizeof(float));
9092
}
93+
currentPositionB += batchSize;
9194
}
9295
return 0;
9396
}

0 commit comments

Comments
 (0)