Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allgather with DID loop split #3284

Merged
merged 23 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ c10::intrusive_ptr<c10d::Work> postAllgather(
c10d::Backend* backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
auto splits = at::split(output_tensor, /*split_size=*/1, /*dim=*/0);
assertBufferCount(splits, communication->team().size());
auto splits =
samnordmann marked this conversation as resolved.
Show resolved Hide resolved
at::tensor_split(output_tensor, communication->team_size(), /*dim=*/0);
assertBuffersHaveSameSize({input_tensor}, splits);

// allgather primitive in c10d induces extra buffering time to copy out the
Expand Down
5 changes: 5 additions & 0 deletions csrc/multidevice/communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ class Communication : public Expr {
return attribute<Team>(1);
}

// A convenience helper so the user doesn't need to convert size_t to int64_t.
int64_t team_size() const {
return static_cast<int64_t>(team().size());
}

DeviceIdxType root() const {
return attribute<DeviceIdxType>(2);
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/multidevice/lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void lowerToReduceScatter(
std::vector<Communication*>& comms) {
const DeviceMesh& mesh = input_tv->getDeviceMesh();
auto reduction_axis = output_tv->getReductionAxis().value();
auto scattered_axis = getShardedAxis(output_tv);
auto scattered_axis = getShardedLogicalAxis(output_tv, ParallelType::DIDx);
// The output tensor is sharded on scattered_axis and needs to be mapped
// back onto the input. The input has an reduced axis, so the scattered axis
// is adjusted to account for this. Ex: [DIDx(i0), i1] -> [r0, DIDx(i1)] The
Expand Down
188 changes: 121 additions & 67 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,48 +121,133 @@ bool isSharded(const TensorView* tv) {
return is_sharded;
}

std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes) {
std::vector<int64_t> unsharded_sizes = sizes.vec();

for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
const ParallelType parallel_type = alloc_id->getParallelType();
namespace {
// Collect device-parallel IterDomains in `domain` and return them as a
// ParallelType-to-IterDomain map.
std::unordered_map<ParallelType, IterDomain*> mapDeviceParallelTypeToId(
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<IterDomain*>& domain) {
std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id;
parallel_type_to_id.reserve(kParallelTypeDIDs.size());
for (IterDomain* id : domain) {
const ParallelType parallel_type = id->getParallelType();
if (!isParallelTypeDeviceDim(parallel_type)) {
continue;
}

const auto inputs = IterVisitor::getInputsTo(
{alloc_id},
{tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()});
NVF_ERROR(
!inputs.empty(),
"IterVisitor::getInputsTo shouldn't return empty unless `of` is empty.");
NVF_ERROR(
inputs.size() == 1,
"Failed to find the single logical input to ",
alloc_id,
". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ",
toDelimitedString(inputs));

const auto iter = std::find(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
inputs[0]);
parallel_type_to_id.try_emplace(parallel_type, id).second,
"Found multiple loop IterDomains with the same parallel type (",
parallel_type,
"): ",
toDelimitedString(domain));
}
return parallel_type_to_id;
}

std::unordered_map<IterDomain*, int64_t> mapIterDomainToTensorAxis(
const std::vector<IterDomain*>& domain) {
std::unordered_map<IterDomain*, int64_t> id_to_axis;
int64_t axis = 0;
for (auto* id : domain) {
// Reduction IterDomains are not materialized as an at::Tensor axis.
if (id->isReduction()) {
continue;
}
id_to_axis[id] = axis;
axis++;
}
return id_to_axis;
}

} // namespace

int64_t getShardedLogicalAxis(
const TensorView* tv,
const ParallelType parallel_type) {
std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id =
mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain());
IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type);
if (alloc_id == nullptr) {
return -1;
}

std::unordered_map<IterDomain*, int64_t> logical_id_to_axis =
mapIterDomainToTensorAxis(tv->getLogicalDomain());
IterDomain* id = alloc_id;
while (logical_id_to_axis.count(id) == 0) {
Expr* def = id->definition();
NVF_ERROR(
iter != tv->getLogicalDomain().end(),
"The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ",
inputs[0]);

// Count the number of non-reduction IterDomains before `iter`. Reduction
// IterDomains are not materialized in the at::Tensor's shape.
const auto index = std::count_if(
tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool {
return !id->isReduction();
});
unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type);
def != nullptr,
"Failed to find a non-reduction logical IterDomain that produces ",
alloc_id);
if (auto* split = dynamic_cast<Split*>(def)) {
// Returning just which tensor axis is sharded isn't sufficient to let
// shardTensor, a user of this function, know how to shard the tensor.
// For example,
//
// t = makeContigConcreteTensor({6});
// t->split(0, 2, /*inner_split=*/true);
// t->axis(-1)->parallelize(DIDx);
// // [i{3}, iDIDx{2}]
//
// and the unsharded tensor is [0, 1, 2, 3, 4, 5], regardless of the
// stride. The sharded tensor ought to be [0, 2, 4] for GPU 0 and [1, 3,
// 5] for GPU 1. However, shardTensor as is will return [0, 1, 2] and [3,
// 4, 5], assuming the axis is sharded outermost.
//
// One potential way to solve the general problem is to replay and rewind
// the splits on the at::Tensor. For example,
//
// t = makeContigConcreteTensor({30});
// t->split(0, 5);
// t->split(0, 3);
// t->axis(0)->parallelize(Host);
// t->axis(1)->parallelize(DIDx);
// // [iHost{2}, iDIDx{3}, i{5}]
//
// Given an unsharded at::Tensor of shape [30], we'll first replay the
// splits using `torch.view` to get a tensor of shape [2,3,5]. Then, we
// `torch.slice` axis 1 for DIDx to get a tensor of shape [2,1,5]. Then,
// we rewind the splits (and therefore apply merging) using
// `torch.reshape` to get a sharded tensor of shape [10].
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
NVF_ERROR(
split->outer() == id,
"Currently, we don't support DID on inner splits: ",
split);
id = split->in();
} else if (auto* merge = dynamic_cast<Merge*>(def)) {
// For example,
//
// t = makeContigTensor(2);
// t->merge(0, 1);
// t->axis(0)->parallelize(DIDx);
//
// When `unshardedSizes` is given a local tensor of shape [1, 1], it's
// unclear the global shape is [1, D] or [D, 1] or even [2, D/2], etc.
NVF_THROW(
"Failed to attribute the sharding to a single tensor axis and therefore bailed out: ",
merge);
} else {
NVF_THROW(
"Unexpected transforms from logical to a DID-parallel allocation IterDomain: ",
def);
}
}

