Skip to content

Commit

Permalink
Allgather with DID loop split (#3284)
Browse files Browse the repository at this point in the history
Another baby step towards #2563
  • Loading branch information
wujingyue authored Dec 9, 2024
1 parent 5a2184c commit 4a897a4
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 141 deletions.
4 changes: 2 additions & 2 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ c10::intrusive_ptr<c10d::Work> postAllgather(
c10d::Backend* backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
auto splits = at::split(output_tensor, /*split_size=*/1, /*dim=*/0);
assertBufferCount(splits, communication->team().size());
auto splits =
at::tensor_split(output_tensor, communication->team_size(), /*dim=*/0);
assertBuffersHaveSameSize({input_tensor}, splits);

// allgather primitive in c10d induces extra buffering time to copy out the
Expand Down
5 changes: 5 additions & 0 deletions csrc/multidevice/communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ class Communication : public Expr {
return attribute<Team>(1);
}

// A convenience helper so the user doesn't need to convert size_t to int64_t.
int64_t team_size() const {
return static_cast<int64_t>(team().size());
}

DeviceIdxType root() const {
return attribute<DeviceIdxType>(2);
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/multidevice/lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void lowerToReduceScatter(
std::vector<Communication*>& comms) {
const DeviceMesh& mesh = input_tv->getDeviceMesh();
auto reduction_axis = output_tv->getReductionAxis().value();
auto scattered_axis = getShardedAxis(output_tv);
auto scattered_axis = getShardedLogicalAxis(output_tv, ParallelType::DIDx);
// The output tensor is sharded on scattered_axis and needs to be mapped
// back onto the input. The input has an reduced axis, so the scattered axis
// is adjusted to account for this. Ex: [DIDx(i0), i1] -> [r0, DIDx(i1)] The
Expand Down
188 changes: 121 additions & 67 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,48 +121,133 @@ bool isSharded(const TensorView* tv) {
return is_sharded;
}

std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes) {
std::vector<int64_t> unsharded_sizes = sizes.vec();

for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
const ParallelType parallel_type = alloc_id->getParallelType();
namespace {
// Collect device-parallel IterDomains in `domain` and return them as a
// ParallelType-to-IterDomain map.
std::unordered_map<ParallelType, IterDomain*> mapDeviceParallelTypeToId(
const std::vector<IterDomain*>& domain) {
std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id;
parallel_type_to_id.reserve(kParallelTypeDIDs.size());
for (IterDomain* id : domain) {
const ParallelType parallel_type = id->getParallelType();
if (!isParallelTypeDeviceDim(parallel_type)) {
continue;
}

const auto inputs = IterVisitor::getInputsTo(
{alloc_id},
{tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()});
NVF_ERROR(
!inputs.empty(),
"IterVisitor::getInputsTo shouldn't return empty unless `of` is empty.");
NVF_ERROR(
inputs.size() == 1,
"Failed to find the single logical input to ",
alloc_id,
". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ",
toDelimitedString(inputs));

const auto iter = std::find(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
inputs[0]);
parallel_type_to_id.try_emplace(parallel_type, id).second,
"Found multiple loop IterDomains with the same parallel type (",
parallel_type,
"): ",
toDelimitedString(domain));
}
return parallel_type_to_id;
}

std::unordered_map<IterDomain*, int64_t> mapIterDomainToTensorAxis(
const std::vector<IterDomain*>& domain) {
std::unordered_map<IterDomain*, int64_t> id_to_axis;
int64_t axis = 0;
for (auto* id : domain) {
// Reduction IterDomains are not materialized as an at::Tensor axis.
if (id->isReduction()) {
continue;
}
id_to_axis[id] = axis;
axis++;
}
return id_to_axis;
}

} // namespace

int64_t getShardedLogicalAxis(
const TensorView* tv,
const ParallelType parallel_type) {
std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id =
mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain());
IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type);
if (alloc_id == nullptr) {
return -1;
}

std::unordered_map<IterDomain*, int64_t> logical_id_to_axis =
mapIterDomainToTensorAxis(tv->getLogicalDomain());
IterDomain* id = alloc_id;
while (logical_id_to_axis.count(id) == 0) {
Expr* def = id->definition();
NVF_ERROR(
iter != tv->getLogicalDomain().end(),
"The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ",
inputs[0]);

// Count the number of non-reduction IterDomains before `iter`. Reduction
// IterDomains are not materialized in the at::Tensor's shape.
const auto index = std::count_if(
tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool {
return !id->isReduction();
});
unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type);
def != nullptr,
"Failed to find a non-reduction logical IterDomain that produces ",
alloc_id);
if (auto* split = dynamic_cast<Split*>(def)) {
// Returning just which tensor axis is sharded isn't sufficient to let
// shardTensor, a user of this function, know how to shard the tensor.
// For example,
//
// t = makeContigConcreteTensor({6});
// t->split(0, 2, /*inner_split=*/true);
// t->axis(-1)->parallelize(DIDx);
// // [i{3}, iDIDx{2}]
//
// and the unsharded tensor is [0, 1, 2, 3, 4, 5], regardless of the
// stride. The sharded tensor ought to be [0, 2, 4] for GPU 0 and [1, 3,
// 5] for GPU 1. However, shardTensor as is will return [0, 1, 2] and [3,
// 4, 5], assuming the axis is sharded outermost.
//
// One potential way to solve the general problem is to replay and rewind
// the splits on the at::Tensor. For example,
//
// t = makeContigConcreteTensor({30});
// t->split(0, 5);
// t->split(0, 3);
// t->axis(0)->parallelize(Host);
// t->axis(1)->parallelize(DIDx);
// // [iHost{2}, iDIDx{3}, i{5}]
//
// Given an unsharded at::Tensor of shape [30], we'll first replay the
// splits using `torch.view` to get a tensor of shape [2,3,5]. Then, we
// `torch.slice` axis 1 for DIDx to get a tensor of shape [2,1,5]. Then,
// we rewind the splits (and therefore apply merging) using
// `torch.reshape` to get a sharded tensor of shape [10].
NVF_ERROR(
split->outer() == id,
"Currently, we don't support DID on inner splits: ",
split);
id = split->in();
} else if (auto* merge = dynamic_cast<Merge*>(def)) {
// For example,
//
// t = makeContigTensor(2);
// t->merge(0, 1);
// t->axis(0)->parallelize(DIDx);
//
// When `unshardedSizes` is given a local tensor of shape [1, 1], it's
// unclear the global shape is [1, D] or [D, 1] or even [2, D/2], etc.
NVF_THROW(
"Failed to attribute the sharding to a single tensor axis and therefore bailed out: ",
merge);
} else {
NVF_THROW(
"Unexpected transforms from logical to a DID-parallel allocation IterDomain: ",
def);
}
}

