Skip to content

Commit

Permalink
[SYCL] Enforce constraints from sycl_ext_oneapi_reduction_properties (
Browse files Browse the repository at this point in the history
  • Loading branch information
aelovikov-intel authored Dec 9, 2024
1 parent 9a467fa commit f2c7869
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 24 deletions.
90 changes: 66 additions & 24 deletions sycl/include/sycl/ext/oneapi/experimental/reduction_properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ struct initialize_to_identity_key
};
inline constexpr initialize_to_identity_key::value_t initialize_to_identity;

namespace detail {
struct reduction_property_check_anchor {};
} // namespace detail

template <>
struct is_property_key_of<deterministic_key,
detail::reduction_property_check_anchor>
: std::true_type {};

template <>
struct is_property_key_of<initialize_to_identity_key,
detail::reduction_property_check_anchor>
: std::true_type {};

} // namespace experimental
} // namespace oneapi
} // namespace ext
Expand Down Expand Up @@ -83,60 +97,88 @@ template <typename BinaryOperation>
struct IsDeterministicOperator<DeterministicOperatorWrapper<BinaryOperation>>
: std::true_type {};

template <typename PropertyList>
inline constexpr bool is_valid_reduction_prop_list =
ext::oneapi::experimental::detail::all_are_properties_of_v<
ext::oneapi::experimental::detail::reduction_property_check_anchor,
PropertyList>;

template <typename BinaryOperation, typename PropertyList, typename... Args>
auto convert_reduction_properties(BinaryOperation combiner,
PropertyList properties, Args &&...args) {
if constexpr (is_valid_reduction_prop_list<PropertyList>) {
auto WrappedOp = WrapOp(combiner, properties);
auto RuntimeProps = GetReductionPropertyList(properties);
return sycl::reduction(std::forward<Args>(args)..., WrappedOp,
RuntimeProps);
} else {
// Invalid, will be disabled by SFINAE at the caller side. Make sure no hard
// error is emitted from here.
}
}
} // namespace detail

template <typename BufferT, typename BinaryOperation, typename PropertyList>
auto reduction(BufferT vars, handler &cgh, BinaryOperation combiner,
PropertyList properties) {
PropertyList properties)
-> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
decltype(detail::convert_reduction_properties(
combiner, properties, vars, cgh))> {
detail::CheckReductionIdentity<typename BufferT::value_type, BinaryOperation>(
properties);
auto WrappedOp = detail::WrapOp(combiner, properties);
auto RuntimeProps = detail::GetReductionPropertyList(properties);
return reduction(vars, cgh, WrappedOp, RuntimeProps);
return detail::convert_reduction_properties(combiner, properties, vars, cgh);
}

template <typename T, typename BinaryOperation, typename PropertyList>
auto reduction(T *var, BinaryOperation combiner, PropertyList properties) {
auto reduction(T *var, BinaryOperation combiner, PropertyList properties)
-> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
decltype(detail::convert_reduction_properties(
combiner, properties, var))> {
detail::CheckReductionIdentity<T, BinaryOperation>(properties);
auto WrappedOp = detail::WrapOp(combiner, properties);
auto RuntimeProps = detail::GetReductionPropertyList(properties);
return reduction(var, WrappedOp, RuntimeProps);
return detail::convert_reduction_properties(combiner, properties, var);
}

template <typename T, size_t Extent, typename BinaryOperation,
typename PropertyList>
auto reduction(span<T, Extent> vars, BinaryOperation combiner,
PropertyList properties) {
PropertyList properties)
-> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
decltype(detail::convert_reduction_properties(
combiner, properties, vars))> {
detail::CheckReductionIdentity<T, BinaryOperation>(properties);
auto WrappedOp = detail::WrapOp(combiner, properties);
auto RuntimeProps = detail::GetReductionPropertyList(properties);
return reduction(vars, WrappedOp, RuntimeProps);
return detail::convert_reduction_properties(combiner, properties, vars);
}

template <typename BufferT, typename BinaryOperation, typename PropertyList>
auto reduction(BufferT vars, handler &cgh,
const typename BufferT::value_type &identity,
BinaryOperation combiner, PropertyList properties) {
auto WrappedOp = detail::WrapOp(combiner, properties);
auto RuntimeProps = detail::GetReductionPropertyList(properties);
return reduction(vars, cgh, identity, WrappedOp, RuntimeProps);
BinaryOperation combiner, PropertyList properties)
-> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
decltype(detail::convert_reduction_properties(
combiner, properties, vars, cgh, identity))> {
return detail::convert_reduction_properties(combiner, properties, vars, cgh,
identity);
}

template <typename T, typename BinaryOperation, typename PropertyList>
auto reduction(T *var, const T &identity, BinaryOperation combiner,
PropertyList properties) {
auto WrappedOp = detail::WrapOp(combiner, properties);
auto RuntimeProps = detail::GetReductionPropertyList(properties);
return reduction(var, identity, WrappedOp, RuntimeProps);
PropertyList properties)
-> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
decltype(detail::convert_reduction_properties(
combiner, properties, var, identity))> {
return detail::convert_reduction_properties(combiner, properties, var,
identity);
}

template <typename T, size_t Extent, typename BinaryOperation,
typename PropertyList>
auto reduction(span<T, Extent> vars, const T &identity,
BinaryOperation combiner, PropertyList properties) {
auto WrappedOp = detail::WrapOp(combiner, properties);
auto RuntimeProps = detail::GetReductionPropertyList(properties);
return reduction(vars, identity, WrappedOp, RuntimeProps);
BinaryOperation combiner, PropertyList properties)
-> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
decltype(detail::convert_reduction_properties(
combiner, properties, vars, identity))> {
return detail::convert_reduction_properties(combiner, properties, vars,
identity);
}

} // namespace _V1
Expand Down
28 changes: 28 additions & 0 deletions sycl/test/extensions/properties/properties_reduction.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -fsyntax-only -Xclang -verify -Xclang -verify-ignore-unexpected=note %s

#include <sycl/sycl.hpp>

int main() {
int *r = nullptr;
// Must not use `sycl_ext_oneapi_reduction_properties`'s overloads:
std::ignore =
sycl::reduction(r, sycl::plus<int>{},
sycl::property::reduction::initialize_to_identity{});

namespace sycl_exp = sycl::ext::oneapi::experimental;
std::ignore =
sycl::reduction(r, sycl::plus<int>{},
sycl_exp::properties(sycl_exp::initialize_to_identity));

// Not a property list:
// expected-error@+2 {{no matching function for call to 'reduction'}}
std::ignore =
sycl::reduction(r, sycl::plus<int>{}, sycl_exp::initialize_to_identity);

// Not a reduction property:
// expected-error@+2 {{no matching function for call to 'reduction'}}
std::ignore =
sycl::reduction(r, sycl::plus<int>{},
sycl_exp::properties(sycl_exp::initialize_to_identity,
sycl_exp::full_group));
}

0 comments on commit f2c7869

Please sign in to comment.