return logical_id_to_axis.at(id);
}

std::vector<int64_t> unshardedSizes(
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
const TensorView* tv,
c10::IntArrayRef sizes) {
std::vector<int64_t> unsharded_sizes = sizes.vec();
for (ParallelType parallel_type : kParallelTypeDIDs) {
const int64_t sharded_axis = getShardedLogicalAxis(tv, parallel_type);
if (sharded_axis == -1) {
continue;
}
unsharded_sizes.at(sharded_axis) *= tv->getDeviceMesh().size(parallel_type);
}
return unsharded_sizes;
}

Expand All @@ -174,27 +259,6 @@ int64_t numDeviceDims(const TensorView* tv) {
}

namespace {
// Collect device-parallel IterDomains in `loop_domain` and return them as a
// ParallelType-to-IterDomain map.
std::unordered_map<ParallelType, IterDomain*> mapParallelTypeToId(
const std::vector<IterDomain*>& loop_domain) {
std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id;
parallel_type_to_id.reserve(kParallelTypeDIDs.size());
for (IterDomain* loop_id : loop_domain) {
const ParallelType parallel_type = loop_id->getParallelType();
if (!isParallelTypeDeviceDim(parallel_type)) {
continue;
}

NVF_ERROR(
parallel_type_to_id.try_emplace(parallel_type, loop_id).second,
"Found multiple loop IterDomains with the same parallel type (",
parallel_type,
"): ",
toDelimitedString(loop_domain));
}
return parallel_type_to_id;
}

std::vector<IterDomain*> getInputsInTargetDomain(
IterDomain* loop_id,
Expand Down Expand Up @@ -294,9 +358,9 @@ bool haveDifferentShardings(
// 3. Check if the two loop IterDomains are almost-exactly mapped in the
// IdModel.
std::unordered_map<ParallelType, IterDomain*> p_parallel_type_to_id =
mapParallelTypeToId(producer->getLoopDomain());
mapDeviceParallelTypeToId(producer->getLoopDomain());
std::unordered_map<ParallelType, IterDomain*> c_parallel_type_to_id =
mapParallelTypeToId(consumer->getLoopDomain());
mapDeviceParallelTypeToId(consumer->getLoopDomain());

for (const auto parallel_type : kParallelTypeDIDs) {
IterDomain* p_loop_id = getOrDefault(p_parallel_type_to_id, parallel_type);
Expand Down Expand Up @@ -502,16 +566,6 @@ std::set<DeviceIdxType> involvedDevices(Expr* expr) {
return ret;
}

int64_t getShardedAxis(TensorView* tv) {
auto ids = TensorDomain::noReductions(tv->getLogicalDomain());
for (size_t i = 0; i < ids.size(); ++i) {
if (ids[i]->getParallelType() == ParallelType::DIDx) {
return static_cast<int64_t>(i);
}
}
return -1;
}

void reorderDIDToFront(TensorView* tv) {
// new position to old position
std::unordered_map<int64_t, int64_t> order_map;
Expand Down
13 changes: 10 additions & 3 deletions csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,16 @@ int64_t requestedNumberOfDevices(Fusion*);
void unshard(Fusion*);
void unshard(TensorView*);

// Returns the index of the a sharded axis if none return -1.
// TODO: Assumes no merges/splits on sharded axis.
int64_t getShardedAxis(TensorView*);
// Returns the index of the sharded logical axis that produces the allocation
// IterDomain sharded on `parallel_type`. If `tv` isn't sharded on the parallel
// type, returns -1.
//
// This is used to correlate `tv` and its corresponding at::Tensor, e.g., by
// `unshardedSizes` and `shardTensor`. `at::Tensor::sizes` and
// `tv->getLogicalDomain()` map one-to-one modulo reduction. However, a size in
// `at::Tensor::sizes` is a factor of the corresponding logical IterDomain's
// extent if that IterDomain is sharded.
int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type);

// Reorders a TensorView so that the DID parallelized axis are in front.
void reorderDIDToFront(TensorView*);
Expand Down
16 changes: 8 additions & 8 deletions tests/cpp/multidevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ at::Tensor MultiDeviceTest::shardTensor(at::Tensor tensor, TensorView* tv) {
return tensor;
}
NVF_ERROR(tv->hasDeviceMesh(), "`tv` has no DeviceMesh: ", tv);
return shardTensor(tensor, getShardedAxis(tv), tv->getDeviceMesh());
return shardTensor(
tensor,
getShardedLogicalAxis(tv, ParallelType::DIDx),
tv->getDeviceMesh());
}

at::Tensor MultiDeviceTest::shardTensor(
Expand All @@ -144,13 +147,10 @@ at::Tensor MultiDeviceTest::shardTensor(
auto stride = extent / nslices;
// TODO: returning slice 0 temporarily when device is not in the mesh.
i = (i < 0) ? 0 : i;
auto slice = tensor.slice(axis, i * stride, (i + 1) * stride).contiguous();
// Temporary until https://github.com/NVIDIA/Fuser/issues/2563. Adds DIDx
// axis in front representing the sharded extent of the tensor.
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
if (stride > 1) {
slice = slice.unsqueeze(0);
}
return slice;
// The following slicing is problematic when DID is on an inner split (cf.
// MultiDeviceTest.ShardTensor_InnerSplit). We currently disallow that and
// it's enforced by getShardedLogicalAxis.
return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous();
}

} // namespace nvfuser
Expand Down
67 changes: 67 additions & 0 deletions tests/cpp/test_multidevice_lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,73 @@ TEST_F(LowerCollectiveTest, Allgather) {
EXPECT_TRUE(at::equal(out_tensor, unsharded_tensor));
}

TEST_F(LowerCollectiveTest, Allgather_LoopSplit) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

const auto num_devices = communicator_->size();
auto mesh = DeviceMesh::createForNumDevices(num_devices);

TensorView* in = makeContigTensor(1);
in->setDeviceMesh(mesh);
TensorView* out = set(in);
fusion->addInput(in);
fusion->addOutput(out);

in->split(0, num_devices, /*inner_split=*/false);
in->axis(0)->parallelize(ParallelType::DIDx);
in->setAllocationDomain(in->getLoopDomain(), true);

out->split(0, num_devices, /*inner_split=*/false);
out->setAllocationDomain(out->getLoopDomain(), true);

at::Tensor unsharded_tensor =
at::randn({num_devices * kTensorSize}, at::kFloat);
at::Tensor in_tensor =
shardTensor(unsharded_tensor, in).to(communicator_->device());

Copy link
Collaborator

@samnordmann samnordmann Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::vector<int64_t> ref_in_tensor_shape = {kTensorSize};
EXPECT_EQ(in_tensor.sizes(), ref_in_tensor_shape);

Copy link
Collaborator

@samnordmann samnordmann Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand how shardTensor can be correct here if it never replays the split backwards... But I might be missing something.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! I think there are two problems with the PR as is:

  1. shardTensor may slice wrong numbers. For example, if an inner split is DID'ed, the slicing needs to be strided per the outer split.
  2. nvFuser doesn't error out when Allgather is not along the outermost allocated dimension. This was guaranteed by ReorderShardedAxisPass by checking isInnerResharding. However, getShardingChanges, one of its dependencies, hasn't been updated to read loop/allocation:
    auto rootmap = PairwiseLogicalDomainMap(input, output).mapBroadcast(false);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re the suggested change: I manually checked the shape is as expected. I added some extra unit tests for shardTensor alone, so we don't have to verify it here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a couple of changes to address the problems I said in #3284 (comment).

  1. 7cf2384. It's an overkill but will probably be OK for quite some time. I had a hard time finding a concrete use case that has to mix DID and host ID within one logical dimension. I agree that to properly support inner splits we'll need to "replay the split backwards". It's not a trivial change anyhow so I'll postpone it to a separate PR.
  2. I wrote Harden assertBuffersHaveSameSize to check shapes. #3531 to harden runtime checks for allgather and added to this PR one more allgather test (Allgather_LoopSplit_Noncontiguous). These extra checks will fire when we trigger some most common limitations before properly fixing ReorderShardedAxisPass, which will take several decent-size PRs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a hard time finding a concrete use case that has to mix DID and host ID within one logical dimension.

In fact, there's

// A has shape (S, sharded(D), M/(S*D), K)
. So I'll try to file a feature request after this PR.

FusionExecutorCache fec(std::move(fusion));
at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0];
assertIsCompiledToHostIrContainer(fec);

EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor));
}