return logical_id_to_axis.at(id);
}

std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes) {
std::vector<int64_t> unsharded_sizes = sizes.vec();
for (ParallelType parallel_type : kParallelTypeDIDs) {
const int64_t sharded_axis = getShardedLogicalAxis(tv, parallel_type);
if (sharded_axis == -1) {
continue;
}
unsharded_sizes.at(sharded_axis) *= tv->getDeviceMesh().size(parallel_type);
}
return unsharded_sizes;
}

Expand All @@ -174,27 +259,6 @@ int64_t numDeviceDims(const TensorView* tv) {
}

namespace {
// Collect device-parallel IterDomains in `loop_domain` and return them as a
// ParallelType-to-IterDomain map.
std::unordered_map<ParallelType, IterDomain*> mapParallelTypeToId(
const std::vector<IterDomain*>& loop_domain) {
std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id;
parallel_type_to_id.reserve(kParallelTypeDIDs.size());
for (IterDomain* loop_id : loop_domain) {
const ParallelType parallel_type = loop_id->getParallelType();
if (!isParallelTypeDeviceDim(parallel_type)) {
continue;
}

NVF_ERROR(
parallel_type_to_id.try_emplace(parallel_type, loop_id).second,
"Found multiple loop IterDomains with the same parallel type (",
parallel_type,
"): ",
toDelimitedString(loop_domain));
}
return parallel_type_to_id;
}

std::vector<IterDomain*> getInputsInTargetDomain(
IterDomain* loop_id,
Expand Down Expand Up @@ -294,9 +358,9 @@ bool haveDifferentShardings(
// 3. Check if the two loop IterDomains are almost-exactly mapped in the
// IdModel.
std::unordered_map<ParallelType, IterDomain*> p_parallel_type_to_id =
mapParallelTypeToId(producer->getLoopDomain());
mapDeviceParallelTypeToId(producer->getLoopDomain());
std::unordered_map<ParallelType, IterDomain*> c_parallel_type_to_id =
mapParallelTypeToId(consumer->getLoopDomain());
mapDeviceParallelTypeToId(consumer->getLoopDomain());

for (const auto parallel_type : kParallelTypeDIDs) {
IterDomain* p_loop_id = getOrDefault(p_parallel_type_to_id, parallel_type);
Expand Down Expand Up @@ -502,16 +566,6 @@ std::set<DeviceIdxType> involvedDevices(Expr* expr) {
return ret;
}

int64_t getShardedAxis(TensorView* tv) {
auto ids = TensorDomain::noReductions(tv->getLogicalDomain());
for (size_t i = 0; i < ids.size(); ++i) {
if (ids[i]->getParallelType() == ParallelType::DIDx) {
return static_cast<int64_t>(i);
}
}
return -1;
}

void reorderDIDToFront(TensorView* tv) {
// new position to old position
std::unordered_map<int64_t, int64_t> order_map;
Expand Down
13 changes: 10 additions & 3 deletions csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,16 @@ int64_t requestedNumberOfDevices(Fusion*);
void unshard(Fusion*);
void unshard(TensorView*);

// Returns the index of the a sharded axis if none return -1.
// TODO: Assumes no merges/splits on sharded axis.
int64_t getShardedAxis(TensorView*);
// Returns the index of the sharded logical axis that produces the allocation
// IterDomain sharded on `parallel_type`. If `tv` isn't sharded on the parallel
// type, returns -1.
//
// This is used to correlate `tv` and its corresponding at::Tensor, e.g., by
// `unshardedSizes` and `shardTensor`. `at::Tensor::sizes` and
// `tv->getLogicalDomain()` map one-to-one modulo reduction. However, a size in
// `at::Tensor::sizes` is a factor of the corresponding logical IterDomain's
// extent if that IterDomain is sharded.
int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type);

// Reorders a TensorView so that the DID parallelized axis are in front.
void reorderDIDToFront(TensorView*);
Expand Down
16 changes: 8 additions & 8 deletions tests/cpp/multidevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ at::Tensor MultiDeviceTest::shardTensor(at::Tensor tensor, TensorView* tv) {
return tensor;
}
NVF_ERROR(tv->hasDeviceMesh(), "`tv` has no DeviceMesh: ", tv);
return shardTensor(tensor, getShardedAxis(tv), tv->getDeviceMesh());
return shardTensor(
tensor,
getShardedLogicalAxis(tv, ParallelType::DIDx),
tv->getDeviceMesh());
}

at::Tensor MultiDeviceTest::shardTensor(
Expand All @@ -144,13 +147,10 @@ at::Tensor MultiDeviceTest::shardTensor(
auto stride = extent / nslices;
// TODO: returning slice 0 temporarily when device is not in the mesh.
i = (i < 0) ? 0 : i;
auto slice = tensor.slice(axis, i * stride, (i + 1) * stride).contiguous();
// Temporary until https://github.com/NVIDIA/Fuser/issues/2563. Adds DIDx
// axis in front representing the sharded extent of the tensor.
if (stride > 1) {
slice = slice.unsqueeze(0);
}
return slice;
// The following slicing is problematic when DID is on an inner split (cf.
// MultiDeviceTest.ShardTensor_InnerSplit). We currently disallow that and
// it's enforced by getShardedLogicalAxis.
return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous();
}

} // namespace nvfuser
Expand Down
67 changes: 67 additions & 0 deletions tests/cpp/test_multidevice_lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,73 @@ TEST_F(LowerCollectiveTest, Allgather) {
EXPECT_TRUE(at::equal(out_tensor, unsharded_tensor));
}

TEST_F(LowerCollectiveTest, Allgather_LoopSplit) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

const auto num_devices = communicator_->size();
auto mesh = DeviceMesh::createForNumDevices(num_devices);

TensorView* in = makeContigTensor(1);
in->setDeviceMesh(mesh);
TensorView* out = set(in);
fusion->addInput(in);
fusion->addOutput(out);

in->split(0, num_devices, /*inner_split=*/false);
in->axis(0)->parallelize(ParallelType::DIDx);
in->setAllocationDomain(in->getLoopDomain(), true);

out->split(0, num_devices, /*inner_split=*/false);
out->setAllocationDomain(out->getLoopDomain(), true);

at::Tensor unsharded_tensor =
at::randn({num_devices * kTensorSize}, at::kFloat);
at::Tensor in_tensor =
shardTensor(unsharded_tensor, in).to(communicator_->device());

FusionExecutorCache fec(std::move(fusion));
at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0];
assertIsCompiledToHostIrContainer(fec);

EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor));
}

