diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 1b2554cdabb..63b3f073b75 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -201,6 +201,7 @@ HostIrEvaluator::HostIrEvaluator( {container_->getDefaultStream(), c10::cuda::getDefaultCUDAStream( static_cast(device_index))}); + expr_evaluator_.bind("numberOfStreams", params_.number_of_streams); } std::vector HostIrEvaluator::runWithInput( diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index a51dc32aed4..70d8030c640 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -74,6 +74,9 @@ struct HostIrEvaluatorParams { // Experimental: whether to cache fusion executor. WAR: avoid recompilation // but implicitely assumes that the input shape don't change over iterations bool cache_fusion_executor = false; + // number of additional cuda streams to use at runtime for comm+compute + // pipelining + int64_t number_of_streams = 4; }; class HostIrEvaluator final : public OptOutDispatch { diff --git a/csrc/host_ir/host_ir.h b/csrc/host_ir/host_ir.h index 82d67d6f4cc..5d339b64386 100644 --- a/csrc/host_ir/host_ir.h +++ b/csrc/host_ir/host_ir.h @@ -208,6 +208,8 @@ class Wait : public Expr { } }; +// Makes the current stream wait on the given stream. Non-blocking from the host +// point of view. class Synchronize : public Expr { public: using Expr::Expr; diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 8e97b958a9a..38e7f9ac45c 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -235,6 +236,10 @@ void lowerToReduceScatter( std::vector HostIrLower::lower(Expr* c) { FusionGuard fg(c->fusion()); + if (c->isA()) { + return lowerToCollectiveBasedPipelinedGemmComm(c); + } + std::vector comms; NVF_ERROR( c->inputs().size() == 1 && c->input(0)->isA() && @@ -310,6 +315,9 @@ bool HostIrLower::canLower(Expr* expr) { return false; } if (expr->isA()) { + if (!isInnerResharding(expr)) { + return false; + } auto in = expr->as()->in()->as(); auto out = expr->as()->out()->as(); // get the reduced axis @@ -328,10 +336,124 @@ bool HostIrLower::canLower(Expr* expr) { PairwiseLogicalDomainMap(in, out).mapConsumerToProducer(); auto c2p_map_it = c2p_map.find(reduction_axis.at(0)); return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim(); - } else { - return expr->isA() && - (expr->as()->opType() == LoadStoreOpType::Set); + } else if (expr->isA()) { + return isInnerResharding(expr) && + expr->as()->opType() == LoadStoreOpType::Set; + } else if (expr->as()) { + // For now we only support c = matmul(a,b) when b,c are fully replicated and + // a is sharded on axis 1 + auto* matmul = expr->as(); + return !isSharded(matmul->inB()) && !isSharded(matmul->out()) && + matmul->inA()->axis(0)->getParallelType() == ParallelType::Serial && + getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1; + } + return false; +} + +std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( + Expr* expr) { + auto matmul = expr->as(); + NVF_ERROR(matmul != nullptr, "Expect a MatmulOp, got", expr); + TensorView* tva = matmul->inA(); + TensorView* tvb = matmul->inB(); + TensorView* tvc = matmul->out(); + NVF_ERROR( + !isSharded(tvb), "The B operand ", tvb, " is expected to not be sharded"); + NVF_ERROR( + !isSharded(tvc), + "The output ", + matmul->out(), + " is expected to not be sharded"); + const int64_t sharded_axis_index = + getShardedLogicalAxis(tva, ParallelType::DIDx); + IterDomain* stream_axis = tva->axis(0); + NVF_ERROR( + stream_axis->getParallelType() == ParallelType::Serial && + sharded_axis_index == 1, + "The operand A ", + tva, + " is expected to be sharded on the dimension 1"); + + auto hic = FusionGuard::getCurFusion()->as(); + + auto* get_current_stream = IrBuilder::create(); + hir::Stream* original_stream = get_current_stream->stream(); + + TensorView* tva_allgathered = + ops::newValLike(tva, tva->dtype())->as(); + tva_allgathered->axis(sharded_axis_index)->parallelize(ParallelType::Serial); + tva_allgathered->setMemoryType(MemoryType::Global); + auto* allocate_tva_allgathered = + IrBuilder::create(tva_allgathered, MemoryType::Global); + + tvc->setMemoryType(MemoryType::Global); + auto* allocate_tvc = + IrBuilder::create(tvc, MemoryType::Global); + + auto* j = + IrBuilder::create(DataType::Index); // running index of the for-loop + auto* start = hic->zeroVal(); + auto* stop = stream_axis->extent(); + auto* step = hic->oneVal(); + auto* for_loop = IrBuilder::create( + stream_axis, + /*index=*/j, + start, + stop, + step, + /*vectorize=*/false, + /*vectorize_shift=*/nullptr, + /*unroll_required=*/false, + CircularBufferLoopStage::NotApplicable, + /*circular_buffer_loop_stage_depth=*/0); + + auto* number_of_streams = + IrBuilder::create("numberOfStreams", DataType::Int); + auto* stream_index = mod(j, number_of_streams); + auto* stream = IrBuilder::create(stream_index); + auto* set_stream = IrBuilder::create(stream); + + TensorView* tva_j = select(tva, 0, j); + TensorView* tva_allgathered_j = select(tva_allgathered, 0, j); + TensorView* tvc_j = select(tvc, 0, j); + + NVF_ERROR( + tva->hasDeviceMesh(), + "The matmul's input ", + tva, + "is expected to have a DeviceMesh"); + for (auto tv : {tva_j, tva_allgathered_j, tvc_j}) { + tv->setDeviceMesh(tva->getDeviceMesh()); } + + auto* communication = IrBuilder::create( + CommunicationType::Allgather, + /*out=*/tva_allgathered_j, + /*in=*/tva_j, + /*team=*/tva->getDeviceMesh().vector()); + auto* wait = IrBuilder::create(communication); + + auto* mm = IrBuilder::create(tvc_j, tva_allgathered_j, tvb); + + auto* set_back_original_stream = + IrBuilder::create(original_stream); + auto* sync_stream = IrBuilder::create(stream); + + std::vector loop_body = { + set_stream, + tva_j->definition(), + tva_allgathered_j->definition(), + communication, + wait, + tvc_j->definition(), + mm, + set_back_original_stream, + sync_stream}; + for (Expr* expr : loop_body) { + for_loop->body().push_back(expr); + } + + return {get_current_stream, allocate_tva_allgathered, allocate_tvc, for_loop}; } std::unique_ptr HostIrLower::lower( @@ -396,21 +518,19 @@ std::unique_ptr HostIrLower::lower( "Communication segments must contain only one Expr"); for (auto* expr : HostIrLower::lower(ir_cloner.clone(group->exprs().at(0)))) { + hic->pushBackTopLevelExprs(expr); // Allocate the recv buffers of communications - NVF_ERROR( - expr->isA(), - "Expected a Communication but got ", - expr); - auto* communication = expr->as(); - TensorView* tv = communication->out(); - if (tv->getDeviceMesh().has(my_device_index)) { - auto* allocate = - IrBuilder::create(tv, MemoryType::Global); - hic->pushBackTopLevelExprs(allocate); + if (expr->isA()) { + auto* communication = expr->as(); + TensorView* tv = communication->out(); + if (tv->getDeviceMesh().has(my_device_index)) { + auto* allocate = + IrBuilder::create(tv, MemoryType::Global); + hic->pushBackTopLevelExprs(allocate); + } + auto wait = IrBuilder::create(communication); + hic->pushBackTopLevelExprs(wait); } - hic->pushBackTopLevelExprs(communication); - auto wait = IrBuilder::create(communication); - hic->pushBackTopLevelExprs(wait); } } else { auto host_unit = IrBuilder::create( diff --git a/csrc/host_ir/lower.h b/csrc/host_ir/lower.h index 6a1d44247d2..c99e8a8e76b 100644 --- a/csrc/host_ir/lower.h +++ b/csrc/host_ir/lower.h @@ -24,6 +24,9 @@ class HostIrLower { static std::unique_ptr lower( std::unique_ptr fusion, int64_t my_device_index); + + private: + static std::vector lowerToCollectiveBasedPipelinedGemmComm(Expr* expr); }; } // namespace nvfuser diff --git a/csrc/preseg_passes/reorder_sharded_axis.cpp b/csrc/preseg_passes/reorder_sharded_axis.cpp index f6359cb424e..bf68f6aa9c5 100644 --- a/csrc/preseg_passes/reorder_sharded_axis.cpp +++ b/csrc/preseg_passes/reorder_sharded_axis.cpp @@ -25,7 +25,7 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) { const std::vector& exprs = fusion->exprs(); for (auto it = std::rbegin(exprs); it != std::rend(exprs); it++) { Expr* expr = *it; - if (!isResharding(expr)) { + if (HostIrLower::canLower(expr)) { continue; } NVF_ERROR( diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 27e9477aede..ceeef81fa1f 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -5,6 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include #include #include #include @@ -349,6 +350,60 @@ TEST_F(P2PCommHostIrTest, CoalescedRingPairwiseExchange) { EXPECT_TRUE(torch::allclose(ref_output, outputs.back())); } +using OverlapDistributedMatmulTest = MultiDeviceTest; + +TEST_F(OverlapDistributedMatmulTest, AG_matmul) { + constexpr int64_t M = 32768; + constexpr int64_t K = 32768; + constexpr int64_t N = 1024; + constexpr int64_t S = 8; + const int64_t D = communicator_->size(); + ASSERT_EQ(M % (D * S), 0); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* a = makeContigTensor(4); //[S, DIDx(D), M/(S*d), K] + TensorView* b = makeContigTensor(2); //[K, N] + TensorView* c = matmul(a, b); //[S, D, M/(S*D), N] + + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + + auto mesh = DeviceMesh::createForNumDevices(D); + a->setDeviceMesh(mesh); + b->setDeviceMesh(mesh); + c->setDeviceMesh(mesh); + + a->axis(1)->parallelize(ParallelType::DIDx); + + MultiDeviceExecutor executor(std::move(fusion), *communicator_); + + auto tensor_options = + at::TensorOptions().dtype(at::kFloat).device(communicator_->device()); + at::Tensor ta_unsharded = at::randn({S, D, M / (S * D), K}, tensor_options); + at::Tensor ta = ta_unsharded.slice( + 1, communicator_->deviceId(), communicator_->deviceId() + 1); + at::Tensor tb = at::randn({K, N}, tensor_options); + at::Tensor tc_ref = at::matmul(ta_unsharded, tb); + + std::vector inputs = {ta, tb}; + at::Tensor tc; + + constexpr int64_t number_of_iterations = 20; + constexpr int64_t number_of_warmup_iterations = 5; + for (const auto& i : c10::irange(number_of_iterations)) { + if (i == number_of_warmup_iterations) { + cudaProfilerStart(); + } + tc = executor.runWithInput(inputs).at(0); + } + cudaProfilerStop(); + + EXPECT_TRUE(torch::allclose(tc_ref, tc)); +} + } // namespace hir } // namespace nvfuser