diff --git a/lite/kernels/xpu/concat_compute.cc b/lite/kernels/xpu/concat_compute.cc index 00cf5bb1012..8b4d4965338 100644 --- a/lite/kernels/xpu/concat_compute.cc +++ b/lite/kernels/xpu/concat_compute.cc @@ -104,6 +104,8 @@ using concati64 = paddle::lite::kernels::xpu::ConcatCompute; using concati8 = paddle::lite::kernels::xpu::ConcatCompute; +using concatbool = + paddle::lite::kernels::xpu::ConcatCompute; REGISTER_LITE_KERNEL(concat, kXPU, kFloat, kNCHW, concatfp32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFloat))}) @@ -147,3 +149,9 @@ REGISTER_LITE_KERNEL(concat, kXPU, kInt8, kNCHW, concati8, concat_INT8) {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt8))}) .Finalize(); +REGISTER_LITE_KERNEL(concat, kXPU, kFloat, kNCHW, concatbool, concat_BOOL) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kBool))}) + .BindInput("AxisTensor", + {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kBool))}) + .Finalize();