Skip to content

Commit 22344f4

Browse files
committed
Provide generic and safe C++ interfaces for warp shuffle: Issue #2976
1 parent c80fce9 commit 22344f4

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
2+
#ifndef _CUDA_FUNCTIONAL_SHUFFLE_SAFETY_H
3+
#define _CUDA_FUNCTIONAL_SHUFFLE_SAFETY_H
4+
5+
#include <cuda/std/detail/__config>
6+
7+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
8+
# pragma GCC system_header
9+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
10+
# pragma clang system_header
11+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
12+
# pragma system_header
13+
#endif // no system header
14+
15+
#include <cuda/std/type_traits>
16+
#include <cuda/std/bit>
17+
#include <cuda/std/memory>
18+
19+
#include <cuda_fp16.h>
20+
21+
#include <cuda_bf16.h>
22+
23+
24+
#define _CCCL_HAS_CUDA_COMPILER 1 //fix for now -- to be deleted later
25+
26+
#if _CCCL_HAS_CUDA_COMPILER
27+
_LIBCUDACXX_BEGIN_NAMESPACE_CUDA
28+
template <typename T>
29+
constexpr bool is_supported_type_v = false;
30+
template <> constexpr bool is_supported_type_v<int> = true;
31+
template <> constexpr bool is_supported_type_v<unsigned int> = true;
32+
template <> constexpr bool is_supported_type_v<long> = true;
33+
template <> constexpr bool is_supported_type_v<unsigned long> = true;
34+
template <> constexpr bool is_supported_type_v<long long> = true;
35+
template <> constexpr bool is_supported_type_v<unsigned long long> = true;
36+
template <> constexpr bool is_supported_type_v<float> = true;
37+
template <> constexpr bool is_supported_type_v<double> = true;
38+
template <> constexpr bool is_supported_type_v<__half> = true;
39+
template <> constexpr bool is_supported_type_v<__half2> = true;
40+
template <> constexpr bool is_supported_type_v<__nv_bfloat16> = true;
41+
template <> constexpr bool is_supported_type_v<__nv_bfloat162> = true;
42+
43+
template <typename T>
44+
T shfl(T var, int srcLane, unsigned mask = 0xFFFFFFFF, int width = warpSize)
45+
{
46+
_CCCL_ASSERT(is_supported_type_v<T>, "T must be a supported type for warp shuffle operations"); // T must be a supported type for warp shuffle operations
47+
_CCCL_ASSERT((width > 0 && (width & (width - 1)) == 0), "Width must be a power of two"); // width must be a power of two
48+
if constexpr(sizeof(T)==4)
49+
{
50+
if constexpr(cuda::std::is_same_v<T, __half2>)//check for __half2
51+
{
52+
__half part_arr[2];
53+
cuda::std::memcpy(part_arr, &var, sizeof(var));
54+
55+
float h2f_one = __half2float(part_arr[0]);
56+
float h2f_two = __half2float(part_arr[1]);
57+
58+
float result_one = __shfl_sync(mask, h2f_one, srcLane, width);
59+
float result_two = __shfl_sync(mask, h2f_two, srcLane, width);
60+
61+
__half f2h_one = __float2half(result_one);
62+
__half f2h_two = __float2half(result_two);
63+
64+
__half2 result = __halves2half2(f2h_one, f2h_two);
65+
}
66+
else if(cuda::std::is_same_v<T, __nv_bfloat162>)//check for __nv_bfloat162
67+
{
68+
__nv_bfloat16 part_arr[2];
69+
cuda::std::memcpy(part_arr, &var, sizeof(var));
70+
71+
float b2f_one = __nv_bfloat162float(part_arr[0]);
72+
float b2f_two = __nv_bfloat162float(part_arr[1]);
73+
74+
float result_one = __shfl_sync(mask, b2f_one, srcLane, width);
75+
float result_two = __shfl_sync(mask, b2f_two, srcLane, width);
76+
77+
__nv_bfloat16 f2b_one = __float2nv_bfloat162(result_one);
78+
__nv_bfloat16 f2b_two = __float2nv_bfloat162(result_two);
79+
80+
__nv_bfloat162 result = __nv_bfloat162(f2b_one, f2b_two);
81+
}
82+
else if constexpr(cuda::std::is_same_v<T, int> || cuda::std::is_same_v<T, unsigned int> || cuda::std::is_same_v<T, long>
83+
|| cuda::std::is_same_v<T, unsigned long>)
84+
{
85+
int var_int = cuda::std::bit_cast<int>(var);
86+
int result = __shfl_sync(mask, var_int, srcLane, width);
87+
T result_t = cuda::std::bit_cast<T>(result);
88+
return result_t;
89+
}
90+
}
91+
}
92+
_LIBCUDACXX_END_NAMESPACE_CUDA
93+
#endif // _CCCL_HAS_CUDA_COMPILER
94+
95+
#endif // _CUDA_FUNCTIONAL_SHUFFLE_SAFETY_H

libcudacxx/include/cuda/cmath

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#endif // no system header
2323

2424
#include <cuda/__cmath/ceil_div.h>
25+
#include <cuda/__cmath/shuffle_safety.h>
2526
#include <cuda/std/cmath>
2627

2728
#endif // _CUDA_CMATH

0 commit comments

Comments
 (0)