diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index af122ee6e3d..edcc40e4d5f 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -429,12 +429,22 @@ c10::intrusive_ptr postReduceScatter( scattered_axis >= 0, "scattered_axis is expected to be non-negative: ", scattered_axis); -// reduce_scatter primitive in c10d induces extra buffering time to copy the -// user's input tensors to an internal source buffer. It is therefore always -// preferable to use _reduce_scatter_base (which does not perform any extra -// copy) when the tensors are stored contiguously (i.e., when -// scattered_axis==0). Note however than only nccl supports -// _reduce_scatter_base, not ucc. + + std::vector input_tensors = at::tensor_split( + input_tensor, communication->team_size(), scattered_axis); + // We could have checked the output shape as well if reduction_axis is + // available. It's not always available via + // `communication->out()->getReductionAxis()` for manually constructed host + // IRs like + // https://github.com/NVIDIA/Fuser/blob/89c47f695b296eb4ffd27984bd4c953fc3f3264b/tests/cpp/test_multidevice_overlap.cpp#L347. + assertBuffersHaveSameSize(input_tensors, {}); + + // reduce_scatter primitive in c10d induces extra buffering time to copy the + // user's input tensors to an internal source buffer. It is therefore always + // preferable to use _reduce_scatter_base (which does not perform any extra + // copy) when the tensors are stored contiguously (i.e., when + // scattered_axis==0). Note however than only nccl supports + // _reduce_scatter_base, not ucc. #if defined(NVFUSER_DISTRIBUTED) && defined(USE_C10D_NCCL) if (scattered_axis == 0 && backend->getBackendName() == c10d::NCCL_BACKEND_NAME) { @@ -442,14 +452,13 @@ c10::intrusive_ptr postReduceScatter( output_tensor, input_tensor, {.reduceOp = communication->reduceOp()}); } #endif - std::vector> input_tensors(1); - input_tensors[0] = at::split(input_tensor, /*split_size=*/1, scattered_axis); - - std::vector output_tensors({output_tensor}); - assertBufferCount(input_tensors[0], communication->team().size()); + std::vector> input_tensors_vec({input_tensors}); + std::vector output_tensor_vec({output_tensor}); return backend->reduce_scatter( - output_tensors, input_tensors, {.reduceOp = communication->reduceOp()}); + output_tensor_vec, + input_tensors_vec, + {.reduceOp = communication->reduceOp()}); } c10::intrusive_ptr postSendRecv( diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index d1f06d80e1d..d89f0a3f3a4 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -7,6 +7,7 @@ // clang-format on #include +#include #include #include @@ -16,15 +17,23 @@ namespace nvfuser { +using testing::Each; +using testing::IsTrue; +using testing::Pointer; +using testing::Property; + namespace { void assertIsCompiledToHostIrContainer( const FusionExecutorCache& executor_cache) { FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); - EXPECT_TRUE(runtime->executors().size() == 1); - for (const auto& ea : runtime->executors()) { - EXPECT_TRUE(ea->isA()) - << "failed to compile to a HostIrContainer with Communications"; - } + EXPECT_EQ(runtime->executors().size(), 1); + EXPECT_THAT( + runtime->executors(), + Each(Pointer(Property( + "is a HostIrExecutor", + &ExecutorAbstract::isA, + IsTrue())))) + << "failed to compile to a HostIrContainer with Communications"; } } // namespace diff --git a/tests/python/test_communication.py b/tests/python/test_communication.py index d0cea846669..75a94f6cec4 100644 --- a/tests/python/test_communication.py +++ b/tests/python/test_communication.py @@ -15,13 +15,13 @@ @pytest.mark.mpi def test_allgather(mpi_test): - num_devices = mpi_test.size - mesh = nvfuser.DeviceMesh(range(num_devices)) + d = mpi_test.size + mesh = nvfuser.DeviceMesh(range(d)) class Model(FusionDefinition): def definition(self): self.inp = self.define_tensor( - (num_devices * 4,), contiguity=True, dtype=DataType.Float + (d * 4,), contiguity=True, dtype=DataType.Float ) self.out = self.ops.set(self.inp) self.add_output(self.out) @@ -30,16 +30,116 @@ def multidevice_schedule(self): self.sched._set_device_mesh(self.inp, mesh) self.sched._set_device_mesh(self.out, mesh) - self.sched.split(self.inp, 0, num_devices, False) + self.sched.split(self.inp, 0, d, False) self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x) self.sched.set_allocation_as_loop(self.inp) - self.sched.split(self.out, 0, num_devices, False) + self.sched.split(self.out, 0, d, False) self.sched.set_allocation_as_loop(self.out) - unsharded = torch.randn(num_devices * 4) + unsharded = torch.randn(d * 4) sharded = mpi_test.shard_tensor(unsharded, 0, mesh) fd = Model() outputs = fd.execute([sharded]) torch.testing.assert_close(outputs[0].cpu(), unsharded) + + +@pytest.mark.mpi +def test_allreduce(mpi_test): + d = mpi_test.size + mesh = nvfuser.DeviceMesh(range(d)) + + class Model(FusionDefinition): + def definition(self): + self.inp = self.define_tensor((d, 4), contiguity=True, dtype=DataType.Float) + self.out = self.ops.sum(self.inp, [0]) + self.add_output(self.out) + + def multidevice_schedule(self): + self.sched._set_device_mesh(self.inp, mesh) + self.sched._set_device_mesh(self.out, mesh) + + self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x) + + unsharded = torch.randn(d, 4) + sharded = mpi_test.shard_tensor(unsharded, 0, mesh) + + fd = Model() + outputs = fd.execute([sharded]) + torch.testing.assert_close(outputs[0].cpu(), unsharded.sum(0)) + + +@pytest.mark.mpi +def test_reduce_scatter(mpi_test): + d = mpi_test.size + mesh = nvfuser.DeviceMesh(range(d)) + + class Model(FusionDefinition): + def definition(self): + self.inp = self.define_tensor( + (d, d * 4), contiguity=True, dtype=DataType.Float + ) + self.out = self.ops.sum(self.inp, [0]) + self.add_output(self.out) + + def multidevice_schedule(self): + self.sched._set_device_mesh(self.inp, mesh) + self.sched._set_device_mesh(self.out, mesh) + + self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x) + + self.sched.split(self.out, -1, d, False) + self.sched.parallelize(self.out, -2, nvfuser.ParallelType.mesh_x) + self.sched.set_allocation_as_loop(self.out) + + unsharded = torch.randn(d, d * 4) + sharded = mpi_test.shard_tensor(unsharded, 0, mesh) + + fd = Model() + outputs = fd.execute([sharded]) + torch.testing.assert_close( + outputs[0], mpi_test.shard_tensor(unsharded.sum(0), 0, mesh) + ) + + +@pytest.mark.mpi +def test_reduce_scatter_noncontiguous(mpi_test): + d = mpi_test.size + mesh = nvfuser.DeviceMesh(range(d)) + + class Model(FusionDefinition): + def definition(self): + self.inp = self.define_tensor( + (d, 3, d * 4), contiguity=True, dtype=DataType.Float + ) + self.out = self.ops.sum(self.inp, [0]) + self.add_output(self.out) + + def multidevice_schedule(self): + self.sched._set_device_mesh(self.inp, mesh) + self.sched._set_device_mesh(self.out, mesh) + + # inp: [iDID{d}, i{3}, i{d*4}] + # out: [r{d}, i{3}, i{d*4}] + # / \ + # iDID{d} i{4} + # + # Unlike test_reduce_scatter, this leads to extra data copy because + # the scattered axis is not outermost in allocation. + # ProcessGroupNCCL::reduce_scatter was able to handle + # non-contiguous scattering in a functional but suboptimal way. + self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x) + + self.sched.split(self.out, -1, d, False) + self.sched.parallelize(self.out, -2, nvfuser.ParallelType.mesh_x) + self.sched.set_allocation_as_loop(self.out) + + unsharded = torch.randn(d, 3, d * 4) + sharded = mpi_test.shard_tensor(unsharded, 0, mesh) + + fd = Model() + outputs = fd.execute([sharded]) + torch.testing.assert_close( + outputs[0], mpi_test.shard_tensor(unsharded.sum(0), 1, mesh) + )