Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allgather with DID loop split #3284

Merged
merged 23 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/multidevice/lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void lowerToReduceScatter(
std::vector<Communication*>& comms) {
const DeviceMesh& mesh = input_tv->getDeviceMesh();
auto reduction_axis = output_tv->getReductionAxis().value();
auto scattered_axis = getShardedAxis(output_tv, ParallelType::DIDx);
auto scattered_axis = getShardedLogicalAxis(output_tv, ParallelType::DIDx);
// The output tensor is sharded on scattered_axis and needs to be mapped
// back onto the input. The input has an reduced axis, so the scattered axis
// is adjusted to account for this. Ex: [DIDx(i0), i1] -> [r0, DIDx(i1)] The
Expand Down
36 changes: 33 additions & 3 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ std::unordered_map<IterDomain*, int64_t> mapIterDomainToTensorAxis(

} // namespace

int64_t getShardedAxis(const TensorView* tv, const ParallelType parallel_type) {
int64_t getShardedLogicalAxis(
const TensorView* tv,
const ParallelType parallel_type) {
std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id =
mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain());
IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type);
Expand All @@ -179,7 +181,35 @@ int64_t getShardedAxis(const TensorView* tv, const ParallelType parallel_type) {
"Failed to find a non-reduction logical IterDomain that produces ",
alloc_id);
if (auto* split = dynamic_cast<Split*>(def)) {
// FIXME: comment
// Returning just which tensor axis is sharded isn't sufficient to let
// shardTensor, a user of this function, know how to shard the tensor.
// For example,
//
// t = makeContigConcreteTensor({6});
// t->split(0, 2, /*inner_split=*/true);
// t->axis(-1)->parallelize(DIDx);
// // [i{3}, iDIDx{2}]
//
// and the unsharded tensor is [0, 1, 2, 3, 4, 5], regardless of the
// stride. The sharded tensor ought to be [0, 2, 4] for GPU 0 and [1, 3,
// 5] for GPU 1. However, shardTensor as is will return [0, 1, 2] and [3,
// 4, 5], assuming the axis is sharded outermost.
//
// One potential way to solve the general problem is to replay and rewind
// the splits on the at::Tensor. For example,
//
// t = makeContigConcreteTensor({30});
// t->split(0, 5);
// t->split(0, 3);
// t->axis(0)->parallelize(Host);
// t->axis(1)->parallelize(DIDx);
// // [iHost{2}, iDIDx{3}, i{5}]
//
// Given an unsharded at::Tensor of shape [30], we'll first replay the
// splits using `torch.view` to get a tensor of shape [2,3,5]. Then, we
// `torch.slice` axis 1 for DIDx to get a tensor of shape [2,1,5]. Then,
// we rewind the splits (and therefore apply merging) using
// `torch.reshape` to get a sharded tensor of shape [10].
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
NVF_ERROR(
split->outer() == id,
"Currently, we don't support DID on inner splits: ",
Expand Down Expand Up @@ -212,7 +242,7 @@ std::vector<int64_t> unshardedSizes(
c10::IntArrayRef sizes) {
std::vector<int64_t> unsharded_sizes = sizes.vec();
for (ParallelType parallel_type : kParallelTypeDIDs) {
const int64_t sharded_axis = getShardedAxis(tv, parallel_type);
const int64_t sharded_axis = getShardedLogicalAxis(tv, parallel_type);
if (sharded_axis == -1) {
continue;
}
Expand Down
8 changes: 7 additions & 1 deletion csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,13 @@ void unshard(TensorView*);
// Returns the index of the sharded logical axis that produces the allocation
// IterDomain sharded on `parallel_type`. If `tv` isn't sharded on the parallel
// type, returns -1.
int64_t getShardedAxis(const TensorView* tv, ParallelType parallel_type);
//
// This is used to correlate `tv` and its corresponding at::Tensor, e.g., by
// `unshardedSizes` and `shardTensor`. `at::Tensor::sizes` and
// `tv->getLogicalDomain()` map one-to-one modulo reduction. However, a size in
// `at::Tensor::sizes` is a factor of the corresponding logical IterDomain's
// extent if that IterDomain is sharded.
int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type);

// Reorders a TensorView so that the DID parallelized axis are in front.
void reorderDIDToFront(TensorView*);
Expand Down
6 changes: 4 additions & 2 deletions tests/cpp/multidevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ at::Tensor MultiDeviceTest::shardTensor(at::Tensor tensor, TensorView* tv) {
}
NVF_ERROR(tv->hasDeviceMesh(), "`tv` has no DeviceMesh: ", tv);
return shardTensor(
tensor, getShardedAxis(tv, ParallelType::DIDx), tv->getDeviceMesh());
tensor,
getShardedLogicalAxis(tv, ParallelType::DIDx),
tv->getDeviceMesh());
}

at::Tensor MultiDeviceTest::shardTensor(
Expand All @@ -147,7 +149,7 @@ at::Tensor MultiDeviceTest::shardTensor(
i = (i < 0) ? 0 : i;
// The following slicing is problematic when DID is on an inner split (cf.
// MultiDeviceTest.ShardTensor_InnerSplit). We currently disallow that and
// it's enforced by getShardedAxis.
// it's enforced by getShardedLogicalAxis.
return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous();
}

Expand Down
Loading