diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 5a3c06e6891..8c8e514419e 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -200,6 +200,7 @@ HostIrEvaluator::HostIrEvaluator( {container_->getDefaultStream(), c10::cuda::getDefaultCUDAStream( static_cast(device_index))}); + expr_evaluator_.bind("numberOfStreams", params_.number_of_streams); } std::vector HostIrEvaluator::runWithInput( diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index 29d9eb5dd8f..e47624ffa6d 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -74,6 +74,7 @@ struct HostIrEvaluatorParams { // Experimental: whether to cache fusion executor. WAR: avoid recompilation // but implicitely assumes that the input shape don't change over iterations bool cache_fusion_executor = false; + int64_t number_of_streams = 4; }; class HostIrEvaluator final : public OptOutDispatch { diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 8e97b958a9a..46addba4b17 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -235,6 +236,10 @@ void lowerToReduceScatter( std::vector HostIrLower::lower(Expr* c) { FusionGuard fg(c->fusion()); + if (c->isA()) { + return lowerToCollectiveBasedPipelinedGemmComm(c); + } + std::vector comms; NVF_ERROR( c->inputs().size() == 1 && c->input(0)->isA() && @@ -310,6 +315,9 @@ bool HostIrLower::canLower(Expr* expr) { return false; } if (expr->isA()) { + if (!isInnerResharding(expr)) { + return false; + } auto in = expr->as()->in()->as(); auto out = expr->as()->out()->as(); // get the reduced axis @@ -328,10 +336,147 @@ bool HostIrLower::canLower(Expr* expr) { PairwiseLogicalDomainMap(in, out).mapConsumerToProducer(); auto c2p_map_it = c2p_map.find(reduction_axis.at(0)); return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim(); - } else { - return expr->isA() && - (expr->as()->opType() == LoadStoreOpType::Set); + } else if (expr->isA()) { + return isInnerResharding(expr) && + expr->as()->opType() == LoadStoreOpType::Set; + } else if (expr->as()) { + // For now we only support c = matmul(a,b) when b,c are fully replicated and + // a is sharded on axis 1 + auto* matmul = expr->as(); + return !isSharded(matmul->inB()) && !isSharded(matmul->out()) && + matmul->inA()->axis(0)->getParallelType() == ParallelType::Serial && + getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1; } + return false; +} + +std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( + Expr* expr) { + auto matmul = expr->as(); + NVF_ERROR(matmul != nullptr, "Expect a MatmulOp, got", expr); + TensorView* tva = matmul->inA(); + TensorView* tvb = matmul->inB(); + TensorView* tvc = matmul->out(); + NVF_ERROR( + !isSharded(tvb), "The B operand ", tvb, " is expected to be sharded"); + NVF_ERROR( + !isSharded(tvc), + "The output ", + matmul->out(), + " is expected to be sharded"); + const int64_t sharded_axis_index = + getShardedLogicalAxis(tva, ParallelType::DIDx); + IterDomain* stream_axis = tva->axis(0); + NVF_ERROR( + stream_axis->getParallelType() == ParallelType::Serial && + sharded_axis_index == 1, + "The operand A ", + tva, + " is expected to be sharded on the dimension 1"); + + auto hic = FusionGuard::getCurFusion()->as(); + + auto* get_current_stream = IrBuilder::create(); + hir::Stream* original_stream = get_current_stream->stream(); + + TensorView* tva_allgathered = + ops::newValLike(tva, tva->dtype())->as(); + tva_allgathered->axis(sharded_axis_index)->parallelize(ParallelType::Serial); + tva_allgathered->setMemoryType(MemoryType::Global); + auto* allocate_tva_allgathered = + IrBuilder::create(tva_allgathered, MemoryType::Global); + + tvc->setMemoryType(MemoryType::Global); + auto* allocate_tvc = + IrBuilder::create(tvc, MemoryType::Global); + + auto* j = + IrBuilder::create(DataType::Index); // running index of the for-loop + auto* start = hic->zeroVal(); + auto* stop = stream_axis->extent(); + auto* step = hic->oneVal(); + auto* for_loop = IrBuilder::create( + stream_axis, + /*index=*/j, + start, + stop, + step, + /*vectorize=*/false, + /*vectorize_shift=*/nullptr, + /*unroll_required=*/false, + CircularBufferLoopStage::NotApplicable, + /*circular_buffer_loop_stage_depth=*/0); + + auto* number_of_streams = + IrBuilder::create("numberOfStreams", DataType::Int); + auto* stream_index = mod(j, number_of_streams); + auto* stream = IrBuilder::create(stream_index); + auto* set_stream = IrBuilder::create(stream); + + TensorView* tva_j = select(tva, 0, j); + TensorView* tva_j_unsqueezed = tva_j; // unsqueeze(tva_j, 0); + TensorView* tva_allgathered_j = select(tva_allgathered, 0, j); + TensorView* tvc_j = select(tvc, 0, j); + + // [TAG: adding articifial outputs] + // The following line is artificial but necessary to make tva_j_unsqueeze a + // consumer of tva_j. + // + // HostIrEvaluator::handle(ForLoop*) relies on `Val::uses()` to find all + // **transitive** consumers of the loop index `j`. `tva_j_unsqueezed` is a + // bit special among all transitive consumers of `j`. It doesn't use `j` + // directly but uses `tva_j` which is a TensorView. TensorView's uses are + // built lazily by Fusion::resetTvUses. For efficiency, Fusion::resetTvUses + // only fix TensorViews that can reach outputs. Therefore, we add + // tva_j_unsqueezed as an output. Other TensorViews don't need this + // treatmenet because they are direct users of `j`, a scalar whose uses are + // built eagerly upon registration. + // + // We could have added `tvc_j` instead as an output, which transitively + // consumes `tva_j_unsqueezed`. However, `tvc_j` has two definitions, a Select + // and a MatmulOp, and StmtSort::getExprs only traverse via the first + // registered definition (i.e. the Select). This sounds like a bug -- I wonder + // how nvFuser resets the TensorView uses of a kir::Kernel, also non-SSA. + hic->addOutput(tva_j_unsqueezed); + + NVF_ERROR( + tva->hasDeviceMesh(), + "The matmul's input ", + tva, + "is expected to have a DeviceMesh"); + for (auto tv : {tva_j, tva_allgathered_j, tva_j_unsqueezed, tvc_j}) { + tv->setDeviceMesh(tva->getDeviceMesh()); + } + + auto* communication = IrBuilder::create( + CommunicationType::Allgather, + /*out=*/tva_allgathered_j, + /*in=*/tva_j_unsqueezed, + /*team=*/tva->getDeviceMesh().vector()); + auto* wait = IrBuilder::create(communication); + + auto* mm = IrBuilder::create(tvc_j, tva_allgathered_j, tvb); + + auto* set_back_original_stream = + IrBuilder::create(original_stream); + auto* sync_stream = IrBuilder::create(stream); + + std::vector loop_body = { + set_stream, + tva_j->definition(), + tva_j_unsqueezed->definition(), + tva_allgathered_j->definition(), + communication, + wait, + tvc_j->definition(), + mm, + set_back_original_stream, + sync_stream}; + for (Expr* expr : loop_body) { + for_loop->body().push_back(expr); + } + + return {get_current_stream, allocate_tva_allgathered, allocate_tvc, for_loop}; } std::unique_ptr HostIrLower::lower( @@ -396,21 +541,19 @@ std::unique_ptr HostIrLower::lower( "Communication segments must contain only one Expr"); for (auto* expr : HostIrLower::lower(ir_cloner.clone(group->exprs().at(0)))) { + hic->pushBackTopLevelExprs(expr); // Allocate the recv buffers of communications - NVF_ERROR( - expr->isA(), - "Expected a Communication but got ", - expr); - auto* communication = expr->as(); - TensorView* tv = communication->out(); - if (tv->getDeviceMesh().has(my_device_index)) { - auto* allocate = - IrBuilder::create(tv, MemoryType::Global); - hic->pushBackTopLevelExprs(allocate); + if (expr->isA()) { + auto* communication = expr->as(); + TensorView* tv = communication->out(); + if (tv->getDeviceMesh().has(my_device_index)) { + auto* allocate = + IrBuilder::create(tv, MemoryType::Global); + hic->pushBackTopLevelExprs(allocate); + } + auto wait = IrBuilder::create(communication); + hic->pushBackTopLevelExprs(wait); } - hic->pushBackTopLevelExprs(communication); - auto wait = IrBuilder::create(communication); - hic->pushBackTopLevelExprs(wait); } } else { auto host_unit = IrBuilder::create( diff --git a/csrc/host_ir/lower.h b/csrc/host_ir/lower.h index 6a1d44247d2..c99e8a8e76b 100644 --- a/csrc/host_ir/lower.h +++ b/csrc/host_ir/lower.h @@ -24,6 +24,9 @@ class HostIrLower { static std::unique_ptr lower( std::unique_ptr fusion, int64_t my_device_index); + + private: + static std::vector lowerToCollectiveBasedPipelinedGemmComm(Expr* expr); }; } // namespace nvfuser diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index 963b80812d3..7c6bdc5d3fe 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -24,7 +24,7 @@ MultiDeviceExecutor::MultiDeviceExecutor( std::unique_ptr fusion, Communicator& comm, hir::HostIrEvaluatorParams params) - : comm_(comm) { + : comm_(comm), number_of_outputs_(fusion->outputs().size()) { std::unique_ptr hic = HostIrLower::lower(std::move(fusion), comm.deviceId()); // Create the HostIrEvaluator representing the host program @@ -52,7 +52,9 @@ std::vector MultiDeviceExecutor::runWithInput( inputs.at(input_idx); } - return host_ir_executor_->runWithInput(val_to_IValue); + auto outputs = host_ir_executor_->runWithInput(val_to_IValue); + return std::vector( + outputs.end() - number_of_outputs_, outputs.end()); } std::ostream& MultiDeviceExecutor::print(std::ostream& os) { diff --git a/csrc/multidevice/executor.h b/csrc/multidevice/executor.h index 7cad0388b18..bc9d5d974b3 100644 --- a/csrc/multidevice/executor.h +++ b/csrc/multidevice/executor.h @@ -102,6 +102,12 @@ class MultiDeviceExecutor { Communicator& comm_; // holds the HostIrEvaluator used for execution std::unique_ptr host_ir_executor_; + // Store the number of outputs before it possibly gets artificially modified + // by HostIr::lower. This is undesirable but required for now. For more + // details, search for the comment in host_ir/lower.cpp tagged with "[TAG: + // adding articifial outputs]" + // TODO: fix + int64_t number_of_outputs_; }; } // namespace nvfuser diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index bc541623310..6dafe4aa819 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -947,8 +948,12 @@ TensorView* broadcast( .iter_type(IterType::Broadcast) .build()); } else { - out_domain.push_back( - IterDomainBuilder(inp_domain[iinp]).resetSchedulingParams().build()); + auto inp_id = inp_domain[iinp]; + auto out_id = IterDomainBuilder(inp_id).resetSchedulingParams().build(); + if (inp_id->isDeviceDim()) { + out_id->parallelize(inp_id->getParallelType()); + } + out_domain.push_back(out_id); iinp++; } ibdim++; diff --git a/csrc/preseg_passes/reorder_sharded_axis.cpp b/csrc/preseg_passes/reorder_sharded_axis.cpp index f6359cb424e..bf68f6aa9c5 100644 --- a/csrc/preseg_passes/reorder_sharded_axis.cpp +++ b/csrc/preseg_passes/reorder_sharded_axis.cpp @@ -25,7 +25,7 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) { const std::vector& exprs = fusion->exprs(); for (auto it = std::rbegin(exprs); it != std::rend(exprs); it++) { Expr* expr = *it; - if (!isResharding(expr)) { + if (HostIrLower::canLower(expr)) { continue; } NVF_ERROR( diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 27e9477aede..f8683effb55 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -349,6 +349,50 @@ TEST_F(P2PCommHostIrTest, CoalescedRingPairwiseExchange) { EXPECT_TRUE(torch::allclose(ref_output, outputs.back())); } +using OverlapDistributedMatmulTest = MultiDeviceTest; + +TEST_F(OverlapDistributedMatmulTest, AG_matmul) { + constexpr int64_t M = 1024; + constexpr int64_t K = 256; + constexpr int64_t N = 512; + constexpr int64_t S = 8; + const int64_t D = communicator_->size(); + ASSERT_EQ(M % (D * S), 0); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* a = makeContigTensor(4); //[S, DIDx(D), M/(S*d), K] + TensorView* b = makeContigTensor(2); //[K, N] + TensorView* c = matmul(a, b); //[S, D, M/(S*D), N] + + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + + auto mesh = DeviceMesh::createForNumDevices(D); + a->setDeviceMesh(mesh); + b->setDeviceMesh(mesh); + c->setDeviceMesh(mesh); + + a->axis(1)->parallelize(ParallelType::DIDx); + + MultiDeviceExecutor executor(std::move(fusion), *communicator_); + + auto tensor_options = + at::TensorOptions().dtype(at::kFloat).device(communicator_->device()); + at::Tensor ta_unsharded = at::randn({S, D, M / (S * D), K}, tensor_options); + at::Tensor ta = ta_unsharded.slice( + 1, communicator_->deviceId(), communicator_->deviceId() + 1); + at::Tensor tb = at::randn({K, N}, tensor_options); + at::Tensor tc_ref = at::matmul(ta_unsharded, tb); + + std::vector inputs = {ta, tb}; + auto tc = executor.runWithInput(inputs).at(0); + + EXPECT_TRUE(torch::allclose(tc_ref, tc)); +} + } // namespace hir } // namespace nvfuser