Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions deepspeed/inference/v2/kernels/includes/conversion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,42 +363,74 @@ DS_D_INLINE __nv_bfloat16 to(float val)
template <>
DS_D_INLINE __nv_bfloat16 to(int64_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __double2bfloat16(__ll2double_rn(val));
#else
return __ll2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(int32_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__int2float_rn(val));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Use exact intermediate for int32/uint32 BF16 casts on HIP

On AMD HIP, this new path converts 32-bit integers through float (__int2float_rn/__uint2float_rn) before __float2bfloat16, which can double-round for |val| > 2^24 and pick a different bf16 than a direct integer→bf16 round (the CUDA branch uses __int2bfloat16_rn/__uint2bfloat16_rn). This means large integer inputs can now produce numerically different bf16 values on ROCm; converting via double for 32-bit ints would keep the conversion exact before the final bf16 rounding.

Useful? React with 👍 / 👎.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conversion is aligned with the existing code. If we want to change the behavior, that should be a separated PR.

#else
return __int2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(int16_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__int2float_rn(val));
#else
return __short2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(int8_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__int2float_rn(val));
#else
return __int2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint64_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __double2bfloat16(__ull2double_rn(val));
#else
return __ull2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint32_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__uint2float_rn(val));
#else
return __uint2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint16_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__uint2float_rn(val));
#else
return __ushort2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint8_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__uint2float_rn(val));
#else
return __uint2bfloat16_rn(val);
#endif
}
#endif

Expand All @@ -412,7 +444,11 @@ DS_D_INLINE __nv_bfloat162 to(float2 val)
template <>
DS_D_INLINE __nv_bfloat162 to(float val)
{
#ifdef __HIP_PLATFORM_AMD__
return __bfloat162bfloat162(__float2bfloat16(val));
#else
return __float2bfloat162_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat162 to(__half2 val)
Expand Down Expand Up @@ -444,7 +480,11 @@ DS_D_INLINE int64_t to(__half val)
template <>
DS_D_INLINE int64_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2ll_rn(__bfloat162float(val));
#else
return __bfloat162ll_rn(val);
#endif
}
#endif

Expand All @@ -471,7 +511,11 @@ DS_D_INLINE int32_t to(__half val)
template <>
DS_D_INLINE int32_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2int_rn(__bfloat162float(val));
#else
return __bfloat162int_rn(val);
#endif
}
#endif

Expand All @@ -498,7 +542,11 @@ DS_D_INLINE int16_t to(__half val)
template <>
DS_D_INLINE int16_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2int_rn(__bfloat162float(val));
#else
return __bfloat162int_rn(val);
#endif
}
#endif

Expand All @@ -525,7 +573,11 @@ DS_D_INLINE int8_t to(__half val)
template <>
DS_D_INLINE int8_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2int_rn(__bfloat162float(val));
#else
return __bfloat162int_rn(val);
#endif
}
#endif

Expand All @@ -552,7 +604,11 @@ DS_D_INLINE uint64_t to(__half val)
template <>
DS_D_INLINE uint64_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2ull_rn(__bfloat162float(val));
#else
return __bfloat162ull_rn(val);
#endif
}
#endif

Expand All @@ -579,7 +635,11 @@ DS_D_INLINE uint32_t to(__half val)
template <>
DS_D_INLINE uint32_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2uint_rn(__bfloat162float(val));
#else
return __bfloat162uint_rn(val);
#endif
}
#endif

Expand All @@ -606,7 +666,11 @@ DS_D_INLINE uint16_t to(__half val)
template <>
DS_D_INLINE uint16_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2uint_rn(__bfloat162float(val));
#else
return __bfloat162uint_rn(val);
#endif
}
#endif

Expand All @@ -633,7 +697,11 @@ DS_D_INLINE uint8_t to(__half val)
template <>
DS_D_INLINE uint8_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2uint_rn(__bfloat162float(val));
#else
return __bfloat162uint_rn(val);
#endif
}
#endif

Expand Down
Loading