Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReduceScatter with DID loop split #3504

Merged
merged 6 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,27 +429,36 @@ c10::intrusive_ptr<c10d::Work> postReduceScatter(
scattered_axis >= 0,
"scattered_axis is expected to be non-negative: ",
scattered_axis);
// reduce_scatter primitive in c10d induces extra buffering time to copy the
// user's input tensors to an internal source buffer. It is therefore always
// preferable to use _reduce_scatter_base (which does not perform any extra
// copy) when the tensors are stored contiguously (i.e., when
// scattered_axis==0). Note however than only nccl supports
// _reduce_scatter_base, not ucc.

std::vector<at::Tensor> input_tensors = at::tensor_split(
input_tensor, communication->team_size(), scattered_axis);
// We could have checked the output shape as well if reduction_axis is
// available. It's not always available via
// `communication->out()->getReductionAxis()` for manually constructed host
// IRs like
// https://github.com/NVIDIA/Fuser/blob/89c47f695b296eb4ffd27984bd4c953fc3f3264b/tests/cpp/test_multidevice_overlap.cpp#L347.
assertBuffersHaveSameSize(input_tensors, {});

// reduce_scatter primitive in c10d induces extra buffering time to copy the
// user's input tensors to an internal source buffer. It is therefore always
// preferable to use _reduce_scatter_base (which does not perform any extra
// copy) when the tensors are stored contiguously (i.e., when
// scattered_axis==0). Note however than only nccl supports
// _reduce_scatter_base, not ucc.
#if defined(NVFUSER_DISTRIBUTED) && defined(USE_C10D_NCCL)
if (scattered_axis == 0 &&
backend->getBackendName() == c10d::NCCL_BACKEND_NAME) {
return backend->_reduce_scatter_base(
output_tensor, input_tensor, {.reduceOp = communication->reduceOp()});
}
#endif
std::vector<std::vector<at::Tensor>> input_tensors(1);
input_tensors[0] = at::split(input_tensor, /*split_size=*/1, scattered_axis);

std::vector<at::Tensor> output_tensors({output_tensor});

assertBufferCount(input_tensors[0], communication->team().size());
std::vector<std::vector<at::Tensor>> input_tensors_vec({input_tensors});
std::vector<at::Tensor> output_tensor_vec({output_tensor});
return backend->reduce_scatter(
output_tensors, input_tensors, {.reduceOp = communication->reduceOp()});
output_tensor_vec,
input_tensors_vec,
{.reduceOp = communication->reduceOp()});
}

c10::intrusive_ptr<c10d::Work> postSendRecv(
Expand Down
19 changes: 14 additions & 5 deletions tests/cpp/test_multidevice_lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// clang-format on

#include <gmock/gmock-matchers.h>
#include <gmock/gmock-more-matchers.h>
#include <gtest/gtest.h>

#include <ops/all_ops.h>
Expand All @@ -16,15 +17,23 @@

namespace nvfuser {

using testing::Each;
using testing::IsTrue;
using testing::Pointer;
using testing::Property;

namespace {
void assertIsCompiledToHostIrContainer(
const FusionExecutorCache& executor_cache) {
FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
EXPECT_TRUE(runtime->executors().size() == 1);
for (const auto& ea : runtime->executors()) {
EXPECT_TRUE(ea->isA<HostIrExecutor>())
<< "failed to compile to a HostIrContainer with Communications";
}
EXPECT_EQ(runtime->executors().size(), 1);
EXPECT_THAT(
runtime->executors(),
Each(Pointer(Property(
"is a HostIrExecutor",
&ExecutorAbstract::isA<HostIrExecutor>,
IsTrue()))))
<< "failed to compile to a HostIrContainer with Communications";
}
} // namespace

Expand Down
112 changes: 106 additions & 6 deletions tests/python/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

@pytest.mark.mpi
def test_allgather(mpi_test):
num_devices = mpi_test.size
mesh = nvfuser.DeviceMesh(range(num_devices))
d = mpi_test.size
mesh = nvfuser.DeviceMesh(range(d))

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor(
(num_devices * 4,), contiguity=True, dtype=DataType.Float
(d * 4,), contiguity=True, dtype=DataType.Float
)
self.out = self.ops.set(self.inp)
self.add_output(self.out)
Expand All @@ -30,16 +30,116 @@ def multidevice_schedule(self):
self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)

self.sched.split(self.inp, 0, num_devices, False)
self.sched.split(self.inp, 0, d, False)
self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)
self.sched.set_allocation_as_loop(self.inp)

self.sched.split(self.out, 0, num_devices, False)
self.sched.split(self.out, 0, d, False)
self.sched.set_allocation_as_loop(self.out)

