Skip to content

Commit c67b1c3

Browse files
wmaxeymiscco
andauthored
Backport PR #2046 - Fixing FP16 conversions. (#2222)
* Do not rely on conversions between float and extended floating point types (#2046) The issue we have is that our tests rely extensively on those conversions which makes it incredibly painfull to test * Fix including `<complex>` when bad CUDA bfloat/half macros are used. (#2226) * Add <complex> test for bad macros being defined * Fix <complex> failing upon inclusion when bad macros are defined * Rather use explicit specializations and some evil hackery to get the complex interop to work * Fix typos * Inline everything * Move workarounds together * Use conversion functions instead of explicit specializations * Drop unneeded conversions --------- Co-authored-by: Michael Schellenberger Costa <[email protected]> --------- Co-authored-by: Michael Schellenberger Costa <[email protected]>
1 parent 1251f54 commit c67b1c3

File tree

5 files changed

+223
-42
lines changed

5 files changed

+223
-42
lines changed

libcudacxx/include/cuda/std/__complex/nvbf16.h

+76-11
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,39 @@ struct __libcpp_complex_overload_traits<__nv_bfloat16, false, false>
6363
typedef complex<__nv_bfloat16> _ComplexType;
6464
};
6565

66+
// This is a workaround against the user defining macros __CUDA_NO_BFLOAT16_CONVERSIONS__ __CUDA_NO_BFLOAT16_OPERATORS__
67+
template <>
68+
struct __complex_can_implicitly_construct<__nv_bfloat16, float> : true_type
69+
{};
70+
71+
template <>
72+
struct __complex_can_implicitly_construct<__nv_bfloat16, double> : true_type
73+
{};
74+
75+
template <>
76+
struct __complex_can_implicitly_construct<float, __nv_bfloat16> : true_type
77+
{};
78+
79+
template <>
80+
struct __complex_can_implicitly_construct<double, __nv_bfloat16> : true_type
81+
{};
82+
83+
template <class _Tp>
84+
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16(const _Tp& __value) noexcept
85+
{
86+
return __value;
87+
}
88+
89+
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16(const float& __value) noexcept
90+
{
91+
return __float2bfloat16(__value);
92+
}
93+
94+
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16(const double& __value) noexcept
95+
{
96+
return __double2bfloat16(__value);
97+
}
98+
6699
template <>
67100
class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__nv_bfloat16>
68101
{
@@ -80,14 +113,14 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
80113

81114
template <class _Up, __enable_if_t<__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0>
82115
_LIBCUDACXX_INLINE_VISIBILITY complex(const complex<_Up>& __c)
83-
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
116+
: __repr_(__convert_to_bfloat16(__c.real()), __convert_to_bfloat16(__c.imag()))
84117
{}
85118

86119
template <class _Up,
87120
__enable_if_t<!__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0,
88121
__enable_if_t<_CCCL_TRAIT(is_constructible, value_type, _Up), int> = 0>
89122
_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<_Up>& __c)
90-
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
123+
: __repr_(__convert_to_bfloat16(__c.real()), __convert_to_bfloat16(__c.imag()))
91124
{}
92125

93126
_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const value_type& __re)
@@ -100,8 +133,8 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
100133
template <class _Up>
101134
_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const complex<_Up>& __c)
102135
{
103-
__repr_.x = __c.real();
104-
__repr_.y = __c.imag();
136+
__repr_.x = __convert_to_bfloat16(__c.real());
137+
__repr_.y = __convert_to_bfloat16(__c.imag());
105138
return *this;
106139
}
107140

@@ -155,24 +188,24 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
155188

156189
_LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const value_type& __re)
157190
{
158-
__repr_.x += __re;
191+
__repr_.x = __hadd(__repr_.x, __re);
159192
return *this;
160193
}
161194
_LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const value_type& __re)
162195
{
163-
__repr_.x -= __re;
196+
__repr_.x = __hsub(__repr_.x, __re);
164197
return *this;
165198
}
166199
_LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const value_type& __re)
167200
{
168-
__repr_.x *= __re;
169-
__repr_.y *= __re;
201+
__repr_.x = __hmul(__repr_.x, __re);
202+
__repr_.y = __hmul(__repr_.y, __re);
170203
return *this;
171204
}
172205
_LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const value_type& __re)
173206
{
174-
__repr_.x /= __re;
175-
__repr_.y /= __re;
207+
__repr_.x = __hdiv(__repr_.x, __re);
208+
__repr_.y = __hdiv(__repr_.y, __re);
176209
return *this;
177210
}
178211

@@ -195,9 +228,41 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
195228
}
196229
};
197230

