@@ -73,7 +73,7 @@ static void pooling_max_pad(const float *channelInput, float *offsetOutput, int
7373
7474static void poolingMax (const float *channelInput, int inputWidth, int inputHeight, float *channelOutput,
7575 int outputWidth, int outputHeight, int kernelWidth, int kernelHeight, int strideWidth,
76- int strideHeight, int padWidth, int padHeight) {
76+ int strideHeight, int padWidth, int padHeight, MNN::PoolPadType padType ) {
7777 int padTop = padHeight <= 0 ? 0 : (padHeight + strideHeight - 1 ) / strideHeight;
7878 int padBottom = (padHeight + inputHeight - kernelHeight) / strideHeight + 1 ;
7979 int padLeft = padWidth <= 0 ? 0 : (padWidth + strideWidth - 1 ) / strideWidth;
@@ -166,7 +166,8 @@ static void poolingMax(const float *channelInput, int inputWidth, int inputHeigh
166166}
167167
168168static void poolingAvgPad (const float *offsetInput, float *offsetOutput, int inputWidth, int inputHeight,
169- int kernelWidth, int kernelHeight, int inputStep4, int iw, int ih) {
169+ int kernelWidth, int kernelHeight, int inputStep4, int iw, int ih, int padWidth,
170+ int padHeight, MNN::PoolPadType padType) {
170171#ifdef MNN_USE_NEON
171172 float32x4_t sum = vdupq_n_f32 (0 );
172173#else
@@ -175,15 +176,23 @@ static void poolingAvgPad(const float *offsetInput, float *offsetOutput, int inp
175176 float sum2 = 0 ;
176177 float sum3 = 0 ;
177178#endif
179+
180+ const int khs = 0 < -ih ? -ih : 0 ; // max
181+ const int khe = kernelHeight < inputHeight - ih ? kernelHeight : inputHeight - ih; // min
182+ const int kws = 0 < -iw ? -iw : 0 ; // max
183+ const int kwe = kernelWidth < inputWidth - iw ? kernelWidth : inputWidth - iw; // min
184+
178185 // sum
179186 int count = 0 ;
187+ if (padType == MNN::PoolPadType_CAFFE) {
188+ count = (ALIMIN (ih + kernelHeight, inputHeight + padHeight) - ih) *
189+ (ALIMIN (iw + kernelWidth, inputWidth + padWidth) - iw);
190+ } else {
191+ count = (khe - khs) * (kwe - kws);
192+ }
180193
181- const int khs = 0 < -ih ? -ih : 0 ; // max
182- const int khe = kernelHeight < inputHeight - ih ? kernelHeight : inputHeight - ih; // min
183194 const float *kernelInput = offsetInput + khs * inputStep4;
184195 for (int kh = khs; kh < khe; kh++, kernelInput += inputStep4) {
185- const int kws = 0 < -iw ? -iw : 0 ; // max
186- const int kwe = kernelWidth < inputWidth - iw ? kernelWidth : inputWidth - iw; // min
187196 const float *cursorInput = kernelInput + kws * 4 ;
188197 for (int kw = kws; kw < kwe; kw++, cursorInput += 4 ) {
189198#ifdef MNN_USE_NEON
@@ -194,7 +203,6 @@ static void poolingAvgPad(const float *offsetInput, float *offsetOutput, int inp
194203 sum2 += cursorInput[2 ];
195204 sum3 += cursorInput[3 ];
196205#endif
197- count++;
198206 }
199207 }
200208
@@ -222,7 +230,7 @@ static void poolingAvgPad(const float *offsetInput, float *offsetOutput, int inp
222230
223231static void poolingAvg (const float *channelInput, int inputWidth, int inputHeight, float *channelOutput,
224232 int outputWidth, int outputHeight, int kernelWidth, int kernelHeight, int strideWidth,
225- int strideHeight, int padWidth, int padHeight) {
233+ int strideHeight, int padWidth, int padHeight, MNN::PoolPadType padType ) {
226234 int padTop = padHeight <= 0 ? 0 : (padHeight + strideHeight - 1 ) / strideHeight;
227235 int padBottom = (padHeight + inputHeight - kernelHeight) / strideHeight + 1 ;
228236 int padLeft = padWidth <= 0 ? 0 : (padWidth + strideWidth - 1 ) / strideWidth;
@@ -243,7 +251,7 @@ static void poolingAvg(const float *channelInput, int inputWidth, int inputHeigh
243251 for (int ow = 0 , iw = -padWidth; ow < outputWidth;
244252 ow++, iw += strideWidth, offsetOutput += 4 , offsetInput += strideWidth4) {
245253 poolingAvgPad (offsetInput, offsetOutput, inputWidth, inputHeight, kernelWidth, kernelHeight, inputStep4,
246- iw, ih);
254+ iw, ih, padWidth, padHeight, padType );
247255 }
248256 }
249257 for (int oh = padTop, ih = -padHeight + oh * strideHeight; oh < padBottom;
@@ -253,14 +261,14 @@ static void poolingAvg(const float *channelInput, int inputWidth, int inputHeigh
253261 for (int ow = 0 , iw = -padWidth; ow < padLeft;
254262 ow++, iw += strideWidth, offsetOutput += 4 , offsetInput += strideWidth4) {
255263 poolingAvgPad (offsetInput, offsetOutput, inputWidth, inputHeight, kernelWidth, kernelHeight, inputStep4,
256- iw, ih);
264+ iw, ih, padWidth, padHeight, padType );
257265 }
258266 offsetInput = lineInput + padRight * strideWidth * 4 ;
259267 offsetOutput = lineOutput + padRight * 4 ;
260268 for (int ow = padRight, iw = -padWidth + ow * strideWidth; ow < outputWidth;
261269 ow++, iw += strideWidth, offsetOutput += 4 , offsetInput += strideWidth4) {
262270 poolingAvgPad (offsetInput, offsetOutput, inputWidth, inputHeight, kernelWidth, kernelHeight, inputStep4,
263- iw, ih);
271+ iw, ih, padWidth, padHeight, padType );
264272 }
265273 }
266274 for (int oh = padBottom, ih = -padHeight + oh * strideHeight; oh < outputHeight;
@@ -270,7 +278,7 @@ static void poolingAvg(const float *channelInput, int inputWidth, int inputHeigh
270278 for (int ow = 0 , iw = -padWidth; ow < outputWidth;
271279 ow++, iw += strideWidth, offsetOutput += 4 , offsetInput += strideWidth4) {
272280 poolingAvgPad (offsetInput, offsetOutput, inputWidth, inputHeight, kernelWidth, kernelHeight, inputStep4,
273- iw, ih);
281+ iw, ih, padWidth, padHeight, padType );
274282 }
275283 }
276284 }
@@ -368,6 +376,8 @@ ErrorCode CPUPool::onResize(const std::vector<Tensor *> &inputs, const std::vect
368376 int padNeededHeight = (output->height () - 1 ) * strideHeight + kernelHeight - input->height ();
369377 padWidth = padNeededWidth > 0 ? padNeededWidth / 2 : 0 ;
370378 padHeight = padNeededHeight > 0 ? padNeededHeight / 2 : 0 ;
379+ } else if (layer->padType () == PoolPadType_VALID) {
380+ padWidth = padHeight = 0 ;
371381 }
372382 auto poolType = layer->type ();
373383 auto planeFunction = poolingMax;
@@ -380,13 +390,14 @@ ErrorCode CPUPool::onResize(const std::vector<Tensor *> &inputs, const std::vect
380390 auto inputPlaneStride = 4 * input->width () * input->height ();
381391 auto outputPlaneStride = 4 * output->width () * output->height ();
382392 int threadNumber = ((CPUBackend *)backend ())->threadNumber ();
393+ auto padType = layer->padType ();
383394 mFunction = [=]() {
384395 MNN_CONCURRENCY_BEGIN (tId, threadNumber) {
385396 for (int channel = (int )tId; channel < totalDepth; channel += threadNumber) {
386397 // run
387398 planeFunction (inputData + channel * inputPlaneStride, input->width (), input->height (),
388399 outputData + outputPlaneStride * channel, output->width (), output->height (), kernelWidth,
389- kernelHeight, strideWidth, strideHeight, padWidth, padHeight);
400+ kernelHeight, strideWidth, strideHeight, padWidth, padHeight, padType );
390401 }
391402 }
392403 MNN_CONCURRENCY_END ();
0 commit comments