Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions csrc/cpu/comm/arm64/shm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

// NOTE:
// This shared-memory implementation targets AArch64 CPUs.
// Minimum supported architecture is ARMv8-A with NEON (Advanced SIMD) support.
// Systems without NEON are not supported.

#include <arm_neon.h>
#include <stddef.h>
#include <stdint.h>
#include <cmath>

// 128 bits = 16 bytes -> fits 8 fp16/bf16 or 4 fp32 elements.
static int vector_length_in_bytes = 16;
// When widening fp16/bf16 -> fp32, 4 elements fit in one 128-bit register.
// Using 8 would require two 128-bit registers, so limit to 4.
static constexpr int full_precision_elements_in_fixed_vector = 4;

static inline float32x4_t cvt_bf16_to_fp32(const uint16x4_t input)
{
// Zero-extend 16-bit to 32-bit and shift left by 16 bits
// BF16 has the same exponent/sign bits as FP32, just missing lower mantissa bits
uint32x4_t result_32 = vshll_n_u16(input, 16);
return vreinterpretq_f32_u32(result_32);
}

static inline float32x4_t cvt_fp16_to_fp32(float16x4_t input)
{
// Converts 4 FP16 values to 4 FP32 values
return vcvt_f32_f16(input);
}

// While converting fp32 to fp16, before truncating lsb, it should be rounded to nearest even and
// Converts 4 float32 -> 4 bfloat16 with round-to-nearest-even (RNE) and NaN handling
static inline uint16x4_t cvt_fp32_to_bf16(float32x4_t src)
{
// Reinterpret float32 bits as uint32
uint32x4_t u32 = vreinterpretq_u32_f32(src);

const uint32x4_t ones = vdupq_n_u32(0x1);
const uint32x4_t vec_bias =
vdupq_n_u32(0x7FFF); // one less than half of the dropped bits range
const uint16x4_t nan_bf16 = vdup_n_u16(0xFFFF);

// RNE: lsb = (input >> 16) & 1
uint32x4_t lsb = vandq_u32(vshrq_n_u32(u32, 16), ones);

// rounding_bias = 0x7FFF + lsb, lsb can be 0 or 1.
uint32x4_t bias = vaddq_u32(vec_bias, lsb);

// input += rounding_bias
u32 = vaddq_u32(u32, bias);

// >> 16 to get bfloat16
// vshrq_n_u32 - keeps 32 bit width after shift
// vshrn_n_u32 - keeps 16 bits width after shift
uint16x4_t bf16 = vshrn_n_u32(u32, 16);

// vmvnq_u32 is bitwise NOT
// NaN mask: ~(src == src) -> 1 if NaN
// for normal num, ~(src == src) -> 0
uint32x4_t isnan = vmvnq_u32(vceqq_f32(src, src));

// Select nan_bf16 if isnan (use 16-bit mask)
uint16x4_t mask = vreinterpret_u16_u32(vget_low_u32(isnan));
return vbsl_u16(mask, nan_bf16, bf16);
}

// fp32 and fp16 are IEEE formats.
// converting fp32 to fp16 is handled by vcvt_f16_f32 internally without arbitrarily truncating the
// lsb but rounds to nearest.
static inline float16x4_t cvt_fp32_to_fp16(float32x4_t input)
{
// Converts 4 FP32 values to 4 FP16 values with rounding
return vcvt_f16_f32(input);
}

// Reduce functions down below use vectorized algorithm, the number of bytes processed each
// iteration depends on vector length. 128bit vector ==> 16 bytes. sticking to NEON 128 bit

void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers);
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers);
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers);

void parallel_memcpy(void* to, void* from, size_t n_bytes);

#define VLOAD_U8(X) vld1q_u8((uint8_t*)(X))
#define VLOAD_U16(X) vld1_u16((uint16_t*)(X))
#define VLOAD_F16(X) vld1_f16((float16_t*)(X))
#define VLOAD_F32(X) vld1q_f32((float32_t*)(X))

#define VSTORE_U8(A, B) vst1q_u8((uint8_t*)(A), B)
#define VSTORE_U16(A, B) vst1_u16((uint16_t*)(A), B)
#define VSTORE_F16(A, B) vst1_f16((float16_t*)(A), B) // fp16 supported from armv8.2-a+fp16
#define VSTORE_F32(A, B) vst1q_f32((float32_t*)(A), B)

#define VADD_F32(A, B) vaddq_f32(A, B)
#define VADD_F32_2VL(A, B) vaddq_f32(A, B)

