1
+
2
+ #ifndef _CUDA_FUNCTIONAL_SHUFFLE_SAFETY_H
3
+ #define _CUDA_FUNCTIONAL_SHUFFLE_SAFETY_H
4
+
5
+ #include < cuda/std/detail/__config>
6
+
7
+ #if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
8
+ # pragma GCC system_header
9
+ #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
10
+ # pragma clang system_header
11
+ #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
12
+ # pragma system_header
13
+ #endif // no system header
14
+
15
+ #include < cuda/std/type_traits>
16
+ #include < cuda/std/bit>
17
+ #include < cuda/std/memory>
18
+
19
+ #include < cuda_fp16.h>
20
+
21
+ #include < cuda_bf16.h>
22
+
23
+
24
+ #define _CCCL_HAS_CUDA_COMPILER 1 // fix for now -- to be deleted later
25
+
26
+ #if _CCCL_HAS_CUDA_COMPILER
27
+ _LIBCUDACXX_BEGIN_NAMESPACE_CUDA
28
+ template <typename T>
29
+ constexpr bool is_supported_type_v = false ;
30
+ template <> constexpr bool is_supported_type_v<int > = true ;
31
+ template <> constexpr bool is_supported_type_v<unsigned int > = true ;
32
+ template <> constexpr bool is_supported_type_v<long > = true ;
33
+ template <> constexpr bool is_supported_type_v<unsigned long > = true ;
34
+ template <> constexpr bool is_supported_type_v<long long > = true ;
35
+ template <> constexpr bool is_supported_type_v<unsigned long long > = true ;
36
+ template <> constexpr bool is_supported_type_v<float > = true ;
37
+ template <> constexpr bool is_supported_type_v<double > = true ;
38
+ template <> constexpr bool is_supported_type_v<__half> = true ;
39
+ template <> constexpr bool is_supported_type_v<__half2> = true ;
40
+ template <> constexpr bool is_supported_type_v<__nv_bfloat16> = true ;
41
+ template <> constexpr bool is_supported_type_v<__nv_bfloat162> = true ;
42
+
43
+ template <typename T>
44
+ T shfl (T var, int srcLane, unsigned mask = 0xFFFFFFFF , int width = warpSize)
45
+ {
46
+ _CCCL_ASSERT (is_supported_type_v<T>, " T must be a supported type for warp shuffle operations" ); // T must be a supported type for warp shuffle operations
47
+ _CCCL_ASSERT ((width > 0 && (width & (width - 1 )) == 0 ), " Width must be a power of two" ); // width must be a power of two
48
+ if constexpr (sizeof (T)==4 )
49
+ {
50
+ if constexpr (cuda::std::is_same_v<T, __half2>)// check for __half2
51
+ {
52
+ __half part_arr[2 ];
53
+ cuda::std::memcpy (part_arr, &var, sizeof (var));
54
+
55
+ float h2f_one = __half2float (part_arr[0 ]);
56
+ float h2f_two = __half2float (part_arr[1 ]);
57
+
58
+ float result_one = __shfl_sync (mask, h2f_one, srcLane, width);
59
+ float result_two = __shfl_sync (mask, h2f_two, srcLane, width);
60
+
61
+ __half f2h_one = __float2half (result_one);
62
+ __half f2h_two = __float2half (result_two);
63
+
64
+ __half2 result = __halves2half2 (f2h_one, f2h_two);
65
+ }
66
+ else if (cuda::std::is_same_v<T, __nv_bfloat162>)// check for __nv_bfloat162
67
+ {
68
+ __nv_bfloat16 part_arr[2 ];
69
+ cuda::std::memcpy (part_arr, &var, sizeof (var));
70
+
71
+ float b2f_one = __nv_bfloat162float (part_arr[0 ]);
72
+ float b2f_two = __nv_bfloat162float (part_arr[1 ]);
73
+
74
+ float result_one = __shfl_sync (mask, b2f_one, srcLane, width);
75
+ float result_two = __shfl_sync (mask, b2f_two, srcLane, width);
76
+
77
+ __nv_bfloat16 f2b_one = __float2nv_bfloat162 (result_one);
78
+ __nv_bfloat16 f2b_two = __float2nv_bfloat162 (result_two);
79
+
80
+ __nv_bfloat162 result = __nv_bfloat162 (f2b_one, f2b_two);
81
+ }
82
+ else if constexpr (cuda::std::is_same_v<T, int > || cuda::std::is_same_v<T, unsigned int > || cuda::std::is_same_v<T, long >
83
+ || cuda::std::is_same_v<T, unsigned long >)
84
+ {
85
+ int var_int = cuda::std::bit_cast<int >(var);
86
+ int result = __shfl_sync (mask, var_int, srcLane, width);
87
+ T result_t = cuda::std::bit_cast<T>(result);
88
+ return result_t ;
89
+ }
90
+ }
91
+ }
92
+ _LIBCUDACXX_END_NAMESPACE_CUDA
93
+ #endif // _CCCL_HAS_CUDA_COMPILER
94
+
95
+ #endif // _CUDA_FUNCTIONAL_SHUFFLE_SAFETY_H
0 commit comments