// This currently fails due to getShardingChanges reads root/logical only:
// https://github.com/NVIDIA/Fuser/blob/1dda106a946adcfd1526b83e4f2d4abebb9e32e4/csrc/multidevice/utils.cpp#L77.
// Will try to fix this in a follow-up PR and reenable the test.
TEST_F(LowerCollectiveTest, DISABLED_Allgather_LoopSplit_Noncontiguous) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

const auto num_devices = communicator_->size();
auto mesh = DeviceMesh::createForNumDevices(num_devices);

TensorView* in = makeContigTensor(2);
in->setDeviceMesh(mesh);
TensorView* out = set(in);
fusion->addInput(in);
fusion->addOutput(out);

in->split(1, num_devices, /*inner_split=*/false);
in->axis(1)->parallelize(ParallelType::DIDx);
in->setAllocationDomain(in->getLoopDomain(), true);

out->split(1, num_devices, /*inner_split=*/false);
out->setAllocationDomain(out->getLoopDomain(), true);

at::Tensor unsharded_tensor =
at::arange(2 * num_devices * 3, at::kFloat).view({2, num_devices * 3});
at::Tensor in_tensor =
shardTensor(unsharded_tensor, in).to(communicator_->device());

FusionExecutorCache fec(std::move(fusion));
at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0];
assertIsCompiledToHostIrContainer(fec);

EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use validate here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed allgather's lowering was not changed...I'm a bit surprised it didn't need any modifications for inputs with DID loop split! I might have missed a few earlier PRs though

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use validate here?

Since validate allows for (small) differences, if two tensors are supposed to be exactly the same, just using the simpler validation method, i.e., at::equal, would be more preferable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit surprised it didn't need any modifications for inputs with DID loop split!

Whether we call lowerToAllGather depends on I/O meshes and whether I/O is sharded:

lowerToAllgather(input_tv, output_tv, comms);
. isSharded have been reading the allocation domain ince #3444.

That being said, I think this PR as is is a bit too permissive and may lower a set to Allgather without properly checking its allocation domain. For example,

auto rootmap = PairwiseLogicalDomainMap(input, output).mapBroadcast(false);
reads root and logical and needs to be updated. I'll try to fix that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That being said, I think this PR as is is a bit too permissive and may lower a set to Allgather without properly checking its allocation domain.

I tried to address this in #3284 (comment).

}

TEST_F(LowerCollectiveTest, Broadcast) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
Expand Down
Loading
Loading