@@ -63,6 +63,39 @@ struct __libcpp_complex_overload_traits<__nv_bfloat16, false, false>
63
63
typedef complex<__nv_bfloat16> _ComplexType;
64
64
};
65
65
66
+ // This is a workaround against the user defining macros __CUDA_NO_BFLOAT16_CONVERSIONS__ __CUDA_NO_BFLOAT16_OPERATORS__
67
+ template <>
68
+ struct __complex_can_implicitly_construct <__nv_bfloat16, float > : true_type
69
+ {};
70
+
71
+ template <>
72
+ struct __complex_can_implicitly_construct <__nv_bfloat16, double > : true_type
73
+ {};
74
+
75
+ template <>
76
+ struct __complex_can_implicitly_construct <float , __nv_bfloat16> : true_type
77
+ {};
78
+
79
+ template <>
80
+ struct __complex_can_implicitly_construct <double , __nv_bfloat16> : true_type
81
+ {};
82
+
83
+ template <class _Tp >
84
+ inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16 (const _Tp& __value) noexcept
85
+ {
86
+ return __value;
87
+ }
88
+
89
+ inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16 (const float & __value) noexcept
90
+ {
91
+ return __float2bfloat16 (__value);
92
+ }
93
+
94
+ inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16 (const double & __value) noexcept
95
+ {
96
+ return __double2bfloat16 (__value);
97
+ }
98
+
66
99
template <>
67
100
class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS (alignof(__nv_bfloat162)) complex<__nv_bfloat16>
68
101
{
@@ -80,14 +113,14 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
80
113
81
114
template <class _Up , __enable_if_t <__complex_can_implicitly_construct<value_type, _Up>::value, int > = 0 >
82
115
_LIBCUDACXX_INLINE_VISIBILITY complex (const complex<_Up>& __c)
83
- : __repr_ (static_cast <value_type> (__c.real ()), static_cast <value_type> (__c.imag ()))
116
+ : __repr_ (__convert_to_bfloat16 (__c.real ()), __convert_to_bfloat16 (__c.imag ()))
84
117
{}
85
118
86
119
template <class _Up ,
87
120
__enable_if_t <!__complex_can_implicitly_construct<value_type, _Up>::value, int > = 0 ,
88
121
__enable_if_t <_CCCL_TRAIT (is_constructible, value_type, _Up), int > = 0 >
89
122
_LIBCUDACXX_INLINE_VISIBILITY explicit complex (const complex<_Up>& __c)
90
- : __repr_ (static_cast <value_type> (__c.real ()), static_cast <value_type> (__c.imag ()))
123
+ : __repr_ (__convert_to_bfloat16 (__c.real ()), __convert_to_bfloat16 (__c.imag ()))
91
124
{}
92
125
93
126
_LIBCUDACXX_INLINE_VISIBILITY complex& operator =(const value_type& __re)
@@ -100,8 +133,8 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
100
133
template <class _Up >
101
134
_LIBCUDACXX_INLINE_VISIBILITY complex& operator =(const complex<_Up>& __c)
102
135
{
103
- __repr_.x = __c.real ();
104
- __repr_.y = __c.imag ();
136
+ __repr_.x = __convert_to_bfloat16 ( __c.real () );
137
+ __repr_.y = __convert_to_bfloat16 ( __c.imag () );
105
138
return *this ;
106
139
}
107
140
@@ -155,24 +188,24 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
155
188
156
189
_LIBCUDACXX_INLINE_VISIBILITY complex& operator +=(const value_type& __re)
157
190
{
158
- __repr_.x += __re;
191
+ __repr_.x = __hadd (__repr_. x , __re) ;
159
192
return *this ;
160
193
}
161
194
_LIBCUDACXX_INLINE_VISIBILITY complex& operator -=(const value_type& __re)
162
195
{
163
- __repr_.x -= __re;
196
+ __repr_.x = __hsub (__repr_. x , __re) ;
164
197
return *this ;
165
198
}
166
199
_LIBCUDACXX_INLINE_VISIBILITY complex& operator *=(const value_type& __re)
167
200
{
168
- __repr_.x *= __re;
169
- __repr_.y *= __re;
201
+ __repr_.x = __hmul (__repr_. x , __re) ;
202
+ __repr_.y = __hmul (__repr_. y , __re) ;
170
203
return *this ;
171
204
}
172
205
_LIBCUDACXX_INLINE_VISIBILITY complex& operator /=(const value_type& __re)
173
206
{
174
- __repr_.x /= __re;
175
- __repr_.y /= __re;
207
+ __repr_.x = __hdiv (__repr_. x , __re) ;
208
+ __repr_.y = __hdiv (__repr_. y , __re) ;
176
209
return *this ;
177
210
}
178
211
@@ -195,9 +228,41 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
195
228
}
196
229
};
197
230
231
+ template <> // complex<float>
232
+ template <> // complex<__half>
233
+ inline _LIBCUDACXX_INLINE_VISIBILITY complex<float >::complex(const complex<__nv_bfloat16>& __c)
234
+ : __re_(__bfloat162float(__c.real()))
235
+ , __im_(__bfloat162float(__c.imag()))
236
+ {}
237
+
238
+ template <> // complex<double>
239
+ template <> // complex<__half>
240
+ inline _LIBCUDACXX_INLINE_VISIBILITY complex<double >::complex(const complex<__nv_bfloat16>& __c)
241
+ : __re_(__bfloat162float(__c.real()))
242
+ , __im_(__bfloat162float(__c.imag()))
243
+ {}
244
+
245
+ template <> // complex<float>
246
+ template <> // complex<__nv_bfloat16>
247
+ inline _LIBCUDACXX_INLINE_VISIBILITY complex<float >& complex<float >::operator =(const complex<__nv_bfloat16>& __c)
248
+ {
249
+ __re_ = __bfloat162float (__c.real ());
250
+ __im_ = __bfloat162float (__c.imag ());
251
+ return *this ;
252
+ }
253
+
254
+ template <> // complex<double>
255
+ template <> // complex<__nv_bfloat16>
256
+ inline _LIBCUDACXX_INLINE_VISIBILITY complex<double >& complex<double >::operator =(const complex<__nv_bfloat16>& __c)
257
+ {
258
+ __re_ = __bfloat162float (__c.real ());
259
+ __im_ = __bfloat162float (__c.imag ());
260
+ return *this ;
261
+ }
262
+
198
263
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 arg (__nv_bfloat16 __re)
199
264
{
200
- return _CUDA_VSTD::atan2f ( __nv_bfloat16 (0 ), __re);
265
+ return _CUDA_VSTD::atan2 ( __int2bfloat16_rn (0 ), __re);
201
266
}
202
267
203
268
// We have performance issues with some trigonometric functions with __nv_bfloat16
0 commit comments