231+
template <> // complex<float>
232+
template <> // complex<__half>
233+
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>::complex(const complex<__nv_bfloat16>& __c)
234+
: __re_(__bfloat162float(__c.real()))
235+
, __im_(__bfloat162float(__c.imag()))
236+
{}
237+
238+
template <> // complex<double>
239+
template <> // complex<__half>
240+
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>::complex(const complex<__nv_bfloat16>& __c)
241+
: __re_(__bfloat162float(__c.real()))
242+
, __im_(__bfloat162float(__c.imag()))
243+
{}
244+
245+
template <> // complex<float>
246+
template <> // complex<__nv_bfloat16>
247+
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>& complex<float>::operator=(const complex<__nv_bfloat16>& __c)
248+
{
249+
__re_ = __bfloat162float(__c.real());
250+
__im_ = __bfloat162float(__c.imag());
251+
return *this;
252+
}
253+
254+
template <> // complex<double>
255+
template <> // complex<__nv_bfloat16>
256+
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>& complex<double>::operator=(const complex<__nv_bfloat16>& __c)
257+
{
258+
__re_ = __bfloat162float(__c.real());
259+
__im_ = __bfloat162float(__c.imag());
260+
return *this;
261+
}
262+
198263
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 arg(__nv_bfloat16 __re)
199264
{
200-
return _CUDA_VSTD::atan2f(__nv_bfloat16(0), __re);
265+
return _CUDA_VSTD::atan2(__int2bfloat16_rn(0), __re);
201266
}
202267

203268
// We have performance issues with some trigonometric functions with __nv_bfloat16

libcudacxx/include/cuda/std/__complex/nvfp16.h

+76-11
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,39 @@ struct __libcpp_complex_overload_traits<__half, false, false>
6060
typedef complex<__half> _ComplexType;
6161
};
6262

63+
// This is a workaround against the user defining macros __CUDA_NO_HALF_CONVERSIONS__ __CUDA_NO_HALF_OPERATORS__
64+
template <>
65+
struct __complex_can_implicitly_construct<__half, float> : true_type
66+
{};
67+
68+
template <>
69+
struct __complex_can_implicitly_construct<__half, double> : true_type
70+
{};
71+
72+
template <>
73+
struct __complex_can_implicitly_construct<float, __half> : true_type
74+
{};
75+
76+
template <>
77+
struct __complex_can_implicitly_construct<double, __half> : true_type
78+
{};
79+
80+
template <class _Tp>
81+
inline _LIBCUDACXX_INLINE_VISIBILITY __half __convert_to_half(const _Tp& __value) noexcept
82+
{
83+
return __value;
84+
}
85+
86+
inline _LIBCUDACXX_INLINE_VISIBILITY __half __convert_to_half(const float& __value) noexcept
87+
{
88+
return __float2half(__value);
89+
}
90+
91+
inline _LIBCUDACXX_INLINE_VISIBILITY __half __convert_to_half(const double& __value) noexcept
92+
{
93+
return __double2half(__value);
94+
}
95+
6396
template <>
6497
class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
6598
{
@@ -77,14 +110,14 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
77110

78111
template <class _Up, __enable_if_t<__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0>
79112
_LIBCUDACXX_INLINE_VISIBILITY complex(const complex<_Up>& __c)
80-
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
113+
: __repr_(__convert_to_half(__c.real()), __convert_to_half(__c.imag()))
81114
{}
82115

83116
template <class _Up,
84117
__enable_if_t<!__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0,
85118
__enable_if_t<_CCCL_TRAIT(is_constructible, value_type, _Up), int> = 0>
86119
_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<_Up>& __c)
87-
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
120+
: __repr_(__convert_to_half(__c.real()), __convert_to_half(__c.imag()))
88121
{}
89122

90123
_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const value_type& __re)
@@ -97,8 +130,8 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
97130
template <class _Up>
98131
_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const complex<_Up>& __c)
99132
{
100-
__repr_.x = __c.real();
101-
__repr_.y = __c.imag();
133+
__repr_.x = __convert_to_half(__c.real());
134+
__repr_.y = __convert_to_half(__c.imag());
102135
return *this;
103136
}
104137

@@ -152,24 +185,24 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
152185

153186
_LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const value_type& __re)
154187
{
155-
__repr_.x += __re;
188+
__repr_.x = __hadd(__repr_.x, __re);
156189
return *this;
157190
}
158191
_LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const value_type& __re)
159192
{
160-
__repr_.x -= __re;
193+
__repr_.x = __hsub(__repr_.x, __re);
161194
return *this;
162195
}
163196
_LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const value_type& __re)
164197
{
165-
__repr_.x *= __re;
166-
__repr_.y *= __re;
198+
__repr_.x = __hmul(__repr_.x, __re);
199+
__repr_.y = __hmul(__repr_.y, __re);
167200
return *this;
168201
}
169202
_LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const value_type& __re)
170203
{
171-
__repr_.x /= __re;
172-
__repr_.y /= __re;
204+
__repr_.x = __hdiv(__repr_.x, __re);
205+
__repr_.y = __hdiv(__repr_.y, __re);
173206
return *this;
174207
}
175208

