diff --git a/csrc/cpu/comm/arm64/shm.h b/csrc/cpu/comm/arm64/shm.h new file mode 100644 index 000000000000..f6bdc41c6d43 --- /dev/null +++ b/csrc/cpu/comm/arm64/shm.h @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// NOTE: +// This shared-memory implementation targets AArch64 CPUs. +// Minimum supported architecture is ARMv8-A with NEON (Advanced SIMD) support. +// Systems without NEON are not supported. + +#include +#include +#include +#include + +// 128 bits = 16 bytes -> fits 8 fp16/bf16 or 4 fp32 elements. +static int vector_length_in_bytes = 16; +// When widening fp16/bf16 -> fp32, 4 elements fit in one 128-bit register. +// Using 8 would require two 128-bit registers, so limit to 4. +static constexpr int full_precision_elements_in_fixed_vector = 4; + +static inline float32x4_t cvt_bf16_to_fp32(const uint16x4_t input) +{ + // Zero-extend 16-bit to 32-bit and shift left by 16 bits + // BF16 has the same exponent/sign bits as FP32, just missing lower mantissa bits + uint32x4_t result_32 = vshll_n_u16(input, 16); + return vreinterpretq_f32_u32(result_32); +} + +static inline float32x4_t cvt_fp16_to_fp32(float16x4_t input) +{ + // Converts 4 FP16 values to 4 FP32 values + return vcvt_f32_f16(input); +} + +// While converting fp32 to fp16, before truncating lsb, it should be rounded to nearest even and +// Converts 4 float32 -> 4 bfloat16 with round-to-nearest-even (RNE) and NaN handling +static inline uint16x4_t cvt_fp32_to_bf16(float32x4_t src) +{ + // Reinterpret float32 bits as uint32 + uint32x4_t u32 = vreinterpretq_u32_f32(src); + + const uint32x4_t ones = vdupq_n_u32(0x1); + const uint32x4_t vec_bias = + vdupq_n_u32(0x7FFF); // one less than half of the dropped bits range + const uint16x4_t nan_bf16 = vdup_n_u16(0xFFFF); + + // RNE: lsb = (input >> 16) & 1 + uint32x4_t lsb = vandq_u32(vshrq_n_u32(u32, 16), ones); + + // rounding_bias = 0x7FFF + lsb, lsb can be 0 or 1. + uint32x4_t bias = vaddq_u32(vec_bias, lsb); + + // input += rounding_bias + u32 = vaddq_u32(u32, bias); + + // >> 16 to get bfloat16 + // vshrq_n_u32 - keeps 32 bit width after shift + // vshrn_n_u32 - keeps 16 bits width after shift + uint16x4_t bf16 = vshrn_n_u32(u32, 16); + + // vmvnq_u32 is bitwise NOT + // NaN mask: ~(src == src) -> 1 if NaN + // for normal num, ~(src == src) -> 0 + uint32x4_t isnan = vmvnq_u32(vceqq_f32(src, src)); + + // Select nan_bf16 if isnan (use 16-bit mask) + uint16x4_t mask = vreinterpret_u16_u32(vget_low_u32(isnan)); + return vbsl_u16(mask, nan_bf16, bf16); +} + +// fp32 and fp16 are IEEE formats. +// converting fp32 to fp16 is handled by vcvt_f16_f32 internally without arbitrarily truncating the +// lsb but rounds to nearest. +static inline float16x4_t cvt_fp32_to_fp16(float32x4_t input) +{ + // Converts 4 FP32 values to 4 FP16 values with rounding + return vcvt_f16_f32(input); +} + +// Reduce functions down below use vectorized algorithm, the number of bytes processed each +// iteration depends on vector length. 128bit vector ==> 16 bytes. sticking to NEON 128 bit + +void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers); +void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers); +void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers); + +void parallel_memcpy(void* to, void* from, size_t n_bytes); + +#define VLOAD_U8(X) vld1q_u8((uint8_t*)(X)) +#define VLOAD_U16(X) vld1_u16((uint16_t*)(X)) +#define VLOAD_F16(X) vld1_f16((float16_t*)(X)) +#define VLOAD_F32(X) vld1q_f32((float32_t*)(X)) + +#define VSTORE_U8(A, B) vst1q_u8((uint8_t*)(A), B) +#define VSTORE_U16(A, B) vst1_u16((uint16_t*)(A), B) +#define VSTORE_F16(A, B) vst1_f16((float16_t*)(A), B) // fp16 supported from armv8.2-a+fp16 +#define VSTORE_F32(A, B) vst1q_f32((float32_t*)(A), B) + +#define VADD_F32(A, B) vaddq_f32(A, B) +#define VADD_F32_2VL(A, B) vaddq_f32(A, B) + +#define CVT_BF16_TO_FP32(X) cvt_bf16_to_fp32(X) +#define CVT_FP16_TO_FP32(X) cvt_fp16_to_fp32(X) +#define CVT_FP32_TO_BF16(X) cvt_fp32_to_bf16(X) +#define CVT_FP32_TO_FP16(X) cvt_fp32_to_fp16(X) diff --git a/csrc/cpu/comm/shm.cpp b/csrc/cpu/comm/shm.cpp index d30b4637b50e..40073e6863f2 100644 --- a/csrc/cpu/comm/shm.cpp +++ b/csrc/cpu/comm/shm.cpp @@ -14,6 +14,9 @@ #if defined(__riscv) #define TARGET_RISCV 1 #include "riscv64/shm.h" +#elif defined(__aarch64__) +#define TARGET_ARM 1 +#include "arm64/shm.h" #else #include "x86_64/shm.h" #endif @@ -154,7 +157,10 @@ void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, #if TARGET_RISCV size_t vl = __riscv_vsetvl_e16m1(num_elements); vector_length_in_bytes = vl * element_size; -#else +#elif TARGET_ARM + const int vl = full_precision_elements_in_fixed_vector; + vector_length_in_bytes = vl * element_size; +#else // x86_64 const int vl = vector_length_in_bytes / element_size; #endif int main_elements = num_elements - (num_elements % vl); @@ -214,7 +220,10 @@ void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, #if TARGET_RISCV size_t vl = __riscv_vsetvl_e16m1(num_elements); vector_length_in_bytes = vl * element_size; -#else +#elif TARGET_ARM + const int vl = full_precision_elements_in_fixed_vector; + vector_length_in_bytes = vl * element_size; +#else // x86_64 const int vl = vector_length_in_bytes / element_size; #endif int main_elements = num_elements - (num_elements % vl); @@ -274,7 +283,10 @@ void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, #if TARGET_RISCV size_t vl = __riscv_vsetvl_e32m1(num_elements); vector_length_in_bytes = vl * element_size; -#else +#elif TARGET_ARM + const int vl = full_precision_elements_in_fixed_vector; + vector_length_in_bytes = vl * element_size; +#else // x86_64 const int vl = vector_length_in_bytes / element_size; #endif int main_elements = num_elements - (num_elements % vl); diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 19f2ba2d42b4..8e821f2fdd6d 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -167,7 +167,11 @@ def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, asyn return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) def inference_all_reduce(self, tensor, op, group=None): - if not hasattr(torch.ops, 'deepspeed') or not hasattr(torch.ops.deepspeed, 'inference_all_reduce_'): + use_ds_op = hasattr(torch.ops, 'deepspeed') and hasattr(torch.ops.deepspeed, 'inference_all_reduce_') + world_size = torch.distributed.get_world_size(group=group) + if world_size <= 1: + return tensor + if not use_ds_op: op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False) else: diff --git a/tests/unit/comm/test_dist.py b/tests/unit/comm/test_dist.py index 861ba5c7be1a..0cbc611dc38c 100644 --- a/tests/unit/comm/test_dist.py +++ b/tests/unit/comm/test_dist.py @@ -110,6 +110,7 @@ def test(self, distributed_fixture, class_tmpdir, val1, val2): assert int(os.environ["WORLD_SIZE"]) == 1 +@pytest.mark.parametrize("num_elements", [128, 3]) class TestDistAllReduce(DistributedTest): device_count = get_accelerator().device_count() if device_count >= 4: @@ -119,15 +120,16 @@ class TestDistAllReduce(DistributedTest): else: world_size = [1] - def test(self): - x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1) + def test(self, num_elements): + x = torch.ones(1, num_elements).to(get_accelerator().device_name()) * (dist.get_rank() + 1) sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2 - result = torch.ones(1, 3).to(get_accelerator().device_name()) * sum_of_ranks + result = torch.ones(1, num_elements).to(get_accelerator().device_name()) * sum_of_ranks dist.all_reduce(x) assert torch.all(x == result) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("num_elements", [128, 3]) class TestDistInferenceAllReduce(DistributedTest): device_count = get_accelerator().device_count() if device_count >= 4: @@ -137,10 +139,10 @@ class TestDistInferenceAllReduce(DistributedTest): else: world_size = [1] - def test(self, dtype): - x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1) + def test(self, dtype, num_elements): + x = torch.ones(1, num_elements).to(get_accelerator().device_name()) * (dist.get_rank() + 1) sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2 - result = torch.ones(1, 3).to(get_accelerator().device_name()) * sum_of_ranks + result = torch.ones(1, num_elements).to(get_accelerator().device_name()) * sum_of_ranks result = result.to(dtype) x = x.to(dtype) dist.inference_all_reduce(x)