Skip to content

Commit

Permalink
[SYCL] Fix SYCL kernel lambda argument type detection (intel#11679)
Browse files Browse the repository at this point in the history
We have a helper which is used to extract a type of the first SYCL
kernel lambda argument to do some error-checking and special handling
based on that.

That check, however, was missing a case when a kernel lambda is also
accepting `kernel_handler` argument, always falling back to a suggested
type in that case. This led to a situations where we couldn't compile
code like:

```c++
sycl::queue q;
q.parallel_for(sycl::range{1}, [=](sycl::item<1, false>, kernel_handler) {});
```

This patch adds extra specializations of some internal helpers to fix
the error.

This is a follow-up from intel#11625
  • Loading branch information
AlexeySachkov authored Oct 27, 2023
1 parent 4156f78 commit 6ba7b52
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
16 changes: 10 additions & 6 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,15 @@ static Arg member_ptr_helper(RetType (Func::*)(Arg) const);
template <typename RetType, typename Func, typename Arg>
static Arg member_ptr_helper(RetType (Func::*)(Arg));

// template <typename RetType, typename Func>
// static void member_ptr_helper(RetType (Func::*)() const);
// Version with two arguments to handle the case when kernel_handler is passed
// to a lambda
template <typename RetType, typename Func, typename Arg1, typename Arg2>
static Arg1 member_ptr_helper(RetType (Func::*)(Arg1, Arg2) const);

// template <typename RetType, typename Func>
// static void member_ptr_helper(RetType (Func::*)());
// Non-const version of the above template to match functors whose 'operator()'
// is declared w/o the 'const' qualifier.
template <typename RetType, typename Func, typename Arg1, typename Arg2>
static Arg1 member_ptr_helper(RetType (Func::*)(Arg1, Arg2));

template <typename F, typename SuggestedArgType>
decltype(member_ptr_helper(&F::operator())) argument_helper(int);
Expand Down Expand Up @@ -1280,7 +1284,7 @@ class __SYCL_EXPORT handler {
using KName = std::conditional_t<std::is_same<KernelType, NameT>::value,
decltype(Wrapper), NameWT>;

kernel_parallel_for_wrapper<KName, item<Dims>, decltype(Wrapper),
kernel_parallel_for_wrapper<KName, TransformedArgType, decltype(Wrapper),
PropertiesT>(Wrapper);
#ifndef __SYCL_DEVICE_ONLY__
// We are executing over the rounded range, but there are still
Expand All @@ -1290,7 +1294,7 @@ class __SYCL_EXPORT handler {
// of the user range, instead of the rounded range.
detail::checkValueRange<Dims>(UserRange);
MNDRDesc.set(*RoundedRange);
StoreLambda<KName, decltype(Wrapper), Dims, item<Dims>>(
StoreLambda<KName, decltype(Wrapper), Dims, TransformedArgType>(
std::move(Wrapper));
setType(detail::CG::Kernel);
#endif
Expand Down
16 changes: 15 additions & 1 deletion sycl/test/basic_tests/handler/parallel_for_args.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ int main() {
q.parallel_for(r2, [=](sycl::item<2> it) {});
q.parallel_for(r3, [=](sycl::item<3> it) {});

q.parallel_for(r1, [=](sycl::item<1, false> it) {});
q.parallel_for(r2, [=](sycl::item<2, false> it) {});
q.parallel_for(r3, [=](sycl::item<3, false> it) {});

// int, size_t -> sycl::item
q.parallel_for(r1, [=](int it) {});
q.parallel_for(r1, [=](size_t it) {});

// sycl::item -> sycl::id
q.parallel_for(r1, [=](sycl::id<1> it) {});
q.parallel_for(r2, [=](sycl::id<2> it) {});
Expand All @@ -51,6 +59,13 @@ int main() {
q.parallel_for(r2, [=](sycl::item<2> it, sycl::kernel_handler kh) {});
q.parallel_for(r3, [=](sycl::item<3> it, sycl::kernel_handler kh) {});

q.parallel_for(r1, [=](int it, sycl::kernel_handler kh) {});
q.parallel_for(r1, [=](size_t it, sycl::kernel_handler kh) {});

q.parallel_for(r1, [=](sycl::item<1, false> it, sycl::kernel_handler kh) {});
q.parallel_for(r2, [=](sycl::item<2, false> it, sycl::kernel_handler kh) {});
q.parallel_for(r3, [=](sycl::item<3, false> it, sycl::kernel_handler kh) {});

q.parallel_for(r1, [=](sycl::id<1> it, sycl::kernel_handler kh) {});
q.parallel_for(r2, [=](sycl::id<2> it, sycl::kernel_handler kh) {});
q.parallel_for(r3, [=](sycl::id<3> it, sycl::kernel_handler kh) {});
Expand Down Expand Up @@ -90,5 +105,4 @@ int main() {
[=](ConvertibleFromNDItem<3> it, sycl::kernel_handler kh) {});

// TODO: consider adding test cases for hierarchical parallelism
// TODO: consider adding cases for sycl::item with offset
}

0 comments on commit 6ba7b52

Please sign in to comment.