Skip to content

Commit

Permalink
[XPU] Add bool type for concat op (#10527)
Browse files Browse the repository at this point in the history
  • Loading branch information
TR666 authored Jun 26, 2024
1 parent 79b0a58 commit bd60e69
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions lite/kernels/xpu/concat_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ using concati64 =
paddle::lite::kernels::xpu::ConcatCompute<int64_t, PRECISION(kFloat)>;
using concati8 =
paddle::lite::kernels::xpu::ConcatCompute<int8_t, PRECISION(kInt8)>;
using concatbool =
paddle::lite::kernels::xpu::ConcatCompute<bool, PRECISION(kFloat)>;

REGISTER_LITE_KERNEL(concat, kXPU, kFloat, kNCHW, concatfp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFloat))})
Expand Down Expand Up @@ -147,3 +149,9 @@ REGISTER_LITE_KERNEL(concat, kXPU, kInt8, kNCHW, concati8, concat_INT8)
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt8))})
.Finalize();
REGISTER_LITE_KERNEL(concat, kXPU, kFloat, kNCHW, concatbool, concat_BOOL)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kBool))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kBool))})
.Finalize();

0 comments on commit bd60e69

Please sign in to comment.