From bd60e696cc8233b0ebef3b4db932923d668e8c50 Mon Sep 17 00:00:00 2001 From: Terry <38135104+TR666@users.noreply.github.com> Date: Wed, 26 Jun 2024 20:30:19 +0800 Subject: [PATCH] [XPU] Add bool type for concat op (#10527) --- lite/kernels/xpu/concat_compute.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lite/kernels/xpu/concat_compute.cc b/lite/kernels/xpu/concat_compute.cc index 00cf5bb1012..8b4d4965338 100644 --- a/lite/kernels/xpu/concat_compute.cc +++ b/lite/kernels/xpu/concat_compute.cc @@ -104,6 +104,8 @@ using concati64 = paddle::lite::kernels::xpu::ConcatCompute; using concati8 = paddle::lite::kernels::xpu::ConcatCompute; +using concatbool = + paddle::lite::kernels::xpu::ConcatCompute; REGISTER_LITE_KERNEL(concat, kXPU, kFloat, kNCHW, concatfp32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFloat))}) @@ -147,3 +149,9 @@ REGISTER_LITE_KERNEL(concat, kXPU, kInt8, kNCHW, concati8, concat_INT8) {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt8))}) .Finalize(); +REGISTER_LITE_KERNEL(concat, kXPU, kFloat, kNCHW, concatbool, concat_BOOL) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kBool))}) + .BindInput("AxisTensor", + {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kBool))}) + .Finalize();