Skip to content

Commit

Permalink
Try to reuse getShardedAxis
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Nov 27, 2024
1 parent fe0cec6 commit f0fbaab
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 87 deletions.
2 changes: 1 addition & 1 deletion csrc/multidevice/lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void lowerToReduceScatter(
std::vector<Communication*>& comms) {
const DeviceMesh& mesh = input_tv->getDeviceMesh();
auto reduction_axis = output_tv->getReductionAxis().value();
auto scattered_axis = getShardedAxis(output_tv);
auto scattered_axis = getShardedAxis(output_tv, ParallelType::DIDx);
// The output tensor is sharded on scattered_axis and needs to be mapped
// back onto the input. The input has an reduced axis, so the scattered axis
// is adjusted to account for this. Ex: [DIDx(i0), i1] -> [r0, DIDx(i1)] The
Expand Down
134 changes: 66 additions & 68 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,48 +121,77 @@ bool isSharded(const TensorView* tv) {
return is_sharded;
}

std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes) {
std::vector<int64_t> unsharded_sizes = sizes.vec();

for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
const ParallelType parallel_type = alloc_id->getParallelType();
namespace {
// Collect device-parallel IterDomains in `domain` and return them as a
// ParallelType-to-IterDomain map.
std::unordered_map<ParallelType, IterDomain*> mapDeviceParallelTypeToId(
const std::vector<IterDomain*>& domain) {
std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id;
parallel_type_to_id.reserve(kParallelTypeDIDs.size());
for (IterDomain* id : domain) {
const ParallelType parallel_type = id->getParallelType();
if (!isParallelTypeDeviceDim(parallel_type)) {
continue;
}

const auto inputs = IterVisitor::getInputsTo(
{alloc_id},
{tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()});
NVF_ERROR(
!inputs.empty(),
"IterVisitor::getInputsTo shouldn't return empty unless `of` is empty.");
NVF_ERROR(
inputs.size() == 1,
"Failed to find the single logical input to ",
alloc_id,
". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ",
toDelimitedString(inputs));

const auto iter = std::find(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
inputs[0]);
NVF_ERROR(
iter != tv->getLogicalDomain().end(),
"The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ",
inputs[0]);

// Count the number of non-reduction IterDomains before `iter`. Reduction
// IterDomains are not materialized in the at::Tensor's shape.
const auto index = std::count_if(
tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool {
return !id->isReduction();
});
unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type);
parallel_type_to_id.try_emplace(parallel_type, id).second,
"Found multiple loop IterDomains with the same parallel type (",
parallel_type,
"): ",
toDelimitedString(domain));
}
return parallel_type_to_id;
}
} // namespace

int64_t getShardedAxis(const TensorView* tv, const ParallelType parallel_type) {
std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id =
mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain());
IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type);
if (alloc_id == nullptr) {
return -1;
}

const auto inputs = IterVisitor::getInputsTo(
{alloc_id},
{tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()});
NVF_ERROR(
!inputs.empty(),
"IterVisitor::getInputsTo shouldn't return empty unless `of` is empty.");
NVF_ERROR(
inputs.size() == 1,
"Failed to find the single logical input to ",
alloc_id,
". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ",
toDelimitedString(inputs));

const auto iter = std::find(
tv->getLogicalDomain().begin(), tv->getLogicalDomain().end(), inputs[0]);
NVF_ERROR(
iter != tv->getLogicalDomain().end(),
"The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ",
inputs[0]);

// Count the number of non-reduction IterDomains before `iter`. Reduction
// IterDomains are not materialized in the at::Tensor's shape.
return std::count_if(
tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool {
return !id->isReduction();
});
}

std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes) {
std::vector<int64_t> unsharded_sizes = sizes.vec();
for (ParallelType parallel_type : kParallelTypeDIDs) {
const int64_t sharded_axis = getShardedAxis(tv, parallel_type);
if (sharded_axis == -1) {
continue;
}
unsharded_sizes.at(sharded_axis) *= tv->getDeviceMesh().size(parallel_type);
}
return unsharded_sizes;
}

Expand All @@ -174,27 +203,6 @@ int64_t numDeviceDims(const TensorView* tv) {
}

