Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add atomic_add and BlockStripedReduce to bfloat162 #1653

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions include/cutlass/block_striped.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,5 +263,36 @@ struct BlockStripedReduce<BlockThreads, ArrayT, half_t> :
};


/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable,
/// statically-sized array types to global memory.
/// (Specialization for bfloat16_t. Uses __nv_bfloat162 vectorized-reduction.)
template <
int BlockThreads,
typename ArrayT>
struct BlockStripedReduce<BlockThreads, ArrayT, bfloat16_t> :
BlockStriped<
BlockThreads,
ArrayT,
__nv_bfloat162>
{
static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of bfloat16 must be even number in length");

/// Reduce
CUTLASS_DEVICE
static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx)
{
cutlass::atomic_add<__nv_bfloat162> reduce;
__nv_bfloat162 *access_output = reinterpret_cast<__nv_bfloat162*>(ptr);
const __nv_bfloat162 *access_data = reinterpret_cast<const __nv_bfloat162*>(&data);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < BlockStripedReduce::kStripes; ++i)
{
reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]);
}
}
};


} // namespace cutlass

17 changes: 17 additions & 0 deletions include/cutlass/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,23 @@ struct atomic_add<half2>
}
};

template<>
struct atomic_add<__nv_bfloat162>
{
CUTLASS_DEVICE
void operator()(__nv_bfloat162 *ptr, const __nv_bfloat162)
{
#if !defined(__CUDA_ARCH__) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 900))
CUTLASS_UNUSED(ptr);
CUTLASS_UNUSED(data);
#else
// Vector-2 bf16 atomic reduction requires .target sm_90 or higher
uint32_t word = reinterpret_cast<const uint32_t&>(data);
asm volatile ("red.gpu.global.add.noftz.bf16x2 [%0], %1;\n" : : "l"(ptr), "r"(word));
#endif // (__CUDA_ARCH__ >= 900)
}
};

template <typename T>
using red [[deprecated("use atomic_add instead")]] = atomic_add<T>;

Expand Down