#define CVT_BF16_TO_FP32(X) cvt_bf16_to_fp32(X)
#define CVT_FP16_TO_FP32(X) cvt_fp16_to_fp32(X)
#define CVT_FP32_TO_BF16(X) cvt_fp32_to_bf16(X)
#define CVT_FP32_TO_FP16(X) cvt_fp32_to_fp16(X)
18 changes: 15 additions & 3 deletions csrc/cpu/comm/shm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#if defined(__riscv)
#define TARGET_RISCV 1
#include "riscv64/shm.h"
#elif defined(__aarch64__)
#define TARGET_ARM 1
#include "arm64/shm.h"
#else
#include "x86_64/shm.h"
#endif
Expand Down Expand Up @@ -154,7 +157,10 @@ void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer,
#if TARGET_RISCV
size_t vl = __riscv_vsetvl_e16m1(num_elements);
vector_length_in_bytes = vl * element_size;
#else
#elif TARGET_ARM
const int vl = full_precision_elements_in_fixed_vector;
vector_length_in_bytes = vl * element_size;
#else // x86_64
const int vl = vector_length_in_bytes / element_size;
#endif
int main_elements = num_elements - (num_elements % vl);
Expand Down Expand Up @@ -214,7 +220,10 @@ void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer,
#if TARGET_RISCV
size_t vl = __riscv_vsetvl_e16m1(num_elements);
vector_length_in_bytes = vl * element_size;
#else
#elif TARGET_ARM
const int vl = full_precision_elements_in_fixed_vector;
vector_length_in_bytes = vl * element_size;
#else // x86_64
const int vl = vector_length_in_bytes / element_size;
#endif
int main_elements = num_elements - (num_elements % vl);
Expand Down Expand Up @@ -274,7 +283,10 @@ void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer,
#if TARGET_RISCV
size_t vl = __riscv_vsetvl_e32m1(num_elements);
vector_length_in_bytes = vl * element_size;
#else
#elif TARGET_ARM
const int vl = full_precision_elements_in_fixed_vector;
vector_length_in_bytes = vl * element_size;
#else // x86_64
const int vl = vector_length_in_bytes / element_size;
#endif
int main_elements = num_elements - (num_elements % vl);
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,11 @@ def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, asyn
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)

def inference_all_reduce(self, tensor, op, group=None):
if not hasattr(torch.ops, 'deepspeed') or not hasattr(torch.ops.deepspeed, 'inference_all_reduce_'):
use_ds_op = hasattr(torch.ops, 'deepspeed') and hasattr(torch.ops.deepspeed, 'inference_all_reduce_')
world_size = torch.distributed.get_world_size(group=group)
if world_size <= 1:
return tensor
if not use_ds_op:
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False)
else:
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/comm/test_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def test(self, distributed_fixture, class_tmpdir, val1, val2):
assert int(os.environ["WORLD_SIZE"]) == 1


@pytest.mark.parametrize("num_elements", [128, 3])
class TestDistAllReduce(DistributedTest):
device_count = get_accelerator().device_count()
if device_count >= 4:
Expand All @@ -119,15 +120,16 @@ class TestDistAllReduce(DistributedTest):
else:
world_size = [1]

def test(self):
x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1)
def test(self, num_elements):
x = torch.ones(1, num_elements).to(get_accelerator().device_name()) * (dist.get_rank() + 1)
sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
result = torch.ones(1, 3).to(get_accelerator().device_name()) * sum_of_ranks
result = torch.ones(1, num_elements).to(get_accelerator().device_name()) * sum_of_ranks
dist.all_reduce(x)
assert torch.all(x == result)


@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("num_elements", [128, 3])
class TestDistInferenceAllReduce(DistributedTest):
device_count = get_accelerator().device_count()
if device_count >= 4:
Expand All @@ -137,10 +139,10 @@ class TestDistInferenceAllReduce(DistributedTest):
else:
world_size = [1]

def test(self, dtype):
x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1)
def test(self, dtype, num_elements):
x = torch.ones(1, num_elements).to(get_accelerator().device_name()) * (dist.get_rank() + 1)
sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
result = torch.ones(1, 3).to(get_accelerator().device_name()) * sum_of_ranks
result = torch.ones(1, num_elements).to(get_accelerator().device_name()) * sum_of_ranks
result = result.to(dtype)
x = x.to(dtype)
dist.inference_all_reduce(x)
Expand Down