namespace {
// Collect device-parallel IterDomains in `loop_domain` and return them as a
// ParallelType-to-IterDomain map.
std::unordered_map<ParallelType, IterDomain*> mapParallelTypeToId(
const std::vector<IterDomain*>& loop_domain) {
std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id;
parallel_type_to_id.reserve(kParallelTypeDIDs.size());
for (IterDomain* loop_id : loop_domain) {
const ParallelType parallel_type = loop_id->getParallelType();
if (!isParallelTypeDeviceDim(parallel_type)) {
continue;
}

NVF_ERROR(
parallel_type_to_id.try_emplace(parallel_type, loop_id).second,
"Found multiple loop IterDomains with the same parallel type (",
parallel_type,
"): ",
toDelimitedString(loop_domain));
}
return parallel_type_to_id;
}

std::vector<IterDomain*> getInputsInTargetDomain(
IterDomain* loop_id,
Expand Down Expand Up @@ -294,9 +302,9 @@ bool haveDifferentShardings(
// 3. Check if the two loop IterDomains are almost-exactly mapped in the
// IdModel.
std::unordered_map<ParallelType, IterDomain*> p_parallel_type_to_id =
mapParallelTypeToId(producer->getLoopDomain());
mapDeviceParallelTypeToId(producer->getLoopDomain());
std::unordered_map<ParallelType, IterDomain*> c_parallel_type_to_id =
mapParallelTypeToId(consumer->getLoopDomain());
mapDeviceParallelTypeToId(consumer->getLoopDomain());

for (const auto parallel_type : kParallelTypeDIDs) {
IterDomain* p_loop_id = getOrDefault(p_parallel_type_to_id, parallel_type);
Expand Down Expand Up @@ -502,16 +510,6 @@ std::set<DeviceIdxType> involvedDevices(Expr* expr) {
return ret;
}

int64_t getShardedAxis(TensorView* tv) {
auto ids = TensorDomain::noReductions(tv->getMaybeAllocationDomain());
for (const auto i : c10::irange(ids.size())) {
if (ids[i]->getParallelType() == ParallelType::DIDx) {
return static_cast<int64_t>(i);
}
}
return -1;
}

void reorderDIDToFront(TensorView* tv) {
// new position to old position
std::unordered_map<int64_t, int64_t> order_map;
Expand Down
6 changes: 3 additions & 3 deletions csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ int64_t requestedNumberOfDevices(Fusion*);
void unshard(Fusion*);
void unshard(TensorView*);

// Returns the index of the a sharded axis if none return -1.
// TODO: Assumes no merges/splits on sharded axis.
int64_t getShardedAxis(TensorView*);
// Returns the index of the sharded logical axis corresponding to
// `parallel_type`. If `tv` isn't sharded on the parallel type, returns -1.
int64_t getShardedAxis(const TensorView* tv, ParallelType parallel_type);

// Reorders a TensorView so that the DID parallelized axis are in front.
void reorderDIDToFront(TensorView*);
Expand Down
11 changes: 3 additions & 8 deletions tests/cpp/multidevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ at::Tensor MultiDeviceTest::shardTensor(at::Tensor tensor, TensorView* tv) {
return tensor;
}
NVF_ERROR(tv->hasDeviceMesh(), "`tv` has no DeviceMesh: ", tv);
return shardTensor(tensor, getShardedAxis(tv), tv->getDeviceMesh());
return shardTensor(
tensor, getShardedAxis(tv, ParallelType::DIDx), tv->getDeviceMesh());
}

at::Tensor MultiDeviceTest::shardTensor(
Expand All @@ -144,13 +145,7 @@ at::Tensor MultiDeviceTest::shardTensor(
auto stride = extent / nslices;
// TODO: returning slice 0 temporarily when device is not in the mesh.
i = (i < 0) ? 0 : i;
auto slice = tensor.slice(axis, i * stride, (i + 1) * stride).contiguous();
// Temporary until https://github.com/NVIDIA/Fuser/issues/2563. Adds DIDx
// axis in front representing the sharded extent of the tensor.
if (stride > 1) {
slice = slice.unsqueeze(0);
}
return slice;
return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous();
}

} // namespace nvfuser
Expand Down
10 changes: 3 additions & 7 deletions tests/cpp/test_multidevice_lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ TEST_F(LowerCollectiveTest, Allgather) {
EXPECT_TRUE(at::equal(out_tensor, unsharded_tensor));
}

TEST_F(LowerCollectiveTest, Allgather_SplitLoop) {
TEST_F(LowerCollectiveTest, Allgather_LoopSplit) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand All @@ -224,12 +224,8 @@ TEST_F(LowerCollectiveTest, Allgather_SplitLoop) {

at::Tensor unsharded_tensor =
at::randn({num_devices * kTensorSize}, at::kFloat);
at::Tensor in_tensor = unsharded_tensor
.slice(
0,
communicator_->deviceId() * kTensorSize,
(communicator_->deviceId() + 1) * kTensorSize)
.to(communicator_->device());
at::Tensor in_tensor =
shardTensor(unsharded_tensor, in).to(communicator_->device());

FusionExecutorCache fec(std::move(fusion));
at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0];
Expand Down

0 comments on commit f0fbaab

Please sign in to comment.