// This currently fails due to getShardingChanges reads root/logical only:
// https://github.com/NVIDIA/Fuser/blob/1dda106a946adcfd1526b83e4f2d4abebb9e32e4/csrc/multidevice/utils.cpp#L77.
// Will try to fix this in a follow-up PR and reenable the test.
TEST_F(LowerCollectiveTest, DISABLED_Allgather_LoopSplit_Noncontiguous) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

const auto num_devices = communicator_->size();
auto mesh = DeviceMesh::createForNumDevices(num_devices);

TensorView* in = makeContigTensor(2);
in->setDeviceMesh(mesh);
TensorView* out = set(in);
fusion->addInput(in);
fusion->addOutput(out);

in->split(1, num_devices, /*inner_split=*/false);
in->axis(1)->parallelize(ParallelType::DIDx);
in->setAllocationDomain(in->getLoopDomain(), true);

out->split(1, num_devices, /*inner_split=*/false);
out->setAllocationDomain(out->getLoopDomain(), true);

at::Tensor unsharded_tensor =
at::arange(2 * num_devices * 3, at::kFloat).view({2, num_devices * 3});
at::Tensor in_tensor =
shardTensor(unsharded_tensor, in).to(communicator_->device());

FusionExecutorCache fec(std::move(fusion));
at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0];
assertIsCompiledToHostIrContainer(fec);

EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor));
}

TEST_F(LowerCollectiveTest, Broadcast) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
Expand Down
Loading

0 comments on commit 4a897a4

Please sign in to comment.