unsharded = torch.randn(num_devices * 4)
unsharded = torch.randn(d * 4)
sharded = mpi_test.shard_tensor(unsharded, 0, mesh)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(outputs[0].cpu(), unsharded)


@pytest.mark.mpi
def test_allreduce(mpi_test):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allreduce test is merely for DID logical split. I don't think allreduce can support DID loop split because sum's reduction axes can only be logical. But I'd be happy to know otherwise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allreduce test is merely for DID logical split.

Just to be clear, you meant DID parallelization of logical domains, right? I'm not sure what you meant by DID logical split otherwise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming I understand what you meant correctly, I think this is where TensorView::rFactor could be used. That's what we use for intra-device hierarchical reductions. For example, I'd think that for multi-GPU reductions, we would have something like:

(I'm mixing the C++ and Python APIs)

self.out->split(0, num_devices, /*inner=*/false);
auto intermediate_result = self.out->rFactor({1});
intermediate_result->axis(0)->parallelize(DIDx);
self.out->axis(0)->parallelize(DIDx);

Here, intermediate_result would be the partial result of per-device reduction, which would be then reduced between all the devices and saved to self.out.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's something like rfactor, and did look into how TensorView::rfactor works in

TensorView* tv2 = tv1_copy->rFactor({0});
. However, I failed to see how it applies here.

If we want to loop (but not logical) split an allreduce, the input would be a logical shape like [D*2,3] and the output would be of logical shape like [2,3]. Regardless of scheduling, what ops in fusion IR could do that? (Not a sum because that reduces an entire dimension to 1).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's talk offline. It seems we are not using the same vocabulary (e.g., I don't understand what "loop split" and "logical split" mean).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#3543 is my failed attempt. It triggered an assertion at

NVF_THROW("Unexpected producer RF ID: ", producer_rf_id->toString())
. Is it because code there has been assuming that reductions are innermost?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyhow, this isn't a blocker. As we discussed yesterday, we'll probably stick with logical split for reductions in Allreduce and ReduceScatter due to MatmulOp's implementation.

d = mpi_test.size
mesh = nvfuser.DeviceMesh(range(d))

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor((d, 4), contiguity=True, dtype=DataType.Float)
self.out = self.ops.sum(self.inp, [0])
self.add_output(self.out)

def multidevice_schedule(self):
self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)

self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)

unsharded = torch.randn(d, 4)
sharded = mpi_test.shard_tensor(unsharded, 0, mesh)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(outputs[0].cpu(), unsharded.sum(0))


@pytest.mark.mpi
def test_reduce_scatter(mpi_test):
d = mpi_test.size
mesh = nvfuser.DeviceMesh(range(d))

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor(
(d, d * 4), contiguity=True, dtype=DataType.Float
)
self.out = self.ops.sum(self.inp, [0])
self.add_output(self.out)

def multidevice_schedule(self):
self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)

self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)

self.sched.split(self.out, -1, d, False)
self.sched.parallelize(self.out, -2, nvfuser.ParallelType.mesh_x)
self.sched.set_allocation_as_loop(self.out)

unsharded = torch.randn(d, d * 4)
sharded = mpi_test.shard_tensor(unsharded, 0, mesh)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(
outputs[0], mpi_test.shard_tensor(unsharded.sum(0), 0, mesh)
)


@pytest.mark.mpi
def test_reduce_scatter_noncontiguous(mpi_test):
d = mpi_test.size
mesh = nvfuser.DeviceMesh(range(d))

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor(
(d, 3, d * 4), contiguity=True, dtype=DataType.Float
)
self.out = self.ops.sum(self.inp, [0])
self.add_output(self.out)

def multidevice_schedule(self):
self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)

# inp: [iDID{d}, i{3}, i{d*4}]
# out: [r{d}, i{3}, i{d*4}]
# / \
# iDID{d} i{4}
#
# Unlike test_reduce_scatter, this leads to extra data copy because
# the scattered axis is not outermost in allocation.
# ProcessGroupNCCL::reduce_scatter was able to handle
# non-contiguous scattering in a functional but suboptimal way.
self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)

self.sched.split(self.out, -1, d, False)
self.sched.parallelize(self.out, -2, nvfuser.ParallelType.mesh_x)
self.sched.set_allocation_as_loop(self.out)

unsharded = torch.randn(d, 3, d * 4)
sharded = mpi_test.shard_tensor(unsharded, 0, mesh)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(
outputs[0], mpi_test.shard_tensor(unsharded.sum(0), 1, mesh)
)
Loading