diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 522b755b96d..af122ee6e3d 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -328,8 +328,8 @@ c10::intrusive_ptr postAllgather( c10d::Backend* backend, at::Tensor input_tensor, at::Tensor output_tensor) { - auto splits = at::split(output_tensor, /*split_size=*/1, /*dim=*/0); - assertBufferCount(splits, communication->team().size()); + auto splits = + at::tensor_split(output_tensor, communication->team_size(), /*dim=*/0); assertBuffersHaveSameSize({input_tensor}, splits); // allgather primitive in c10d induces extra buffering time to copy out the diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 45c104b36d3..8631a1a04e5 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -90,6 +90,11 @@ class Communication : public Expr { return attribute(1); } + // A convenience helper so the user doesn't need to convert size_t to int64_t. + int64_t team_size() const { + return static_cast(team().size()); + } + DeviceIdxType root() const { return attribute(2); } diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index c8068b5a113..4b878ac7376 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -196,7 +196,7 @@ void lowerToReduceScatter( std::vector& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); auto reduction_axis = output_tv->getReductionAxis().value(); - auto scattered_axis = getShardedAxis(output_tv); + auto scattered_axis = getShardedLogicalAxis(output_tv, ParallelType::DIDx); // The output tensor is sharded on scattered_axis and needs to be mapped // back onto the input. The input has an reduced axis, so the scattered axis // is adjusted to account for this. Ex: [DIDx(i0), i1] -> [r0, DIDx(i1)] The diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 24b7e582104..54f1303bc16 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -121,48 +121,133 @@ bool isSharded(const TensorView* tv) { return is_sharded; } -std::vector unshardedSizes( - const TensorView* tv, - c10::IntArrayRef sizes) { - std::vector unsharded_sizes = sizes.vec(); - - for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { - const ParallelType parallel_type = alloc_id->getParallelType(); +namespace { +// Collect device-parallel IterDomains in `domain` and return them as a +// ParallelType-to-IterDomain map. +std::unordered_map mapDeviceParallelTypeToId( + const std::vector& domain) { + std::unordered_map parallel_type_to_id; + parallel_type_to_id.reserve(kParallelTypeDIDs.size()); + for (IterDomain* id : domain) { + const ParallelType parallel_type = id->getParallelType(); if (!isParallelTypeDeviceDim(parallel_type)) { continue; } - const auto inputs = IterVisitor::getInputsTo( - {alloc_id}, - {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}); NVF_ERROR( - !inputs.empty(), - "IterVisitor::getInputsTo shouldn't return empty unless `of` is empty."); - NVF_ERROR( - inputs.size() == 1, - "Failed to find the single logical input to ", - alloc_id, - ". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ", - toDelimitedString(inputs)); - - const auto iter = std::find( - tv->getLogicalDomain().begin(), - tv->getLogicalDomain().end(), - inputs[0]); + parallel_type_to_id.try_emplace(parallel_type, id).second, + "Found multiple loop IterDomains with the same parallel type (", + parallel_type, + "): ", + toDelimitedString(domain)); + } + return parallel_type_to_id; +} + +std::unordered_map mapIterDomainToTensorAxis( + const std::vector& domain) { + std::unordered_map id_to_axis; + int64_t axis = 0; + for (auto* id : domain) { + // Reduction IterDomains are not materialized as an at::Tensor axis. + if (id->isReduction()) { + continue; + } + id_to_axis[id] = axis; + axis++; + } + return id_to_axis; +} + +} // namespace + +int64_t getShardedLogicalAxis( + const TensorView* tv, + const ParallelType parallel_type) { + std::unordered_map parallel_type_to_id = + mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain()); + IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type); + if (alloc_id == nullptr) { + return -1; + } + + std::unordered_map logical_id_to_axis = + mapIterDomainToTensorAxis(tv->getLogicalDomain()); + IterDomain* id = alloc_id; + while (logical_id_to_axis.count(id) == 0) { + Expr* def = id->definition(); NVF_ERROR( - iter != tv->getLogicalDomain().end(), - "The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ", - inputs[0]); - - // Count the number of non-reduction IterDomains before `iter`. Reduction - // IterDomains are not materialized in the at::Tensor's shape. - const auto index = std::count_if( - tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool { - return !id->isReduction(); - }); - unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type); + def != nullptr, + "Failed to find a non-reduction logical IterDomain that produces ", + alloc_id); + if (auto* split = dynamic_cast(def)) { + // Returning just which tensor axis is sharded isn't sufficient to let + // shardTensor, a user of this function, know how to shard the tensor. + // For example, + // + // t = makeContigConcreteTensor({6}); + // t->split(0, 2, /*inner_split=*/true); + // t->axis(-1)->parallelize(DIDx); + // // [i{3}, iDIDx{2}] + // + // and the unsharded tensor is [0, 1, 2, 3, 4, 5], regardless of the + // stride. The sharded tensor ought to be [0, 2, 4] for GPU 0 and [1, 3, + // 5] for GPU 1. However, shardTensor as is will return [0, 1, 2] and [3, + // 4, 5], assuming the axis is sharded outermost. + // + // One potential way to solve the general problem is to replay and rewind + // the splits on the at::Tensor. For example, + // + // t = makeContigConcreteTensor({30}); + // t->split(0, 5); + // t->split(0, 3); + // t->axis(0)->parallelize(Host); + // t->axis(1)->parallelize(DIDx); + // // [iHost{2}, iDIDx{3}, i{5}] + // + // Given an unsharded at::Tensor of shape [30], we'll first replay the + // splits using `torch.view` to get a tensor of shape [2,3,5]. Then, we + // `torch.slice` axis 1 for DIDx to get a tensor of shape [2,1,5]. Then, + // we rewind the splits (and therefore apply merging) using + // `torch.reshape` to get a sharded tensor of shape [10]. + NVF_ERROR( + split->outer() == id, + "Currently, we don't support DID on inner splits: ", + split); + id = split->in(); + } else if (auto* merge = dynamic_cast(def)) { + // For example, + // + // t = makeContigTensor(2); + // t->merge(0, 1); + // t->axis(0)->parallelize(DIDx); + // + // When `unshardedSizes` is given a local tensor of shape [1, 1], it's + // unclear the global shape is [1, D] or [D, 1] or even [2, D/2], etc. + NVF_THROW( + "Failed to attribute the sharding to a single tensor axis and therefore bailed out: ", + merge); + } else { + NVF_THROW( + "Unexpected transforms from logical to a DID-parallel allocation IterDomain: ", + def); + } } + return logical_id_to_axis.at(id); +} + +std::vector unshardedSizes( + const TensorView* tv, + c10::IntArrayRef sizes) { + std::vector unsharded_sizes = sizes.vec(); + for (ParallelType parallel_type : kParallelTypeDIDs) { + const int64_t sharded_axis = getShardedLogicalAxis(tv, parallel_type); + if (sharded_axis == -1) { + continue; + } + unsharded_sizes.at(sharded_axis) *= tv->getDeviceMesh().size(parallel_type); + } return unsharded_sizes; } @@ -174,27 +259,6 @@ int64_t numDeviceDims(const TensorView* tv) { } namespace { -// Collect device-parallel IterDomains in `loop_domain` and return them as a -// ParallelType-to-IterDomain map. -std::unordered_map mapParallelTypeToId( - const std::vector& loop_domain) { - std::unordered_map parallel_type_to_id; - parallel_type_to_id.reserve(kParallelTypeDIDs.size()); - for (IterDomain* loop_id : loop_domain) { - const ParallelType parallel_type = loop_id->getParallelType(); - if (!isParallelTypeDeviceDim(parallel_type)) { - continue; - } - - NVF_ERROR( - parallel_type_to_id.try_emplace(parallel_type, loop_id).second, - "Found multiple loop IterDomains with the same parallel type (", - parallel_type, - "): ", - toDelimitedString(loop_domain)); - } - return parallel_type_to_id; -} std::vector getInputsInTargetDomain( IterDomain* loop_id, @@ -294,9 +358,9 @@ bool haveDifferentShardings( // 3. Check if the two loop IterDomains are almost-exactly mapped in the // IdModel. std::unordered_map p_parallel_type_to_id = - mapParallelTypeToId(producer->getLoopDomain()); + mapDeviceParallelTypeToId(producer->getLoopDomain()); std::unordered_map c_parallel_type_to_id = - mapParallelTypeToId(consumer->getLoopDomain()); + mapDeviceParallelTypeToId(consumer->getLoopDomain()); for (const auto parallel_type : kParallelTypeDIDs) { IterDomain* p_loop_id = getOrDefault(p_parallel_type_to_id, parallel_type); @@ -502,16 +566,6 @@ std::set involvedDevices(Expr* expr) { return ret; } -int64_t getShardedAxis(TensorView* tv) { - auto ids = TensorDomain::noReductions(tv->getLogicalDomain()); - for (size_t i = 0; i < ids.size(); ++i) { - if (ids[i]->getParallelType() == ParallelType::DIDx) { - return static_cast(i); - } - } - return -1; -} - void reorderDIDToFront(TensorView* tv) { // new position to old position std::unordered_map order_map; diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 5be2e11bd15..ef88fbdcf80 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -123,9 +123,16 @@ int64_t requestedNumberOfDevices(Fusion*); void unshard(Fusion*); void unshard(TensorView*); -// Returns the index of the a sharded axis if none return -1. -// TODO: Assumes no merges/splits on sharded axis. -int64_t getShardedAxis(TensorView*); +// Returns the index of the sharded logical axis that produces the allocation +// IterDomain sharded on `parallel_type`. If `tv` isn't sharded on the parallel +// type, returns -1. +// +// This is used to correlate `tv` and its corresponding at::Tensor, e.g., by +// `unshardedSizes` and `shardTensor`. `at::Tensor::sizes` and +// `tv->getLogicalDomain()` map one-to-one modulo reduction. However, a size in +// `at::Tensor::sizes` is a factor of the corresponding logical IterDomain's +// extent if that IterDomain is sharded. +int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type); // Reorders a TensorView so that the DID parallelized axis are in front. void reorderDIDToFront(TensorView*); diff --git a/tests/cpp/multidevice.cpp b/tests/cpp/multidevice.cpp index bab5cdccc5e..22897dc5311 100644 --- a/tests/cpp/multidevice.cpp +++ b/tests/cpp/multidevice.cpp @@ -128,7 +128,10 @@ at::Tensor MultiDeviceTest::shardTensor(at::Tensor tensor, TensorView* tv) { return tensor; } NVF_ERROR(tv->hasDeviceMesh(), "`tv` has no DeviceMesh: ", tv); - return shardTensor(tensor, getShardedAxis(tv), tv->getDeviceMesh()); + return shardTensor( + tensor, + getShardedLogicalAxis(tv, ParallelType::DIDx), + tv->getDeviceMesh()); } at::Tensor MultiDeviceTest::shardTensor( @@ -144,13 +147,10 @@ at::Tensor MultiDeviceTest::shardTensor( auto stride = extent / nslices; // TODO: returning slice 0 temporarily when device is not in the mesh. i = (i < 0) ? 0 : i; - auto slice = tensor.slice(axis, i * stride, (i + 1) * stride).contiguous(); - // Temporary until https://github.com/NVIDIA/Fuser/issues/2563. Adds DIDx - // axis in front representing the sharded extent of the tensor. - if (stride > 1) { - slice = slice.unsqueeze(0); - } - return slice; + // The following slicing is problematic when DID is on an inner split (cf. + // MultiDeviceTest.ShardTensor_InnerSplit). We currently disallow that and + // it's enforced by getShardedLogicalAxis. + return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous(); } } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index 643b5b2220d..d1f06d80e1d 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -202,6 +202,73 @@ TEST_F(LowerCollectiveTest, Allgather) { EXPECT_TRUE(at::equal(out_tensor, unsharded_tensor)); } +TEST_F(LowerCollectiveTest, Allgather_LoopSplit) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto num_devices = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(num_devices); + + TensorView* in = makeContigTensor(1); + in->setDeviceMesh(mesh); + TensorView* out = set(in); + fusion->addInput(in); + fusion->addOutput(out); + + in->split(0, num_devices, /*inner_split=*/false); + in->axis(0)->parallelize(ParallelType::DIDx); + in->setAllocationDomain(in->getLoopDomain(), true); + + out->split(0, num_devices, /*inner_split=*/false); + out->setAllocationDomain(out->getLoopDomain(), true); + + at::Tensor unsharded_tensor = + at::randn({num_devices * kTensorSize}, at::kFloat); + at::Tensor in_tensor = + shardTensor(unsharded_tensor, in).to(communicator_->device()); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; + assertIsCompiledToHostIrContainer(fec); + + EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor)); +} + +// This currently fails due to getShardingChanges reads root/logical only: +// https://github.com/NVIDIA/Fuser/blob/1dda106a946adcfd1526b83e4f2d4abebb9e32e4/csrc/multidevice/utils.cpp#L77. +// Will try to fix this in a follow-up PR and reenable the test. +TEST_F(LowerCollectiveTest, DISABLED_Allgather_LoopSplit_Noncontiguous) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto num_devices = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(num_devices); + + TensorView* in = makeContigTensor(2); + in->setDeviceMesh(mesh); + TensorView* out = set(in); + fusion->addInput(in); + fusion->addOutput(out); + + in->split(1, num_devices, /*inner_split=*/false); + in->axis(1)->parallelize(ParallelType::DIDx); + in->setAllocationDomain(in->getLoopDomain(), true); + + out->split(1, num_devices, /*inner_split=*/false); + out->setAllocationDomain(out->getLoopDomain(), true); + + at::Tensor unsharded_tensor = + at::arange(2 * num_devices * 3, at::kFloat).view({2, num_devices * 3}); + at::Tensor in_tensor = + shardTensor(unsharded_tensor, in).to(communicator_->device()); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; + assertIsCompiledToHostIrContainer(fec); + + EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor)); +} + TEST_F(LowerCollectiveTest, Broadcast) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 3adac90bc5e..aaa5d3a3218 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -491,4 +491,52 @@ TEST_P(MultiDeviceBroadcastTest, Expanded) { INSTANTIATE_TEST_SUITE_P(, MultiDeviceBroadcastTest, testing::Bool()); +TEST_F(MultiDeviceTest, ShardTensor_OuterSplit) { + const int d = communicator_->size(); + + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv = makeContigConcreteTensor({2, d * 3}); + tv->setDeviceMesh(DeviceMesh::createForNumDevices(d)); + tv->split(1, d, /*inner_split=*/false); + tv->axis(1)->parallelize(ParallelType::DIDx); + tv->setAllocationDomain(tv->getLoopDomain(), true); + + fusion.addInput(tv); + fusion.addOutput(tv); + + at::Tensor unsharded = at::arange(2 * d * 3).view({2, d * 3}); + at::Tensor sharded = shardTensor(unsharded, tv); + + EXPECT_THAT(sharded.sizes(), ElementsAre(2, 3)); + at::Tensor expected = unsharded.view({2, d, 3}).index( + {torch::indexing::Slice(), + communicator_->deviceId(), + torch::indexing::Slice()}); + EXPECT_TRUE(at::equal(sharded, expected)); +} + +TEST_F(MultiDeviceTest, ShardTensor_InnerSplit) { + const int d = communicator_->size(); + + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv = makeContigConcreteTensor({d * 3}); + tv->setDeviceMesh(DeviceMesh::createForNumDevices(d)); + tv->split(0, d, /*inner_split=*/true); + tv->axis(-1)->parallelize(ParallelType::DIDx); + tv->setAllocationDomain(tv->getLoopDomain(), true); + + fusion.addInput(tv); + fusion.addOutput(tv); + + at::Tensor unsharded = at::arange(d * 3); + EXPECT_THAT( + [&]() { shardTensor(unsharded, tv); }, + ::testing::ThrowsMessage( + ::testing::HasSubstr("DID on inner splits"))); +} + } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index 2ef33dcdf8f..0f39ae6f6e5 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -720,14 +720,14 @@ TEST_P(DistributedTransformerTest, MLP_Layer) { std::vector inputs = { x, - shardTensor(w0, 0, mesh), - shardTensor(b0, 0, mesh), - shardTensor(w1, 1, mesh), + shardTensor(w0, 0, mesh).unsqueeze(0), + shardTensor(b0, 0, mesh).unsqueeze(0), + shardTensor(w1, 1, mesh).unsqueeze(0), b1}; std::vector expected_outputs = { - shardTensor(reference_outs[0], 1, mesh), - shardTensor(reference_outs[1], 1, mesh), + shardTensor(reference_outs[0], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), reference_outs[2], reference_outs[3]}; @@ -801,17 +801,17 @@ TEST_P(DistributedTransformerTest, Sequence_Parallel_MLP_Layer) { auto mask_ = reference_outs[4]; std::vector inputs = { - shardTensor(x_, 0, mesh), - shardTensor(w0_, 0, mesh), - shardTensor(b0_, 0, mesh), - shardTensor(w1_, 1, mesh), + shardTensor(x_, 0, mesh).unsqueeze(0), + shardTensor(w0_, 0, mesh).unsqueeze(0), + shardTensor(b0_, 0, mesh).unsqueeze(0), + shardTensor(w1_, 1, mesh).unsqueeze(0), b1_}; std::vector expected_outputs = { - shardTensor(reference_outs[0], 1, mesh), - shardTensor(reference_outs[1], 1, mesh), - shardTensor(reference_outs[2], 0, mesh), - shardTensor(reference_outs[3], 0, mesh)}; + shardTensor(reference_outs[0], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[2], 0, mesh).unsqueeze(0), + shardTensor(reference_outs[3], 0, mesh).unsqueeze(0)}; FusionExecutorCache executor_cache(std::move(fusion)); at::manual_seed(getATenRandomSeed()); @@ -866,12 +866,12 @@ TEST_P(DistributedTransformerTest, MultiheadAttention) { x, shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(b0.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(w1, 1, mesh), + shardTensor(w1, 1, mesh).unsqueeze(0), b1}; std::vector expected_outputs = { shardTensor(reference_outs[0].view({B * S, 3, E}), 2, mesh) .view({1, B * S, 3 * E / D}), - shardTensor(reference_outs[1], 1, mesh), + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), reference_outs[2], reference_outs[3]}; @@ -929,17 +929,17 @@ TEST_P(DistributedTransformerTest, MultiheadAttention_SP) { at::manual_seed(getATenRandomSeed()); auto reference_outs = reference_mha(x, w0, b0, w1, b1); std::vector inputs = { - shardTensor(x, 0, mesh), + shardTensor(x, 0, mesh).unsqueeze(0), shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(b0.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(w1, 1, mesh), + shardTensor(w1, 1, mesh).unsqueeze(0), b1}; std::vector expected_outputs = { shardTensor(reference_outs[0].view({B * S, 3, E}), 2, mesh) .view({1, B * S, 3 * E / D}), - shardTensor(reference_outs[1], 1, mesh), - shardTensor(reference_outs[2], 0, mesh), - shardTensor(reference_outs[3], 0, mesh)}; + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[2], 0, mesh).unsqueeze(0), + shardTensor(reference_outs[3], 0, mesh).unsqueeze(0)}; FusionExecutorCache fec(std::move(fusion)); at::manual_seed(getATenRandomSeed()); @@ -1003,16 +1003,16 @@ TEST_P(DistributedTransformerTest, MLP_Backward) { grad_, x_, mask_, - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), - shardTensor(linear0_, 1, mesh)}; + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), + shardTensor(linear0_, 1, mesh).unsqueeze(0)}; std::vector expected_outputs = { outs[0], // dropout grad - shardTensor(outs[1], 1, mesh), // linear1 weight grad + shardTensor(outs[1], 1, mesh).unsqueeze(0), // linear1 weight grad outs[2], // linear1 bias grad - shardTensor(outs[3], 1, mesh), // gelu grad - shardTensor(outs[4], 0, mesh), // linear0 weight grad - shardTensor(outs[5], 0, mesh), // linear0 bias grad + shardTensor(outs[3], 1, mesh).unsqueeze(0), // gelu grad + shardTensor(outs[4], 0, mesh).unsqueeze(0), // linear0 weight grad + shardTensor(outs[5], 0, mesh).unsqueeze(0), // linear0 bias grad outs[6]}; // linear0 grad x FusionExecutorCache executor_cache(std::move(fusion)); @@ -1094,22 +1094,23 @@ TEST_P(DistributedTransformerTest, MHA_Backward) { std::vector inputs = { x, shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), - shardTensor(w1, 1, mesh), + shardTensor(w1, 1, mesh).unsqueeze(0), grad, mask, - shardTensor(reference_outs[0], 1, mesh), // sdpa.output - shardTensor(reference_outs[1], 1, mesh), // sdpa.log_sumexp + shardTensor(reference_outs[0], 1, mesh).unsqueeze(0), // sdpa.output + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), // sdpa.log_sumexp reference_outs[2], // sdpa.seed reference_outs[3], // sdpa.offset - shardTensor(reference_outs[13], 1, mesh) // linear0 + shardTensor(reference_outs[13], 1, mesh).unsqueeze(0) // linear0 }; std::vector expected_outputs = { reference_outs[4], // dropout grad - shardTensor(reference_outs[5], 1, mesh), // linear1 weight grad + shardTensor(reference_outs[5], 1, mesh) + .unsqueeze(0), // linear1 weight grad reference_outs[6], // linear1 bias grad - shardTensor(reference_outs[7], 1, mesh), // q grad - shardTensor(reference_outs[8], 1, mesh), // k grad - shardTensor(reference_outs[9], 1, mesh), // v grad + shardTensor(reference_outs[7], 1, mesh).unsqueeze(0), // q grad + shardTensor(reference_outs[8], 1, mesh).unsqueeze(0), // k grad + shardTensor(reference_outs[9], 1, mesh).unsqueeze(0), // v grad shardTensor(reference_outs[10].view({3, E, E}), 1, mesh) .view({1, 3 * E / D, E}), // linear0 weight grad shardTensor(reference_outs[11].view({3, E}), 1, mesh) @@ -1234,26 +1235,26 @@ TEST_P(DistributedTransformerTest, Forward_SP) { auto at_out = (resid0_ + mlp_out_).to(at_dtype); std::vector inputs = { - shardTensor(x_, 0, mesh), + shardTensor(x_, 0, mesh).unsqueeze(0), ln0_w_, ln0_b_, shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(mha_b0_.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(mha_w1_, 1, mesh), + shardTensor(mha_w1_, 1, mesh).unsqueeze(0), mha_b1_, ln1_w_, ln1_b_, - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_b0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_b0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), mlp_b1_}; std::vector expected_outputs = { - shardTensor(ln0_out_, 0, mesh), - shardTensor(mha_out_, 0, mesh), - shardTensor(ln1_out_, 0, mesh), - shardTensor(mlp_out_, 0, mesh), - shardTensor(at_out, 0, mesh)}; + shardTensor(ln0_out_, 0, mesh).unsqueeze(0), + shardTensor(mha_out_, 0, mesh).unsqueeze(0), + shardTensor(ln1_out_, 0, mesh).unsqueeze(0), + shardTensor(mlp_out_, 0, mesh).unsqueeze(0), + shardTensor(at_out, 0, mesh).unsqueeze(0)}; FusionExecutorCache fec(std::move(fusion)); at::manual_seed(getATenRandomSeed()); @@ -1367,13 +1368,13 @@ TEST_P(DistributedTransformerTest, Forward) { ln0_b_, shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(mha_b0_.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(mha_w1_, 1, mesh), + shardTensor(mha_w1_, 1, mesh).unsqueeze(0), mha_b1_, ln1_w_, ln1_b_, - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_b0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_b0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), mlp_b1_}; std::vector expected_outputs = { @@ -1620,13 +1621,16 @@ TEST_P(DistributedTransformerTest, Backward) { auto dx_ = (ln0_x_grad_ + resid1_grad_).to(at_dtype); auto expected_outputs = { - shardTensor(mlp_grads_[1], 1, mesh), // mlp_linear1_weight_grad + shardTensor(mlp_grads_[1], 1, mesh) + .unsqueeze(0), // mlp_linear1_weight_grad mlp_grads_[2], // mlp_linear1_bias_grad - shardTensor(mlp_grads_[4], 0, mesh), // mlp_linear0_weight_grad - shardTensor(mlp_grads_[5], 0, mesh), // mlp_linear0_bias_grad + shardTensor(mlp_grads_[4], 0, mesh) + .unsqueeze(0), // mlp_linear0_weight_grad + shardTensor(mlp_grads_[5], 0, mesh).unsqueeze(0), // mlp_linear0_bias_grad ln1_w_grad_, ln1_b_grad_, - shardTensor(mha_grads_[5], 1, mesh), // mha linear1 weight grad + shardTensor(mha_grads_[5], 1, mesh) + .unsqueeze(0), // mha linear1 weight grad mha_grads_[6], // mha linear1 bias grad shardTensor( mha_grads_[10].view({3, E, E}), 1, mesh) // failing starting here @@ -1641,13 +1645,13 @@ TEST_P(DistributedTransformerTest, Backward) { x_, grad_, shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), - shardTensor(mha_w1_, 1, mesh), - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), + shardTensor(mha_w1_, 1, mesh).unsqueeze(0), + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), mlp_out_[4], // mlp dropout mask mha_out_[4], // mha dropout mask - shardTensor(mha_grads_[0], 1, mesh), // sdpa output - shardTensor(mha_grads_[1], 1, mesh), // sdpa logsum_exp + shardTensor(mha_grads_[0], 1, mesh).unsqueeze(0), // sdpa output + shardTensor(mha_grads_[1], 1, mesh).unsqueeze(0), // sdpa logsum_exp mha_grads_[2], // sdpa seed mha_grads_[3], // sdpa offset ln1_w_, @@ -1658,9 +1662,9 @@ TEST_P(DistributedTransformerTest, Backward) { ln0_b_, ln0_mean_, ln0_rstd_, - shardTensor(mha_out_[0], 1, mesh), // mha linear0 + shardTensor(mha_out_[0], 1, mesh).unsqueeze(0), // mha linear0 mha_out_[2].to(at::kFloat), // mha linear1 - shardTensor(mlp_out_[0], 1, mesh) // mlp linear1 + shardTensor(mlp_out_[0], 1, mesh).unsqueeze(0) // mlp linear1 }; FusionExecutorCache executor_cache(std::move(fusion));