diff --git a/deepspeed/inference/v2/kernels/includes/conversion_utils.h b/deepspeed/inference/v2/kernels/includes/conversion_utils.h index 3a90a3e91ddf..d6d8f11e0854 100644 --- a/deepspeed/inference/v2/kernels/includes/conversion_utils.h +++ b/deepspeed/inference/v2/kernels/includes/conversion_utils.h @@ -363,42 +363,74 @@ DS_D_INLINE __nv_bfloat16 to(float val) template <> DS_D_INLINE __nv_bfloat16 to(int64_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __double2bfloat16(__ll2double_rn(val)); +#else return __ll2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(int32_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else return __int2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(int16_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else return __short2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(int8_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else return __int2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint64_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __double2bfloat16(__ull2double_rn(val)); +#else return __ull2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint32_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else return __uint2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint16_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else return __ushort2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint8_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else return __uint2bfloat16_rn(val); +#endif } #endif @@ -412,7 +444,11 @@ DS_D_INLINE __nv_bfloat162 to(float2 val) template <> DS_D_INLINE __nv_bfloat162 to(float val) { +#ifdef __HIP_PLATFORM_AMD__ + return __bfloat162bfloat162(__float2bfloat16(val)); +#else return __float2bfloat162_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat162 to(__half2 val) @@ -444,7 +480,11 @@ DS_D_INLINE int64_t to(__half val) template <> DS_D_INLINE int64_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2ll_rn(__bfloat162float(val)); +#else return __bfloat162ll_rn(val); +#endif } #endif @@ -471,7 +511,11 @@ DS_D_INLINE int32_t to(__half val) template <> DS_D_INLINE int32_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else return __bfloat162int_rn(val); +#endif } #endif @@ -498,7 +542,11 @@ DS_D_INLINE int16_t to(__half val) template <> DS_D_INLINE int16_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else return __bfloat162int_rn(val); +#endif } #endif @@ -525,7 +573,11 @@ DS_D_INLINE int8_t to(__half val) template <> DS_D_INLINE int8_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else return __bfloat162int_rn(val); +#endif } #endif @@ -552,7 +604,11 @@ DS_D_INLINE uint64_t to(__half val) template <> DS_D_INLINE uint64_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2ull_rn(__bfloat162float(val)); +#else return __bfloat162ull_rn(val); +#endif } #endif @@ -579,7 +635,11 @@ DS_D_INLINE uint32_t to(__half val) template <> DS_D_INLINE uint32_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else return __bfloat162uint_rn(val); +#endif } #endif @@ -606,7 +666,11 @@ DS_D_INLINE uint16_t to(__half val) template <> DS_D_INLINE uint16_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else return __bfloat162uint_rn(val); +#endif } #endif @@ -633,7 +697,11 @@ DS_D_INLINE uint8_t to(__half val) template <> DS_D_INLINE uint8_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else return __bfloat162uint_rn(val); +#endif } #endif