@@ -192,9 +225,41 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
192225
}
193226
};
194227

228+
template <> // complex<float>
229+
template <> // complex<__half>
230+
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>::complex(const complex<__half>& __c)
231+
: __re_(__half2float(__c.real()))
232+
, __im_(__half2float(__c.imag()))
233+
{}
234+
235+
template <> // complex<double>
236+
template <> // complex<__half>
237+
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>::complex(const complex<__half>& __c)
238+
: __re_(__half2float(__c.real()))
239+
, __im_(__half2float(__c.imag()))
240+
{}
241+
242+
template <> // complex<float>
243+
template <> // complex<__half>
244+
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>& complex<float>::operator=(const complex<__half>& __c)
245+
{
246+
__re_ = __half2float(__c.real());
247+
__im_ = __half2float(__c.imag());
248+
return *this;
249+
}
250+
251+
template <> // complex<double>
252+
template <> // complex<__half>
253+
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>& complex<double>::operator=(const complex<__half>& __c)
254+
{
255+
__re_ = __half2float(__c.real());
256+
__im_ = __half2float(__c.imag());
257+
return *this;
258+
}
259+
195260
inline _LIBCUDACXX_INLINE_VISIBILITY __half arg(__half __re)
196261
{
197-
return _CUDA_VSTD::atan2f(__half(0), __re);
262+
return _CUDA_VSTD::atan2(__int2half_rn(0), __re);
198263
}
199264

200265
// We have performance issues with some trigonometric functions with __half

libcudacxx/include/cuda/std/__cuda/cmath_nvbf16.h

+10-10
Original file line numberDiff line numberDiff line change
@@ -37,47 +37,47 @@ _LIBCUDACXX_BEGIN_NAMESPACE_STD
3737
// trigonometric functions
3838
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sin(__nv_bfloat16 __v)
3939
{
40-
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsin(__v);), (return __nv_bfloat16(::sin(float(__v)));))
40+
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsin(__v);), (return __float2bfloat16(::sin(__bfloat162float(__v)));))
4141
}
4242

4343
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sinh(__nv_bfloat16 __v)
4444
{
45-
return __nv_bfloat16(::sinh(float(__v)));
45+
return __float2bfloat16(::sinh(__bfloat162float(__v)));
4646
}
4747

4848
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cos(__nv_bfloat16 __v)
4949
{
50-
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hcos(__v);), (return __nv_bfloat16(::cos(float(__v)));))
50+
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hcos(__v);), (return __float2bfloat16(::cos(__bfloat162float(__v)));))
5151
}
5252

5353
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cosh(__nv_bfloat16 __v)
5454
{
55-
return __nv_bfloat16(::cosh(float(__v)));
55+
return __float2bfloat16(::cosh(__bfloat162float(__v)));
5656
}
5757

5858
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 exp(__nv_bfloat16 __v)
5959
{
60-
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hexp(__v);), (return __nv_bfloat16(::exp(float(__v)));))
60+
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hexp(__v);), (return __float2bfloat16(::exp(__bfloat162float(__v)));))
6161
}
6262

6363
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 hypot(__nv_bfloat16 __x, __nv_bfloat16 __y)
6464
{
65-
return __nv_bfloat16(::hypot(float(__x), float(__y)));
65+
return __float2bfloat16(::hypot(__bfloat162float(__x), __bfloat162float(__y)));
6666
}
6767

6868
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 atan2(__nv_bfloat16 __x, __nv_bfloat16 __y)
6969
{
70-
return __nv_bfloat16(::atan2(float(__x), float(__y)));
70+
return __float2bfloat16(::atan2(__bfloat162float(__x), __bfloat162float(__y)));
7171
}
7272

7373
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 log(__nv_bfloat16 __x)
7474
{
75-
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hlog(__x);), (return __nv_bfloat16(::log(float(__x)));))
75+
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hlog(__x);), (return __float2bfloat16(::log(__bfloat162float(__x)));))
7676
}
7777

7878
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sqrt(__nv_bfloat16 __x)
7979
{
80-
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __nv_bfloat16(::sqrt(float(__x)));))
80+
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2bfloat16(::sqrt(__bfloat162float(__x)));))
8181
}
8282

8383
// floating point helper
@@ -123,7 +123,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY bool isfinite(__nv_bfloat16 __v)
123123

124124
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __constexpr_copysign(__nv_bfloat16 __x, __nv_bfloat16 __y) noexcept
125125
{
126-
return __nv_bfloat16(::copysignf(float(__x), float(__y)));
126+
return __float2bfloat16(::copysignf(__bfloat162float(__x), __bfloat162float(__y)));
127127
}
128128

129129
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 copysign(__nv_bfloat16 __x, __nv_bfloat16 __y)

0 commit comments

Comments
 (0)