From 38721febab76f48e65558e4410653a19c49c397a Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 17 Dec 2024 23:29:15 -0800 Subject: [PATCH 01/40] Host IR: add GetCurrentStream --- csrc/dispatch.h | 1 + csrc/host_ir/executor.cpp | 8 ++++++++ csrc/host_ir/executor.h | 1 + csrc/host_ir/host_ir.cpp | 26 ++++++++++++++++++++++++++ csrc/host_ir/host_ir.h | 25 +++++++++++++++++++++++++ tests/cpp/test_host_irs.cpp | 20 ++++++++++++++++++++ 6 files changed, 81 insertions(+) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 4fe0f86cc5f..77b650b88dc 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -146,6 +146,7 @@ class Val; f(HostUnit); \ f(PostOnStream); \ f(SetCurrentStream); \ + f(GetCurrentStream); \ f(Wait); \ f(Synchronize); \ f(StartCoalescing); \ diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 8b8ce484946..2c283cb9610 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -273,6 +273,14 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) { setCurrentCUDAStream(getCUDAStream(set_current_stream->stream())); } +void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) { + c10::DeviceIndex my_device_index = + communicator_ ? communicator_->deviceId() : 0; + streams_.insert( + {get_current_stream->stream(), + c10::cuda::getCurrentCUDAStream(my_device_index)}); +} + void HostIrEvaluator::handle(Synchronize* synchronize) { getCUDAStream(synchronize->stream()).synchronize(); } diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index 7e3932c6b1a..29d9eb5dd8f 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -112,6 +112,7 @@ class HostIrEvaluator final : public OptOutDispatch { private: using OptOutDispatch::handle; void handle(SetCurrentStream* set_current_stream) override; + void handle(GetCurrentStream* get_current_stream) override; void handle(Synchronize* synchronize) override; void handle(PostOnStream* post_ir) override; void handle(Communication* communication) override; diff --git a/csrc/host_ir/host_ir.cpp b/csrc/host_ir/host_ir.cpp index 492b2b22aab..24ceda7daba 100644 --- a/csrc/host_ir/host_ir.cpp +++ b/csrc/host_ir/host_ir.cpp @@ -179,6 +179,32 @@ bool SetCurrentStream::sameAs(const Statement* other) const { return false; } +GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey) : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_ERROR(passkey.ir_container_->isA()); + auto stream = IrBuilder::createInContainer(passkey.ir_container_); + addAttribute(stream); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(GetCurrentStream) + +std::string GetCurrentStream::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "GetCurrentStream into " << stream()->toString() + << std::endl; + return ss.str(); +} + +// TODO: implement better ? +std::string GetCurrentStream::toInlineString(int indent_size) const { + NVF_CHECK(false, "Cannot be printed inline"); +} + +// TODO: implement +bool GetCurrentStream::sameAs(const Statement* other) const { + return false; +} + Wait::Wait(IrBuilderPasskey passkey, Expr* expr) : Expr(passkey, {}, {}, {expr}) { NVF_ERROR(passkey.ir_container_ != nullptr); diff --git a/csrc/host_ir/host_ir.h b/csrc/host_ir/host_ir.h index 587ffc43638..bed7d5893a7 100644 --- a/csrc/host_ir/host_ir.h +++ b/csrc/host_ir/host_ir.h @@ -161,6 +161,31 @@ class SetCurrentStream : public Expr { } }; +class GetCurrentStream : public Expr { + public: + using Expr::Expr; + GetCurrentStream(IrBuilderPasskey passkey); + + GetCurrentStream(const GetCurrentStream& other) = delete; + GetCurrentStream& operator=(const GetCurrentStream& other) = delete; + GetCurrentStream(GetCurrentStream&& other) = delete; + GetCurrentStream& operator=(GetCurrentStream&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "hir::GetCurrentStream"; + } + + bool sameAs(const Statement* other) const override; + + Stream* stream() const { + return attributes_.at(0)->as(); + } +}; + class Wait : public Expr { public: using Expr::Expr; diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index 64aa2a0564b..e97550309e1 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -513,6 +513,26 @@ TEST_F(StreamTest, HostIrDefaultStream) { c10::cuda::getDefaultCUDAStream(0), c10::cuda::getCurrentCUDAStream(0)); } +TEST_F(StreamTest, HostIrGetCurrentStream) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + auto get_stream = IrBuilder::create(); + auto current_stream = get_stream->stream(); + auto other_stream = IrBuilder::create(); + hic->pushBackTopLevelExprs(get_stream); + hic->pushBackTopLevelExprs(IrBuilder::create(other_stream)); + hic->pushBackTopLevelExprs( + IrBuilder::create(current_stream)); + + auto cuda_stream = c10::cuda::getStreamFromPool(); + setCurrentCUDAStream(cuda_stream); + + HostIrEvaluator hie(std::move(hic)); + hie.runWithInput({}); + + EXPECT_EQ(cuda_stream, c10::cuda::getCurrentCUDAStream(0)); +} + TEST_F(StreamTest, ByIndex) { constexpr int64_t kStreamIndex1 = 2; constexpr int64_t kStreamIndex2 = 3; From c4ca266bd2e3c35c48a99897a3e30f917294fcce Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 18 Dec 2024 00:46:25 -0800 Subject: [PATCH 02/40] lint --- csrc/host_ir/executor.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 2c283cb9610..5a3c06e6891 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -274,8 +274,9 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) { } void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) { - c10::DeviceIndex my_device_index = - communicator_ ? communicator_->deviceId() : 0; + auto my_device_index = communicator_ + ? static_cast(communicator_->deviceId()) + : 0; streams_.insert( {get_current_stream->stream(), c10::cuda::getCurrentCUDAStream(my_device_index)}); From b517c2b577c7b873a97e26b7545e5cf5842b0c8c Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 17 Dec 2024 23:42:29 -0800 Subject: [PATCH 03/40] lower to collective base pipeline AG+GEMM --- csrc/host_ir/executor.cpp | 1 + csrc/host_ir/executor.h | 1 + csrc/host_ir/lower.cpp | 175 ++++++++++++++++++-- csrc/host_ir/lower.h | 3 + csrc/multidevice/executor.cpp | 6 +- csrc/multidevice/executor.h | 6 + csrc/ops/alias.cpp | 9 +- csrc/preseg_passes/reorder_sharded_axis.cpp | 2 +- tests/cpp/test_multidevice_host_ir.cpp | 44 +++++ 9 files changed, 226 insertions(+), 21 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 5a3c06e6891..8c8e514419e 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -200,6 +200,7 @@ HostIrEvaluator::HostIrEvaluator( {container_->getDefaultStream(), c10::cuda::getDefaultCUDAStream( static_cast(device_index))}); + expr_evaluator_.bind("numberOfStreams", params_.number_of_streams); } std::vector HostIrEvaluator::runWithInput( diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index 29d9eb5dd8f..e47624ffa6d 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -74,6 +74,7 @@ struct HostIrEvaluatorParams { // Experimental: whether to cache fusion executor. WAR: avoid recompilation // but implicitely assumes that the input shape don't change over iterations bool cache_fusion_executor = false; + int64_t number_of_streams = 4; }; class HostIrEvaluator final : public OptOutDispatch { diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 8e97b958a9a..46addba4b17 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -235,6 +236,10 @@ void lowerToReduceScatter( std::vector HostIrLower::lower(Expr* c) { FusionGuard fg(c->fusion()); + if (c->isA()) { + return lowerToCollectiveBasedPipelinedGemmComm(c); + } + std::vector comms; NVF_ERROR( c->inputs().size() == 1 && c->input(0)->isA() && @@ -310,6 +315,9 @@ bool HostIrLower::canLower(Expr* expr) { return false; } if (expr->isA()) { + if (!isInnerResharding(expr)) { + return false; + } auto in = expr->as()->in()->as(); auto out = expr->as()->out()->as(); // get the reduced axis @@ -328,10 +336,147 @@ bool HostIrLower::canLower(Expr* expr) { PairwiseLogicalDomainMap(in, out).mapConsumerToProducer(); auto c2p_map_it = c2p_map.find(reduction_axis.at(0)); return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim(); - } else { - return expr->isA() && - (expr->as()->opType() == LoadStoreOpType::Set); + } else if (expr->isA()) { + return isInnerResharding(expr) && + expr->as()->opType() == LoadStoreOpType::Set; + } else if (expr->as()) { + // For now we only support c = matmul(a,b) when b,c are fully replicated and + // a is sharded on axis 1 + auto* matmul = expr->as(); + return !isSharded(matmul->inB()) && !isSharded(matmul->out()) && + matmul->inA()->axis(0)->getParallelType() == ParallelType::Serial && + getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1; } + return false; +} + +std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( + Expr* expr) { + auto matmul = expr->as(); + NVF_ERROR(matmul != nullptr, "Expect a MatmulOp, got", expr); + TensorView* tva = matmul->inA(); + TensorView* tvb = matmul->inB(); + TensorView* tvc = matmul->out(); + NVF_ERROR( + !isSharded(tvb), "The B operand ", tvb, " is expected to be sharded"); + NVF_ERROR( + !isSharded(tvc), + "The output ", + matmul->out(), + " is expected to be sharded"); + const int64_t sharded_axis_index = + getShardedLogicalAxis(tva, ParallelType::DIDx); + IterDomain* stream_axis = tva->axis(0); + NVF_ERROR( + stream_axis->getParallelType() == ParallelType::Serial && + sharded_axis_index == 1, + "The operand A ", + tva, + " is expected to be sharded on the dimension 1"); + + auto hic = FusionGuard::getCurFusion()->as(); + + auto* get_current_stream = IrBuilder::create(); + hir::Stream* original_stream = get_current_stream->stream(); + + TensorView* tva_allgathered = + ops::newValLike(tva, tva->dtype())->as(); + tva_allgathered->axis(sharded_axis_index)->parallelize(ParallelType::Serial); + tva_allgathered->setMemoryType(MemoryType::Global); + auto* allocate_tva_allgathered = + IrBuilder::create(tva_allgathered, MemoryType::Global); + + tvc->setMemoryType(MemoryType::Global); + auto* allocate_tvc = + IrBuilder::create(tvc, MemoryType::Global); + + auto* j = + IrBuilder::create(DataType::Index); // running index of the for-loop + auto* start = hic->zeroVal(); + auto* stop = stream_axis->extent(); + auto* step = hic->oneVal(); + auto* for_loop = IrBuilder::create( + stream_axis, + /*index=*/j, + start, + stop, + step, + /*vectorize=*/false, + /*vectorize_shift=*/nullptr, + /*unroll_required=*/false, + CircularBufferLoopStage::NotApplicable, + /*circular_buffer_loop_stage_depth=*/0); + + auto* number_of_streams = + IrBuilder::create("numberOfStreams", DataType::Int); + auto* stream_index = mod(j, number_of_streams); + auto* stream = IrBuilder::create(stream_index); + auto* set_stream = IrBuilder::create(stream); + + TensorView* tva_j = select(tva, 0, j); + TensorView* tva_j_unsqueezed = tva_j; // unsqueeze(tva_j, 0); + TensorView* tva_allgathered_j = select(tva_allgathered, 0, j); + TensorView* tvc_j = select(tvc, 0, j); + + // [TAG: adding articifial outputs] + // The following line is artificial but necessary to make tva_j_unsqueeze a + // consumer of tva_j. + // + // HostIrEvaluator::handle(ForLoop*) relies on `Val::uses()` to find all + // **transitive** consumers of the loop index `j`. `tva_j_unsqueezed` is a + // bit special among all transitive consumers of `j`. It doesn't use `j` + // directly but uses `tva_j` which is a TensorView. TensorView's uses are + // built lazily by Fusion::resetTvUses. For efficiency, Fusion::resetTvUses + // only fix TensorViews that can reach outputs. Therefore, we add + // tva_j_unsqueezed as an output. Other TensorViews don't need this + // treatmenet because they are direct users of `j`, a scalar whose uses are + // built eagerly upon registration. + // + // We could have added `tvc_j` instead as an output, which transitively + // consumes `tva_j_unsqueezed`. However, `tvc_j` has two definitions, a Select + // and a MatmulOp, and StmtSort::getExprs only traverse via the first + // registered definition (i.e. the Select). This sounds like a bug -- I wonder + // how nvFuser resets the TensorView uses of a kir::Kernel, also non-SSA. + hic->addOutput(tva_j_unsqueezed); + + NVF_ERROR( + tva->hasDeviceMesh(), + "The matmul's input ", + tva, + "is expected to have a DeviceMesh"); + for (auto tv : {tva_j, tva_allgathered_j, tva_j_unsqueezed, tvc_j}) { + tv->setDeviceMesh(tva->getDeviceMesh()); + } + + auto* communication = IrBuilder::create( + CommunicationType::Allgather, + /*out=*/tva_allgathered_j, + /*in=*/tva_j_unsqueezed, + /*team=*/tva->getDeviceMesh().vector()); + auto* wait = IrBuilder::create(communication); + + auto* mm = IrBuilder::create(tvc_j, tva_allgathered_j, tvb); + + auto* set_back_original_stream = + IrBuilder::create(original_stream); + auto* sync_stream = IrBuilder::create(stream); + + std::vector loop_body = { + set_stream, + tva_j->definition(), + tva_j_unsqueezed->definition(), + tva_allgathered_j->definition(), + communication, + wait, + tvc_j->definition(), + mm, + set_back_original_stream, + sync_stream}; + for (Expr* expr : loop_body) { + for_loop->body().push_back(expr); + } + + return {get_current_stream, allocate_tva_allgathered, allocate_tvc, for_loop}; } std::unique_ptr HostIrLower::lower( @@ -396,21 +541,19 @@ std::unique_ptr HostIrLower::lower( "Communication segments must contain only one Expr"); for (auto* expr : HostIrLower::lower(ir_cloner.clone(group->exprs().at(0)))) { + hic->pushBackTopLevelExprs(expr); // Allocate the recv buffers of communications - NVF_ERROR( - expr->isA(), - "Expected a Communication but got ", - expr); - auto* communication = expr->as(); - TensorView* tv = communication->out(); - if (tv->getDeviceMesh().has(my_device_index)) { - auto* allocate = - IrBuilder::create(tv, MemoryType::Global); - hic->pushBackTopLevelExprs(allocate); + if (expr->isA()) { + auto* communication = expr->as(); + TensorView* tv = communication->out(); + if (tv->getDeviceMesh().has(my_device_index)) { + auto* allocate = + IrBuilder::create(tv, MemoryType::Global); + hic->pushBackTopLevelExprs(allocate); + } + auto wait = IrBuilder::create(communication); + hic->pushBackTopLevelExprs(wait); } - hic->pushBackTopLevelExprs(communication); - auto wait = IrBuilder::create(communication); - hic->pushBackTopLevelExprs(wait); } } else { auto host_unit = IrBuilder::create( diff --git a/csrc/host_ir/lower.h b/csrc/host_ir/lower.h index 6a1d44247d2..c99e8a8e76b 100644 --- a/csrc/host_ir/lower.h +++ b/csrc/host_ir/lower.h @@ -24,6 +24,9 @@ class HostIrLower { static std::unique_ptr lower( std::unique_ptr fusion, int64_t my_device_index); + + private: + static std::vector lowerToCollectiveBasedPipelinedGemmComm(Expr* expr); }; } // namespace nvfuser diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index 963b80812d3..7c6bdc5d3fe 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -24,7 +24,7 @@ MultiDeviceExecutor::MultiDeviceExecutor( std::unique_ptr fusion, Communicator& comm, hir::HostIrEvaluatorParams params) - : comm_(comm) { + : comm_(comm), number_of_outputs_(fusion->outputs().size()) { std::unique_ptr hic = HostIrLower::lower(std::move(fusion), comm.deviceId()); // Create the HostIrEvaluator representing the host program @@ -52,7 +52,9 @@ std::vector MultiDeviceExecutor::runWithInput( inputs.at(input_idx); } - return host_ir_executor_->runWithInput(val_to_IValue); + auto outputs = host_ir_executor_->runWithInput(val_to_IValue); + return std::vector( + outputs.end() - number_of_outputs_, outputs.end()); } std::ostream& MultiDeviceExecutor::print(std::ostream& os) { diff --git a/csrc/multidevice/executor.h b/csrc/multidevice/executor.h index 7cad0388b18..bc9d5d974b3 100644 --- a/csrc/multidevice/executor.h +++ b/csrc/multidevice/executor.h @@ -102,6 +102,12 @@ class MultiDeviceExecutor { Communicator& comm_; // holds the HostIrEvaluator used for execution std::unique_ptr host_ir_executor_; + // Store the number of outputs before it possibly gets artificially modified + // by HostIr::lower. This is undesirable but required for now. For more + // details, search for the comment in host_ir/lower.cpp tagged with "[TAG: + // adding articifial outputs]" + // TODO: fix + int64_t number_of_outputs_; }; } // namespace nvfuser diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index bc541623310..6dafe4aa819 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -947,8 +948,12 @@ TensorView* broadcast( .iter_type(IterType::Broadcast) .build()); } else { - out_domain.push_back( - IterDomainBuilder(inp_domain[iinp]).resetSchedulingParams().build()); + auto inp_id = inp_domain[iinp]; + auto out_id = IterDomainBuilder(inp_id).resetSchedulingParams().build(); + if (inp_id->isDeviceDim()) { + out_id->parallelize(inp_id->getParallelType()); + } + out_domain.push_back(out_id); iinp++; } ibdim++; diff --git a/csrc/preseg_passes/reorder_sharded_axis.cpp b/csrc/preseg_passes/reorder_sharded_axis.cpp index f6359cb424e..bf68f6aa9c5 100644 --- a/csrc/preseg_passes/reorder_sharded_axis.cpp +++ b/csrc/preseg_passes/reorder_sharded_axis.cpp @@ -25,7 +25,7 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) { const std::vector& exprs = fusion->exprs(); for (auto it = std::rbegin(exprs); it != std::rend(exprs); it++) { Expr* expr = *it; - if (!isResharding(expr)) { + if (HostIrLower::canLower(expr)) { continue; } NVF_ERROR( diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 27e9477aede..f8683effb55 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -349,6 +349,50 @@ TEST_F(P2PCommHostIrTest, CoalescedRingPairwiseExchange) { EXPECT_TRUE(torch::allclose(ref_output, outputs.back())); } +using OverlapDistributedMatmulTest = MultiDeviceTest; + +TEST_F(OverlapDistributedMatmulTest, AG_matmul) { + constexpr int64_t M = 1024; + constexpr int64_t K = 256; + constexpr int64_t N = 512; + constexpr int64_t S = 8; + const int64_t D = communicator_->size(); + ASSERT_EQ(M % (D * S), 0); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* a = makeContigTensor(4); //[S, DIDx(D), M/(S*d), K] + TensorView* b = makeContigTensor(2); //[K, N] + TensorView* c = matmul(a, b); //[S, D, M/(S*D), N] + + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + + auto mesh = DeviceMesh::createForNumDevices(D); + a->setDeviceMesh(mesh); + b->setDeviceMesh(mesh); + c->setDeviceMesh(mesh); + + a->axis(1)->parallelize(ParallelType::DIDx); + + MultiDeviceExecutor executor(std::move(fusion), *communicator_); + + auto tensor_options = + at::TensorOptions().dtype(at::kFloat).device(communicator_->device()); + at::Tensor ta_unsharded = at::randn({S, D, M / (S * D), K}, tensor_options); + at::Tensor ta = ta_unsharded.slice( + 1, communicator_->deviceId(), communicator_->deviceId() + 1); + at::Tensor tb = at::randn({K, N}, tensor_options); + at::Tensor tc_ref = at::matmul(ta_unsharded, tb); + + std::vector inputs = {ta, tb}; + auto tc = executor.runWithInput(inputs).at(0); + + EXPECT_TRUE(torch::allclose(tc_ref, tc)); +} + } // namespace hir } // namespace nvfuser From 92ab927e5f8dd85d21b8aa2feb279eef4bf5af94 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 18 Dec 2024 00:48:23 -0800 Subject: [PATCH 04/40] lint --- csrc/multidevice/executor.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index 7c6bdc5d3fe..82957a54f7f 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -24,7 +24,8 @@ MultiDeviceExecutor::MultiDeviceExecutor( std::unique_ptr fusion, Communicator& comm, hir::HostIrEvaluatorParams params) - : comm_(comm), number_of_outputs_(fusion->outputs().size()) { + : comm_(comm), + number_of_outputs_(static_cast(fusion->outputs().size())) { std::unique_ptr hic = HostIrLower::lower(std::move(fusion), comm.deviceId()); // Create the HostIrEvaluator representing the host program From ed4440ad471a15c803843af97137d7b92f10bdfa Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 18 Dec 2024 01:00:18 -0800 Subject: [PATCH 05/40] lint --- csrc/host_ir/executor.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 5a3c06e6891..4ab752bda44 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -274,9 +274,8 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) { } void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) { - auto my_device_index = communicator_ - ? static_cast(communicator_->deviceId()) - : 0; + auto my_device_index = static_cast( + communicator_ ? communicator_->deviceId() : 0); streams_.insert( {get_current_stream->stream(), c10::cuda::getCurrentCUDAStream(my_device_index)}); From ef8f00c7b13b8a54789cc929536d1da4b0fe2c76 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 18 Dec 2024 02:49:41 -0800 Subject: [PATCH 06/40] update with non blocking stream synchronization --- csrc/host_ir/executor.cpp | 28 ++++++++++++++++++++------ csrc/host_ir/executor.h | 2 ++ csrc/host_ir/host_ir.h | 2 ++ csrc/multidevice/communicator.cpp | 3 +++ tests/cpp/test_multidevice_host_ir.cpp | 19 +++++++++++++---- 5 files changed, 44 insertions(+), 10 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 8c8e514419e..64fb4aea587 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -188,7 +188,8 @@ HostIrEvaluator::HostIrEvaluator( HostIrEvaluatorParams params) : container_(std::move(container)), communicator_(communicator), - params_(params) { + params_(params), + my_device_index_(communicator_ ? communicator_->deviceId() : 0) { const DeviceIdxType device_index = (communicator_ != nullptr && communicator_->is_available()) ? communicator_->deviceId() @@ -216,6 +217,12 @@ std::vector HostIrEvaluator::runWithInput( dispatch(expr); } + c10::cuda::getCurrentCUDAStream( + static_cast(my_device_index_)) + .synchronize(); + for (auto event : events_) { + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(event)); + } // Collect global outputs return getKnownTensorOrUndefined(container_->outputs(), expr_evaluator_); } @@ -275,16 +282,25 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) { } void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) { - auto my_device_index = communicator_ - ? static_cast(communicator_->deviceId()) - : 0; streams_.insert( {get_current_stream->stream(), - c10::cuda::getCurrentCUDAStream(my_device_index)}); + c10::cuda::getCurrentCUDAStream( + static_cast(my_device_index_))}); } void HostIrEvaluator::handle(Synchronize* synchronize) { - getCUDAStream(synchronize->stream()).synchronize(); + cudaStream_t current_stream = + c10::cuda::getCurrentCUDAStream( + static_cast(my_device_index_)) + .stream(); + cudaStream_t stream_to_sync = getCUDAStream(synchronize->stream()).stream(); + + cudaEvent_t event; + NVFUSER_CUDA_RT_SAFE_CALL( + cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(event, stream_to_sync)); + NVFUSER_CUDA_RT_SAFE_CALL( + cudaStreamWaitEvent(current_stream, event, cudaEventWaitDefault)); } void HostIrEvaluator::handle(PostOnStream* post_ir) { diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index e47624ffa6d..1b476b126bf 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -140,6 +140,8 @@ class HostIrEvaluator final : public OptOutDispatch { using StreamKey = std::variant; std::unordered_map streams_; std::unordered_map> works_; + const int64_t my_device_index_; + std::vector events_; }; } // namespace hir diff --git a/csrc/host_ir/host_ir.h b/csrc/host_ir/host_ir.h index bed7d5893a7..09503eff3b5 100644 --- a/csrc/host_ir/host_ir.h +++ b/csrc/host_ir/host_ir.h @@ -211,6 +211,8 @@ class Wait : public Expr { } }; +// Makes the current stream wait on the given stream. Non-blocking from the host +// point of view. class Synchronize : public Expr { public: using Expr::Expr; diff --git a/csrc/multidevice/communicator.cpp b/csrc/multidevice/communicator.cpp index 8197ea224f4..6cf1a499bb9 100644 --- a/csrc/multidevice/communicator.cpp +++ b/csrc/multidevice/communicator.cpp @@ -5,6 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include #include #include @@ -196,6 +197,8 @@ Communicator::Communicator( return; } + NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(local_rank_)); + #ifdef NVFUSER_DISTRIBUTED c10d::TCPStoreOptions store_opts; { diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index f8683effb55..ceeef81fa1f 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -5,6 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include #include #include #include @@ -352,9 +353,9 @@ TEST_F(P2PCommHostIrTest, CoalescedRingPairwiseExchange) { using OverlapDistributedMatmulTest = MultiDeviceTest; TEST_F(OverlapDistributedMatmulTest, AG_matmul) { - constexpr int64_t M = 1024; - constexpr int64_t K = 256; - constexpr int64_t N = 512; + constexpr int64_t M = 32768; + constexpr int64_t K = 32768; + constexpr int64_t N = 1024; constexpr int64_t S = 8; const int64_t D = communicator_->size(); ASSERT_EQ(M % (D * S), 0); @@ -388,7 +389,17 @@ TEST_F(OverlapDistributedMatmulTest, AG_matmul) { at::Tensor tc_ref = at::matmul(ta_unsharded, tb); std::vector inputs = {ta, tb}; - auto tc = executor.runWithInput(inputs).at(0); + at::Tensor tc; + + constexpr int64_t number_of_iterations = 20; + constexpr int64_t number_of_warmup_iterations = 5; + for (const auto& i : c10::irange(number_of_iterations)) { + if (i == number_of_warmup_iterations) { + cudaProfilerStart(); + } + tc = executor.runWithInput(inputs).at(0); + } + cudaProfilerStop(); EXPECT_TRUE(torch::allclose(tc_ref, tc)); } From 36fd2be2a5bec3b7fa6578b0d3e0a6a7814f53d9 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 18 Dec 2024 02:58:40 -0800 Subject: [PATCH 07/40] make stream synchronization non blocking --- csrc/host_ir/executor.cpp | 22 ++++++++++++++++++++-- csrc/host_ir/executor.h | 2 ++ csrc/multidevice/communicator.cpp | 3 +++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 8b8ce484946..92a6a245fb0 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -188,7 +188,8 @@ HostIrEvaluator::HostIrEvaluator( HostIrEvaluatorParams params) : container_(std::move(container)), communicator_(communicator), - params_(params) { + params_(params), + my_device_index_(communicator_ ? communicator_->deviceId() : 0) { const DeviceIdxType device_index = (communicator_ != nullptr && communicator_->is_available()) ? communicator_->deviceId() @@ -215,6 +216,12 @@ std::vector HostIrEvaluator::runWithInput( dispatch(expr); } + c10::cuda::getCurrentCUDAStream( + static_cast(my_device_index_)) + .synchronize(); + for (auto event : events_) { + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(event)); + } // Collect global outputs return getKnownTensorOrUndefined(container_->outputs(), expr_evaluator_); } @@ -274,7 +281,18 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) { } void HostIrEvaluator::handle(Synchronize* synchronize) { - getCUDAStream(synchronize->stream()).synchronize(); + cudaStream_t current_stream = + c10::cuda::getCurrentCUDAStream( + static_cast(my_device_index_)) + .stream(); + cudaStream_t stream_to_sync = getCUDAStream(synchronize->stream()).stream(); + + cudaEvent_t event; + NVFUSER_CUDA_RT_SAFE_CALL( + cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(event, stream_to_sync)); + NVFUSER_CUDA_RT_SAFE_CALL( + cudaStreamWaitEvent(current_stream, event, cudaEventWaitDefault)); } void HostIrEvaluator::handle(PostOnStream* post_ir) { diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index 7e3932c6b1a..a9cd8eeac69 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -138,6 +138,8 @@ class HostIrEvaluator final : public OptOutDispatch { using StreamKey = std::variant; std::unordered_map streams_; std::unordered_map> works_; + const int64_t my_device_index_; + std::vector events_; }; } // namespace hir diff --git a/csrc/multidevice/communicator.cpp b/csrc/multidevice/communicator.cpp index 8197ea224f4..6cf1a499bb9 100644 --- a/csrc/multidevice/communicator.cpp +++ b/csrc/multidevice/communicator.cpp @@ -5,6 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include #include #include @@ -196,6 +197,8 @@ Communicator::Communicator( return; } + NVFUSER_CUDA_RT_SAFE_CALL(cudaSetDevice(local_rank_)); + #ifdef NVFUSER_DISTRIBUTED c10d::TCPStoreOptions store_opts; { From 1e9f1d090dfadb99fb3265dcdee2cd7a75910e50 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 18 Dec 2024 03:34:31 -0800 Subject: [PATCH 08/40] lint --- csrc/host_ir/executor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 92a6a245fb0..f64f7679766 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -287,7 +287,7 @@ void HostIrEvaluator::handle(Synchronize* synchronize) { .stream(); cudaStream_t stream_to_sync = getCUDAStream(synchronize->stream()).stream(); - cudaEvent_t event; + cudaEvent_t event = {}; NVFUSER_CUDA_RT_SAFE_CALL( cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(event, stream_to_sync)); From af06de406f3082355fdaf6131ef831908d4fd248 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 18 Dec 2024 04:00:33 -0800 Subject: [PATCH 09/40] add event to events_ container --- csrc/host_ir/executor.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index f64f7679766..0a4752b7804 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -293,6 +293,7 @@ void HostIrEvaluator::handle(Synchronize* synchronize) { NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(event, stream_to_sync)); NVFUSER_CUDA_RT_SAFE_CALL( cudaStreamWaitEvent(current_stream, event, cudaEventWaitDefault)); + events_.push_back(event); } void HostIrEvaluator::handle(PostOnStream* post_ir) { From 5e166a0277e1eabd2d6e40918e2f3cca3582fe89 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 18 Dec 2024 05:31:37 -0800 Subject: [PATCH 10/40] destroy event async at create site --- csrc/host_ir/executor.cpp | 8 +------- csrc/host_ir/executor.h | 1 - 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 0a4752b7804..69b5b9c704d 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -216,12 +216,6 @@ std::vector HostIrEvaluator::runWithInput( dispatch(expr); } - c10::cuda::getCurrentCUDAStream( - static_cast(my_device_index_)) - .synchronize(); - for (auto event : events_) { - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(event)); - } // Collect global outputs return getKnownTensorOrUndefined(container_->outputs(), expr_evaluator_); } @@ -293,7 +287,7 @@ void HostIrEvaluator::handle(Synchronize* synchronize) { NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(event, stream_to_sync)); NVFUSER_CUDA_RT_SAFE_CALL( cudaStreamWaitEvent(current_stream, event, cudaEventWaitDefault)); - events_.push_back(event); + NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(event)); } void HostIrEvaluator::handle(PostOnStream* post_ir) { diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index a9cd8eeac69..6f9070b810a 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -139,7 +139,6 @@ class HostIrEvaluator final : public OptOutDispatch { std::unordered_map streams_; std::unordered_map> works_; const int64_t my_device_index_; - std::vector events_; }; } // namespace hir From 741202b4456ab5593ea54d72873bcacd8227a808 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 23 Dec 2024 03:04:32 -0800 Subject: [PATCH 11/40] minor review --- csrc/host_ir/executor.cpp | 8 ++++---- csrc/host_ir/executor.h | 1 + csrc/host_ir/host_ir.cpp | 10 ---------- csrc/host_ir/host_ir.h | 3 --- 4 files changed, 5 insertions(+), 17 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 4ab752bda44..e216e2390af 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -188,7 +188,8 @@ HostIrEvaluator::HostIrEvaluator( HostIrEvaluatorParams params) : container_(std::move(container)), communicator_(communicator), - params_(params) { + params_(params), + my_device_index_(communicator_ ? communicator_->deviceId() : 0) { const DeviceIdxType device_index = (communicator_ != nullptr && communicator_->is_available()) ? communicator_->deviceId() @@ -274,11 +275,10 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) { } void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) { - auto my_device_index = static_cast( - communicator_ ? communicator_->deviceId() : 0); streams_.insert( {get_current_stream->stream(), - c10::cuda::getCurrentCUDAStream(my_device_index)}); + c10::cuda::getCurrentCUDAStream( + static_cast(my_device_index_))}); } void HostIrEvaluator::handle(Synchronize* synchronize) { diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index 29d9eb5dd8f..a51dc32aed4 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -139,6 +139,7 @@ class HostIrEvaluator final : public OptOutDispatch { using StreamKey = std::variant; std::unordered_map streams_; std::unordered_map> works_; + const int64_t my_device_index_; }; } // namespace hir diff --git a/csrc/host_ir/host_ir.cpp b/csrc/host_ir/host_ir.cpp index 24ceda7daba..49b33f59823 100644 --- a/csrc/host_ir/host_ir.cpp +++ b/csrc/host_ir/host_ir.cpp @@ -195,16 +195,6 @@ std::string GetCurrentStream::toString(int indent_size) const { return ss.str(); } -// TODO: implement better ? -std::string GetCurrentStream::toInlineString(int indent_size) const { - NVF_CHECK(false, "Cannot be printed inline"); -} - -// TODO: implement -bool GetCurrentStream::sameAs(const Statement* other) const { - return false; -} - Wait::Wait(IrBuilderPasskey passkey, Expr* expr) : Expr(passkey, {}, {}, {expr}) { NVF_ERROR(passkey.ir_container_ != nullptr); diff --git a/csrc/host_ir/host_ir.h b/csrc/host_ir/host_ir.h index bed7d5893a7..82d67d6f4cc 100644 --- a/csrc/host_ir/host_ir.h +++ b/csrc/host_ir/host_ir.h @@ -174,13 +174,10 @@ class GetCurrentStream : public Expr { NVFUSER_DECLARE_CLONE_AND_CREATE std::string toString(int indent_size = 0) const override; - std::string toInlineString(int indent_size = 0) const override; const char* getOpString() const override { return "hir::GetCurrentStream"; } - bool sameAs(const Statement* other) const override; - Stream* stream() const { return attributes_.at(0)->as(); } From 0374604263f41c45264f638739572533f242c354 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 23 Dec 2024 06:25:30 -0800 Subject: [PATCH 12/40] fix merge --- csrc/host_ir/executor.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 6641d2aa0de..63b3f073b75 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -217,12 +217,6 @@ std::vector HostIrEvaluator::runWithInput( dispatch(expr); } - c10::cuda::getCurrentCUDAStream( - static_cast(my_device_index_)) - .synchronize(); - for (auto event : events_) { - NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(event)); - } // Collect global outputs return getKnownTensorOrUndefined(container_->outputs(), expr_evaluator_); } From 5e07ad85f37c589516d4b0d0e4c6316a4ff146ed Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 23 Dec 2024 06:31:33 -0800 Subject: [PATCH 13/40] minor review --- csrc/host_ir/lower.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 46addba4b17..3bef51dd81d 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -358,12 +358,12 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( TensorView* tvb = matmul->inB(); TensorView* tvc = matmul->out(); NVF_ERROR( - !isSharded(tvb), "The B operand ", tvb, " is expected to be sharded"); + !isSharded(tvb), "The B operand ", tvb, " is expected to not be sharded"); NVF_ERROR( !isSharded(tvc), "The output ", matmul->out(), - " is expected to be sharded"); + " is expected to not be sharded"); const int64_t sharded_axis_index = getShardedLogicalAxis(tva, ParallelType::DIDx); IterDomain* stream_axis = tva->axis(0); From b546dcec9da8b05f5b211810bff485433fe5ea08 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 23 Dec 2024 06:39:10 -0800 Subject: [PATCH 14/40] remove now unnecessary trick of adding artifical outputs --- csrc/host_ir/executor.h | 1 + csrc/host_ir/lower.cpp | 27 ++------------------------- csrc/multidevice/executor.cpp | 7 ++----- csrc/multidevice/executor.h | 6 ------ 4 files changed, 5 insertions(+), 36 deletions(-) diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index c6214b083c2..6efbbf6854e 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -74,6 +74,7 @@ struct HostIrEvaluatorParams { // Experimental: whether to cache fusion executor. WAR: avoid recompilation // but implicitely assumes that the input shape don't change over iterations bool cache_fusion_executor = false; + // number of additional cuda streams to use at runtime for comm+compute pipelining int64_t number_of_streams = 4; }; diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 3bef51dd81d..38e7f9ac45c 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -414,44 +414,22 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( auto* set_stream = IrBuilder::create(stream); TensorView* tva_j = select(tva, 0, j); - TensorView* tva_j_unsqueezed = tva_j; // unsqueeze(tva_j, 0); TensorView* tva_allgathered_j = select(tva_allgathered, 0, j); TensorView* tvc_j = select(tvc, 0, j); - // [TAG: adding articifial outputs] - // The following line is artificial but necessary to make tva_j_unsqueeze a - // consumer of tva_j. - // - // HostIrEvaluator::handle(ForLoop*) relies on `Val::uses()` to find all - // **transitive** consumers of the loop index `j`. `tva_j_unsqueezed` is a - // bit special among all transitive consumers of `j`. It doesn't use `j` - // directly but uses `tva_j` which is a TensorView. TensorView's uses are - // built lazily by Fusion::resetTvUses. For efficiency, Fusion::resetTvUses - // only fix TensorViews that can reach outputs. Therefore, we add - // tva_j_unsqueezed as an output. Other TensorViews don't need this - // treatmenet because they are direct users of `j`, a scalar whose uses are - // built eagerly upon registration. - // - // We could have added `tvc_j` instead as an output, which transitively - // consumes `tva_j_unsqueezed`. However, `tvc_j` has two definitions, a Select - // and a MatmulOp, and StmtSort::getExprs only traverse via the first - // registered definition (i.e. the Select). This sounds like a bug -- I wonder - // how nvFuser resets the TensorView uses of a kir::Kernel, also non-SSA. - hic->addOutput(tva_j_unsqueezed); - NVF_ERROR( tva->hasDeviceMesh(), "The matmul's input ", tva, "is expected to have a DeviceMesh"); - for (auto tv : {tva_j, tva_allgathered_j, tva_j_unsqueezed, tvc_j}) { + for (auto tv : {tva_j, tva_allgathered_j, tvc_j}) { tv->setDeviceMesh(tva->getDeviceMesh()); } auto* communication = IrBuilder::create( CommunicationType::Allgather, /*out=*/tva_allgathered_j, - /*in=*/tva_j_unsqueezed, + /*in=*/tva_j, /*team=*/tva->getDeviceMesh().vector()); auto* wait = IrBuilder::create(communication); @@ -464,7 +442,6 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( std::vector loop_body = { set_stream, tva_j->definition(), - tva_j_unsqueezed->definition(), tva_allgathered_j->definition(), communication, wait, diff --git a/csrc/multidevice/executor.cpp b/csrc/multidevice/executor.cpp index 82957a54f7f..963b80812d3 100644 --- a/csrc/multidevice/executor.cpp +++ b/csrc/multidevice/executor.cpp @@ -24,8 +24,7 @@ MultiDeviceExecutor::MultiDeviceExecutor( std::unique_ptr fusion, Communicator& comm, hir::HostIrEvaluatorParams params) - : comm_(comm), - number_of_outputs_(static_cast(fusion->outputs().size())) { + : comm_(comm) { std::unique_ptr hic = HostIrLower::lower(std::move(fusion), comm.deviceId()); // Create the HostIrEvaluator representing the host program @@ -53,9 +52,7 @@ std::vector MultiDeviceExecutor::runWithInput( inputs.at(input_idx); } - auto outputs = host_ir_executor_->runWithInput(val_to_IValue); - return std::vector( - outputs.end() - number_of_outputs_, outputs.end()); + return host_ir_executor_->runWithInput(val_to_IValue); } std::ostream& MultiDeviceExecutor::print(std::ostream& os) { diff --git a/csrc/multidevice/executor.h b/csrc/multidevice/executor.h index bc9d5d974b3..7cad0388b18 100644 --- a/csrc/multidevice/executor.h +++ b/csrc/multidevice/executor.h @@ -102,12 +102,6 @@ class MultiDeviceExecutor { Communicator& comm_; // holds the HostIrEvaluator used for execution std::unique_ptr host_ir_executor_; - // Store the number of outputs before it possibly gets artificially modified - // by HostIr::lower. This is undesirable but required for now. For more - // details, search for the comment in host_ir/lower.cpp tagged with "[TAG: - // adding articifial outputs]" - // TODO: fix - int64_t number_of_outputs_; }; } // namespace nvfuser From 8e8b24724868e0ce074e72e2aa4f52ff131928dd Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 23 Dec 2024 06:41:44 -0800 Subject: [PATCH 15/40] lint --- csrc/host_ir/executor.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index 6efbbf6854e..70d8030c640 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -74,7 +74,8 @@ struct HostIrEvaluatorParams { // Experimental: whether to cache fusion executor. WAR: avoid recompilation // but implicitely assumes that the input shape don't change over iterations bool cache_fusion_executor = false; - // number of additional cuda streams to use at runtime for comm+compute pipelining + // number of additional cuda streams to use at runtime for comm+compute + // pipelining int64_t number_of_streams = 4; }; From d5b42c24b88adc167de4e04e457bc5f6d9170a28 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 23 Dec 2024 06:46:50 -0800 Subject: [PATCH 16/40] remove now unnecessary patch on broadcast --- csrc/ops/alias.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 6dafe4aa819..bc541623310 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include #include #include @@ -948,12 +947,8 @@ TensorView* broadcast( .iter_type(IterType::Broadcast) .build()); } else { - auto inp_id = inp_domain[iinp]; - auto out_id = IterDomainBuilder(inp_id).resetSchedulingParams().build(); - if (inp_id->isDeviceDim()) { - out_id->parallelize(inp_id->getParallelType()); - } - out_domain.push_back(out_id); + out_domain.push_back( + IterDomainBuilder(inp_domain[iinp]).resetSchedulingParams().build()); iinp++; } ibdim++; From 4191ecf157402c87e75790b9d5e2639dc5eafefa Mon Sep 17 00:00:00 2001 From: snordmann Date: Fri, 3 Jan 2025 03:20:21 -0800 Subject: [PATCH 17/40] fix typo in canLower --- csrc/host_ir/lower.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 38e7f9ac45c..4c11dffd252 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -315,7 +315,7 @@ bool HostIrLower::canLower(Expr* expr) { return false; } if (expr->isA()) { - if (!isInnerResharding(expr)) { + if (isInnerResharding(expr)) { return false; } auto in = expr->as()->in()->as(); @@ -337,7 +337,7 @@ bool HostIrLower::canLower(Expr* expr) { auto c2p_map_it = c2p_map.find(reduction_axis.at(0)); return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim(); } else if (expr->isA()) { - return isInnerResharding(expr) && + return !isInnerResharding(expr) && expr->as()->opType() == LoadStoreOpType::Set; } else if (expr->as()) { // For now we only support c = matmul(a,b) when b,c are fully replicated and From dfc33f2dd08083d8f2e4fea59b2c1a32c69c6889 Mon Sep 17 00:00:00 2001 From: snordmann Date: Fri, 3 Jan 2025 03:20:29 -0800 Subject: [PATCH 18/40] add Stream parallelType --- csrc/host_ir/lower.cpp | 3 ++- csrc/type.cpp | 2 ++ csrc/type.h | 1 + tests/cpp/test_multidevice_host_ir.cpp | 1 + 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 4c11dffd252..a6a8fb00fe9 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -345,7 +345,8 @@ bool HostIrLower::canLower(Expr* expr) { auto* matmul = expr->as(); return !isSharded(matmul->inB()) && !isSharded(matmul->out()) && matmul->inA()->axis(0)->getParallelType() == ParallelType::Serial && - getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1; + getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1 && + matmul->out()->axis(0)->getParallelType() == ParallelType::Stream; } return false; } diff --git a/csrc/type.cpp b/csrc/type.cpp index ab087361a1d..36819404ae9 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -712,6 +712,8 @@ static const char* parallel_type2string(ParallelType t) { return "threadIdx.y"; case ParallelType::TIDx: return "threadIdx.x"; + case ParallelType::Stream: + return "Stream"; case ParallelType::Vectorize: return "V"; case ParallelType::MisalignedVectorize: diff --git a/csrc/type.h b/csrc/type.h index 89cebe8763b..265f1a939ee 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -672,6 +672,7 @@ enum class ParallelType { TIDz, TIDy, TIDx, + Stream, Vectorize, MisalignedVectorize, Unroll, diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index ceeef81fa1f..81ced52b530 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -377,6 +377,7 @@ TEST_F(OverlapDistributedMatmulTest, AG_matmul) { c->setDeviceMesh(mesh); a->axis(1)->parallelize(ParallelType::DIDx); + c->axis(0)->parallelize(ParallelType::Stream); MultiDeviceExecutor executor(std::move(fusion), *communicator_); From 4e1ecb9f8ecae3a8ae9da6ab17daa7940d3c6d26 Mon Sep 17 00:00:00 2001 From: snordmann Date: Fri, 3 Jan 2025 03:28:17 -0800 Subject: [PATCH 19/40] minor reviews --- csrc/host_ir/lower.cpp | 15 +++++++-------- tests/cpp/test_multidevice_host_ir.cpp | 5 ++++- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index a6a8fb00fe9..250432ce136 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -314,12 +314,12 @@ bool HostIrLower::canLower(Expr* expr) { if (!ir_utils::isTvOp(expr)) { return false; } - if (expr->isA()) { + if (auto* reduction = dynamic_cast(expr)) { if (isInnerResharding(expr)) { return false; } - auto in = expr->as()->in()->as(); - auto out = expr->as()->out()->as(); + auto in = reduction->in()->as(); + auto out = reduction->out()->as(); // get the reduced axis std::vector reduction_axis; std::copy_if( @@ -336,13 +336,12 @@ bool HostIrLower::canLower(Expr* expr) { PairwiseLogicalDomainMap(in, out).mapConsumerToProducer(); auto c2p_map_it = c2p_map.find(reduction_axis.at(0)); return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim(); - } else if (expr->isA()) { - return !isInnerResharding(expr) && - expr->as()->opType() == LoadStoreOpType::Set; - } else if (expr->as()) { + } else if (auto* ldst = dynamic_cast(expr)) { + return !isInnerResharding(ldst) && + ldst->as()->opType() == LoadStoreOpType::Set; + } else if (auto* matmul = dynamic_cast(expr)) { // For now we only support c = matmul(a,b) when b,c are fully replicated and // a is sharded on axis 1 - auto* matmul = expr->as(); return !isSharded(matmul->inB()) && !isSharded(matmul->out()) && matmul->inA()->axis(0)->getParallelType() == ParallelType::Serial && getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1 && diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 81ced52b530..518adef3022 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -358,7 +358,10 @@ TEST_F(OverlapDistributedMatmulTest, AG_matmul) { constexpr int64_t N = 1024; constexpr int64_t S = 8; const int64_t D = communicator_->size(); - ASSERT_EQ(M % (D * S), 0); + if (M % (D * S) != 0) { + GTEST_SKIP() << "M must be a multiple of D * S, but got M = " << M + << ", D = " << D << ", S = " << S; + } auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); From a5c70b85e6ce617410bb1cf684f94d91c75ed2a6 Mon Sep 17 00:00:00 2001 From: snordmann Date: Fri, 3 Jan 2025 09:44:07 -0800 Subject: [PATCH 20/40] fix bug: allocate dst buffer before posting communication --- csrc/host_ir/lower.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 250432ce136..87d750a7850 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -518,7 +518,6 @@ std::unique_ptr HostIrLower::lower( "Communication segments must contain only one Expr"); for (auto* expr : HostIrLower::lower(ir_cloner.clone(group->exprs().at(0)))) { - hic->pushBackTopLevelExprs(expr); // Allocate the recv buffers of communications if (expr->isA()) { auto* communication = expr->as(); @@ -528,7 +527,10 @@ std::unique_ptr HostIrLower::lower( IrBuilder::create(tv, MemoryType::Global); hic->pushBackTopLevelExprs(allocate); } - auto wait = IrBuilder::create(communication); + } + hic->pushBackTopLevelExprs(expr); + if (expr->isA()) { + auto wait = IrBuilder::create(expr->as()); hic->pushBackTopLevelExprs(wait); } } From f66e97a37e7df89805d1bececd78e61fb4572edf Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 7 Jan 2025 03:32:36 -0800 Subject: [PATCH 21/40] change order of presegpass --- csrc/host_ir/lower.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 87d750a7850..84815866dc6 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -463,10 +463,10 @@ std::unique_ptr HostIrLower::lower( // Note: passes run before PreSegmenter optimization passes. preseg_passes::OptimizationPass< preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); - preseg_passes::OptimizationPass< - preseg_passes::InsertReshardingsPass>::runPass(fusion.get()); preseg_passes::OptimizationPass< preseg_passes::ReorderShardedAxisPass>::runPass(fusion.get()); + preseg_passes::OptimizationPass< + preseg_passes::InsertReshardingsPass>::runPass(fusion.get()); preseg_passes::OptimizationPass< preseg_passes::MakeReshardingContiguousPass>::runPass(fusion.get()); From 09ccfa5b46f8c584d4120f518a86c95464558b80 Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 7 Jan 2025 04:25:40 -0800 Subject: [PATCH 22/40] minor comment --- tests/cpp/test_multidevice_host_ir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 518adef3022..d34922a68e2 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -366,7 +366,7 @@ TEST_F(OverlapDistributedMatmulTest, AG_matmul) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - TensorView* a = makeContigTensor(4); //[S, DIDx(D), M/(S*d), K] + TensorView* a = makeContigTensor(4); //[S, DIDx(D), M/(S*D), K] TensorView* b = makeContigTensor(2); //[K, N] TensorView* c = matmul(a, b); //[S, D, M/(S*D), N] From 8287679a2a24576a17f54a43981ad93ef1ccf170 Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 7 Jan 2025 09:48:24 -0800 Subject: [PATCH 23/40] fix MultiDeviceReductionTest.UnshardedInput_ShardedOutput/ tests --- csrc/multidevice/utils.cpp | 8 +++++--- csrc/scheduler/no_op.cpp | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 847557bfa3a..8bc30e667b6 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -100,7 +100,7 @@ std::pair, std::vector> getShardingChanges bool isSharded(const TensorView* tv) { bool is_sharded = false; for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { - if (!alloc_id->isDeviceDim()) { + if (!alloc_id->isDeviceDim() || alloc_id->isReduction()) { continue; } @@ -160,7 +160,7 @@ int64_t getShardedLogicalAxis( std::unordered_map parallel_type_to_id = mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain()); IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type); - if (alloc_id == nullptr) { + if (alloc_id == nullptr || alloc_id->isReduction()) { return -1; } @@ -417,7 +417,9 @@ bool haveDifferentShardings( .strictAreMapped(a, b); }; - if (!is_mapped_in_id_model(p_loop_id, c_loop_id, id_model)) { + if (!is_mapped_in_id_model(p_loop_id, c_loop_id, id_model) + || (p_loop_id != nullptr && c_loop_id != nullptr && p_loop_id->isReduction() != c_loop_id->isReduction()) + ) { return true; } } diff --git a/csrc/scheduler/no_op.cpp b/csrc/scheduler/no_op.cpp index a7eb6e2de1f..73174cc432b 100644 --- a/csrc/scheduler/no_op.cpp +++ b/csrc/scheduler/no_op.cpp @@ -49,7 +49,7 @@ bool NoOpScheduler::canScheduleCompileTime(Fusion* fusion) { const std::vector& exprs = fusion->exprs(); if (exprs.size() == 1 && isResharding(exprs[0]) && HostIrLower::canLower(exprs[0])) { - return true; + return true; // b? } if (allOutputsArePointerArithmetics(fusion)) { From 301e54d5851f4420bd39c03e9933fd66562881ac Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 7 Jan 2025 10:12:24 -0800 Subject: [PATCH 24/40] bypass ReorderShardedAxisPass if multiple IO. fix DistributedTransformerTest --- csrc/preseg_passes/reorder_sharded_axis.cpp | 11 +++-------- tests/cpp/test_multidevice_matmul.cpp | 4 ++-- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/csrc/preseg_passes/reorder_sharded_axis.cpp b/csrc/preseg_passes/reorder_sharded_axis.cpp index bf68f6aa9c5..cde09fa18b2 100644 --- a/csrc/preseg_passes/reorder_sharded_axis.cpp +++ b/csrc/preseg_passes/reorder_sharded_axis.cpp @@ -28,18 +28,13 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) { if (HostIrLower::canLower(expr)) { continue; } + if (expr->outputs().size() > 1 || expr->inputs().size() > 1) { + continue; + } NVF_ERROR( ir_utils::isTvOp(expr), "Non-tv op is not supported: ", expr->toString()); - NVF_ERROR( - expr->outputs().size() == 1, - "Resharding operations can only have one output: ", - expr->toString()); - NVF_ERROR( - expr->inputs().size() == 1, - "Resharding operations can have only one input: ", - expr->toString()); auto* output = expr->outputs().at(0)->as(); auto* input = expr->inputs().at(0)->as(); auto [shard_additions, shard_deletions] = getShardingChanges(input, output); diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index a4e323a553d..dc0c4a50371 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -415,8 +415,8 @@ TEST_F(DistributedMatmulTest, AnnotateWeightOnly) { // x is of shape [2, 3] and replicated. // w is of shape [3, D*5] and column-wise sharded. // y is expected to have shape [2, D*5] and to be also column-wise sharded. - auto x_tensor = at::randn({2, 3}, tensor_options); - auto w_tensor = at::randn({mesh.size(), 3, 5}, tensor_options); + auto x_tensor = at::randn({64, 32}, tensor_options); + auto w_tensor = at::randn({mesh.size(), 32, 128}, tensor_options); auto sharded_w_tensor = shardTensor(w_tensor, w); FusionExecutorCache executor_cache(std::move(fusion)); From 0c9b65e59c4d938e453ed4b40bc1bffce8ba103c Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 7 Jan 2025 10:17:08 -0800 Subject: [PATCH 25/40] change tensor size to loosen tolerance/error in DistributedMatmulTest.AnnotateWeightOnly --- tests/cpp/test_multidevice_matmul.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index dc0c4a50371..19a7d955db4 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -412,9 +412,9 @@ TEST_F(DistributedMatmulTest, AnnotateWeightOnly) { w->setDeviceMesh(mesh); w->axis(0)->parallelize(ParallelType::DIDx); - // x is of shape [2, 3] and replicated. - // w is of shape [3, D*5] and column-wise sharded. - // y is expected to have shape [2, D*5] and to be also column-wise sharded. + // x is of shape [64, 32] and replicated. + // w is of shape [32, D*128] and column-wise sharded. + // y is expected to have shape [64, D*128] and to be also column-wise sharded. auto x_tensor = at::randn({64, 32}, tensor_options); auto w_tensor = at::randn({mesh.size(), 32, 128}, tensor_options); auto sharded_w_tensor = shardTensor(w_tensor, w); From 3a0d827755ee09ff667034933f99659b0c12622a Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 7 Jan 2025 10:18:16 -0800 Subject: [PATCH 26/40] lint --- csrc/multidevice/utils.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 8bc30e667b6..059f161bd5a 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -417,9 +417,9 @@ bool haveDifferentShardings( .strictAreMapped(a, b); }; - if (!is_mapped_in_id_model(p_loop_id, c_loop_id, id_model) - || (p_loop_id != nullptr && c_loop_id != nullptr && p_loop_id->isReduction() != c_loop_id->isReduction()) - ) { + if (!is_mapped_in_id_model(p_loop_id, c_loop_id, id_model) || + (p_loop_id != nullptr && c_loop_id != nullptr && + p_loop_id->isReduction() != c_loop_id->isReduction())) { return true; } } From decb0555cf81a79f0701775ddb841b2081dd806e Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 7 Jan 2025 10:22:07 -0800 Subject: [PATCH 27/40] minor comments --- tests/cpp/test_multidevice_host_ir.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index d34922a68e2..7c9ebb9cc5d 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -395,10 +395,10 @@ TEST_F(OverlapDistributedMatmulTest, AG_matmul) { std::vector inputs = {ta, tb}; at::Tensor tc; - constexpr int64_t number_of_iterations = 20; - constexpr int64_t number_of_warmup_iterations = 5; - for (const auto& i : c10::irange(number_of_iterations)) { - if (i == number_of_warmup_iterations) { + constexpr int64_t kNumberOfIterations = 20; + constexpr int64_t kNumberOfWarmupIterations = 5; + for (auto i : c10::irange(kNumberOfIterations)) { + if (i == kNumberOfWarmupIterations) { cudaProfilerStart(); } tc = executor.runWithInput(inputs).at(0); From 11843d6f5381c21db73a8512b23a7c25358e039e Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 7 Jan 2025 10:29:34 -0800 Subject: [PATCH 28/40] typo --- csrc/scheduler/no_op.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/no_op.cpp b/csrc/scheduler/no_op.cpp index 73174cc432b..a7eb6e2de1f 100644 --- a/csrc/scheduler/no_op.cpp +++ b/csrc/scheduler/no_op.cpp @@ -49,7 +49,7 @@ bool NoOpScheduler::canScheduleCompileTime(Fusion* fusion) { const std::vector& exprs = fusion->exprs(); if (exprs.size() == 1 && isResharding(exprs[0]) && HostIrLower::canLower(exprs[0])) { - return true; // b? + return true; } if (allOutputsArePointerArithmetics(fusion)) { From caf5d0b7b376f06bc820d8e8c35c13bba225799c Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 8 Jan 2025 10:17:31 -0800 Subject: [PATCH 29/40] increase tolerance rate --- tests/cpp/test_multidevice_host_ir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 7c9ebb9cc5d..ef05c4a45ac 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -405,7 +405,7 @@ TEST_F(OverlapDistributedMatmulTest, AG_matmul) { } cudaProfilerStop(); - EXPECT_TRUE(torch::allclose(tc_ref, tc)); + EXPECT_TRUE(torch::allclose(tc_ref, tc, 1e-2, 1e-2)); } } // namespace hir From 5b6c7bdf41f941c24275b56d2c1eb761b7cc8afe Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 8 Jan 2025 10:18:25 -0800 Subject: [PATCH 30/40] still throws if two axis are DIDx, even if one is reduced --- csrc/multidevice/utils.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 059f161bd5a..c10aa5e4395 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -99,17 +99,23 @@ std::pair, std::vector> getShardingChanges bool isSharded(const TensorView* tv) { bool is_sharded = false; + bool is_reduction_sharded = false; for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { - if (!alloc_id->isDeviceDim() || alloc_id->isReduction()) { + if (!alloc_id->isDeviceDim()) { continue; } // Only one axis can be sharded on DIDx. NVF_ERROR( - !is_sharded, + !is_sharded && !is_reduction_sharded, "Multiple IterDomains parallelized on DIDx in TensorView ", tv); - is_sharded = true; + + if (alloc_id->isReduction()) { + is_reduction_sharded = true; + } else { + is_sharded = true; + } } return is_sharded; } From 632aa1e1fbffeafb77415dc000cf733546c62074 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 8 Jan 2025 10:23:57 -0800 Subject: [PATCH 31/40] support multiple additions/deletions in isInnerResharding --- csrc/multidevice/utils.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index c10aa5e4395..9f46359eae3 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -468,9 +468,6 @@ bool isInnerResharding(Expr* expr) { for (auto output : ir_utils::filterByType(expr->outputs())) { auto [shard_additions, shard_deletions] = getShardingChanges(input, output); - NVF_ERROR( - shard_additions.size() + shard_deletions.size() <= 1, - "Resharding expr can only support one axis") if ((!shard_deletions.empty() && allocationIndex(input, shard_deletions.at(0)) > 0) || (!shard_additions.empty() && From 8f60b4584b2890ef7205092a18b107059f4b78f5 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 9 Jan 2025 02:44:51 -0800 Subject: [PATCH 32/40] use randint and small sizes in DistributedMatmulTest.AnnotateWeightOnly --- tests/cpp/test_multidevice_matmul.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index 19a7d955db4..632ea15f051 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -412,11 +412,13 @@ TEST_F(DistributedMatmulTest, AnnotateWeightOnly) { w->setDeviceMesh(mesh); w->axis(0)->parallelize(ParallelType::DIDx); - // x is of shape [64, 32] and replicated. - // w is of shape [32, D*128] and column-wise sharded. - // y is expected to have shape [64, D*128] and to be also column-wise sharded. - auto x_tensor = at::randn({64, 32}, tensor_options); - auto w_tensor = at::randn({mesh.size(), 32, 128}, tensor_options); + // x is of shape [2, 3] and replicated. + // w is of shape [3, D*5] and column-wise sharded. + // y is expected to have shape [2, D*5] and to be also column-wise sharded. + constexpr int64_t kLowerBound = 0; + constexpr int64_t kUpperBound = 10; + auto x_tensor = at::randint(kLowerBound, kUpperBound, {2, 3}, tensor_options); + auto w_tensor = at::randint(kLowerBound, kUpperBound, {mesh.size(), 3, 5}, tensor_options); auto sharded_w_tensor = shardTensor(w_tensor, w); FusionExecutorCache executor_cache(std::move(fusion)); From cdd9e46c877bc3efe391fc6c89ce8d47f3be9dc8 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 9 Jan 2025 03:01:29 -0800 Subject: [PATCH 33/40] add bool option ignore_inner_resharding in canLower --- csrc/host_ir/lower.cpp | 6 +++--- csrc/host_ir/lower.h | 2 +- csrc/preseg_passes/insert_reshardings.cpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 84815866dc6..70f5f2794ac 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -307,7 +307,7 @@ std::vector HostIrLower::lower(Expr* c) { return comms; } -bool HostIrLower::canLower(Expr* expr) { +bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) { if (!isResharding(expr)) { return true; } @@ -315,7 +315,7 @@ bool HostIrLower::canLower(Expr* expr) { return false; } if (auto* reduction = dynamic_cast(expr)) { - if (isInnerResharding(expr)) { + if (isInnerResharding(expr) && !ignore_inner_resharding) { return false; } auto in = reduction->in()->as(); @@ -337,7 +337,7 @@ bool HostIrLower::canLower(Expr* expr) { auto c2p_map_it = c2p_map.find(reduction_axis.at(0)); return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim(); } else if (auto* ldst = dynamic_cast(expr)) { - return !isInnerResharding(ldst) && + return (!isInnerResharding(ldst) || ignore_inner_resharding) && ldst->as()->opType() == LoadStoreOpType::Set; } else if (auto* matmul = dynamic_cast(expr)) { // For now we only support c = matmul(a,b) when b,c are fully replicated and diff --git a/csrc/host_ir/lower.h b/csrc/host_ir/lower.h index c99e8a8e76b..e140a80d878 100644 --- a/csrc/host_ir/lower.h +++ b/csrc/host_ir/lower.h @@ -16,7 +16,7 @@ namespace nvfuser { class HostIrLower { public: - static bool canLower(Expr* expr); + static bool canLower(Expr* expr, bool ignore_inner_resharding = false); // Lower a sharded Expr into a series of Communication. static std::vector lower(Expr* c); diff --git a/csrc/preseg_passes/insert_reshardings.cpp b/csrc/preseg_passes/insert_reshardings.cpp index 9d62e0dc1a9..c0b6424a2ad 100644 --- a/csrc/preseg_passes/insert_reshardings.cpp +++ b/csrc/preseg_passes/insert_reshardings.cpp @@ -33,7 +33,7 @@ void insertReshardingsBefore(Fusion* fusion) { // Remove this after we refactor this as a pre-segmenter pass. FusionGuard fg(fusion); for (Expr* expr : fusion->exprs()) { - if (HostIrLower::canLower(expr) || shouldReshardAfter(expr)) { + if (HostIrLower::canLower(expr, /*ignore_inner_resharding=*/true) || shouldReshardAfter(expr)) { continue; } @@ -85,7 +85,7 @@ void insertReshardingsAfter(Fusion* fusion) { auto exprs = fusion->exprs(); for (auto it = std::rbegin(exprs); it != std::rend(exprs); it++) { Expr* expr = *it; - if (HostIrLower::canLower(expr) || !shouldReshardAfter(expr)) { + if (HostIrLower::canLower(expr, /*ignore_inner_resharding=*/true) || !shouldReshardAfter(expr)) { continue; } From 6e9fe35af20be86c92df92bacc6f1fe592bf33f3 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 9 Jan 2025 03:04:40 -0800 Subject: [PATCH 34/40] de-DID-parallelize reduction axis in shardAllLike --- csrc/multidevice/utils.cpp | 41 ++++++++++++++++++++++ csrc/preseg_passes/propagate_shardings.cpp | 41 ---------------------- 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 9f46359eae3..3ff264a55cc 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -487,6 +487,47 @@ void shardAllLike(TensorView* ref, std::vector tvs) { scheduler_utils::parallelizeAllLike( ref, tvs, {ParallelType::DIDx, ParallelType::Serial}); } + + // parallelAllLke, tries to DID-parallelize + // reduction dimensions. For example, + // + // [iDID{i1}, i2] -> (Reduce) -> [r{i1}, i2] -> (Pointwise) -> [i2] + // + // becomes + // + // [iDID{i1}, i2] -> (Reduce) -> [rDID{i1}, i2] -> (Pointwise) -> [i2] + // + // This implies that the reduction result only exists on the "home" device. + // `lower_communication` can't lower such a reduction today. lowerToReduce + // is closest but it uses the output device mesh to indicate the home device. + // Also, an extra broadcast will be needed to replicate the reduction result + // to all devices for the pointwise op. + // + // Therefore, instead, we remove the DID from reduction dimensions and + // therefore reset them to Serial. This way, + // the above becomes + // + // [iDID{i1}, i2] -> (Reduce) -> [r{i1}, i2] -> (Pointwise) -> [i2] + // + // where the reduction will be lowered to an Allreduce. + // + // Alternatively, @naoyam proposed to represent an allreduce as a reduce + // followed by a broadcasting set. + // + // [iDID{i1}, i2] -> (Reduce) -> [rDID{i1}, i2] -> (Set) [i2] -> (Pointwise) + // -> [i2] + // + // This will make the semantics similar to other parallel types and therefore + // we can better leverage existing parallelization utilities. We have yet to + // pursue this because of implementation difficulty -- `lower_communication` + // would need to match the reduce-set pattern. + for (TensorView* tv : tvs) { + for (IterDomain* id : tv->getLoopDomain()) { + if (id->isReduction() && id->isDeviceDim()) { + id->parallelize(ParallelType::Serial); + } + } + } } void shardBetween( diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 3f8e514184e..69ba5983060 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -98,47 +98,6 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { shardAllLike(ref_input, outputs_without_mesh); } - // shardAllLike, which calls parallelAllLke, tries to DID-parallelize - // reduction dimensions. For example, - // - // [iDID{i1}, i2] -> (Reduce) -> [r{i1}, i2] -> (Pointwise) -> [i2] - // - // becomes - // - // [iDID{i1}, i2] -> (Reduce) -> [rDID{i1}, i2] -> (Pointwise) -> [i2] - // - // This implies that the reduction result only exists on the "home" device. - // `lower_communication` can't lower such a reduction today. lowerToReduce - // is closest but it uses the output device mesh to indicate the home device. - // Also, an extra broadcast will be needed to replicate the reduction result - // to all devices for the pointwise op. - // - // Therefore, instead, we remove the DID from reduction dimensions and - // therefore reset them to Serial. This way, - // the above becomes - // - // [iDID{i1}, i2] -> (Reduce) -> [r{i1}, i2] -> (Pointwise) -> [i2] - // - // where the reduction will be lowered to an Allreduce. - // - // Alternatively, @naoyam proposed to represent an allreduce as a reduce - // followed by a broadcasting set. - // - // [iDID{i1}, i2] -> (Reduce) -> [rDID{i1}, i2] -> (Set) [i2] -> (Pointwise) - // -> [i2] - // - // This will make the semantics similar to other parallel types and therefore - // we can better leverage existing parallelization utilities. We have yet to - // pursue this because of implementation difficulty -- `lower_communication` - // would need to match the reduce-set pattern. - for (TensorView* tv : fusion->allTvs()) { - for (IterDomain* id : tv->getLoopDomain()) { - if (id->isReduction() && id->isDeviceDim()) { - id->parallelize(ParallelType::Serial); - } - } - } - // Back-propagate device meshes. This makes sure all TensorViews have a mesh // if any of them has one. This is needed in addition to the forward // propagation for ops that don't take any TensorView operands, e.g., From 428600d50e389db57c2eaf0dcc20a4508617f376 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 9 Jan 2025 03:10:59 -0800 Subject: [PATCH 35/40] revert patch on isSharded --- csrc/multidevice/utils.cpp | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 3ff264a55cc..a0a0e0c3c90 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -99,7 +99,6 @@ std::pair, std::vector> getShardingChanges bool isSharded(const TensorView* tv) { bool is_sharded = false; - bool is_reduction_sharded = false; for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { if (!alloc_id->isDeviceDim()) { continue; @@ -107,15 +106,10 @@ bool isSharded(const TensorView* tv) { // Only one axis can be sharded on DIDx. NVF_ERROR( - !is_sharded && !is_reduction_sharded, + !is_sharded, "Multiple IterDomains parallelized on DIDx in TensorView ", tv); - - if (alloc_id->isReduction()) { - is_reduction_sharded = true; - } else { - is_sharded = true; - } + is_sharded = true; } return is_sharded; } From e70c00ffae63815d8fc835cf81a7c89c4286ff3a Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 9 Jan 2025 03:13:58 -0800 Subject: [PATCH 36/40] revert patch on getShardedLogicalAxis and isInnerResharding --- csrc/multidevice/utils.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index a0a0e0c3c90..771bf19d6f8 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -160,7 +160,7 @@ int64_t getShardedLogicalAxis( std::unordered_map parallel_type_to_id = mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain()); IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type); - if (alloc_id == nullptr || alloc_id->isReduction()) { + if (alloc_id == nullptr) { return -1; } @@ -417,9 +417,7 @@ bool haveDifferentShardings( .strictAreMapped(a, b); }; - if (!is_mapped_in_id_model(p_loop_id, c_loop_id, id_model) || - (p_loop_id != nullptr && c_loop_id != nullptr && - p_loop_id->isReduction() != c_loop_id->isReduction())) { + if (!is_mapped_in_id_model(p_loop_id, c_loop_id, id_model)) { return true; } } @@ -462,6 +460,9 @@ bool isInnerResharding(Expr* expr) { for (auto output : ir_utils::filterByType(expr->outputs())) { auto [shard_additions, shard_deletions] = getShardingChanges(input, output); + NVF_ERROR( + shard_additions.size() + shard_deletions.size() <= 1, + "Resharding expr can only support one axis") if ((!shard_deletions.empty() && allocationIndex(input, shard_deletions.at(0)) > 0) || (!shard_additions.empty() && From 0f93a437cb9d35f260b1121dbd41c26a3c282ef7 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 9 Jan 2025 03:19:50 -0800 Subject: [PATCH 37/40] revert accepting multiple IO in ReorderShardedAxisPass --- csrc/preseg_passes/reorder_sharded_axis.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/csrc/preseg_passes/reorder_sharded_axis.cpp b/csrc/preseg_passes/reorder_sharded_axis.cpp index cde09fa18b2..bf68f6aa9c5 100644 --- a/csrc/preseg_passes/reorder_sharded_axis.cpp +++ b/csrc/preseg_passes/reorder_sharded_axis.cpp @@ -28,13 +28,18 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) { if (HostIrLower::canLower(expr)) { continue; } - if (expr->outputs().size() > 1 || expr->inputs().size() > 1) { - continue; - } NVF_ERROR( ir_utils::isTvOp(expr), "Non-tv op is not supported: ", expr->toString()); + NVF_ERROR( + expr->outputs().size() == 1, + "Resharding operations can only have one output: ", + expr->toString()); + NVF_ERROR( + expr->inputs().size() == 1, + "Resharding operations can have only one input: ", + expr->toString()); auto* output = expr->outputs().at(0)->as(); auto* input = expr->inputs().at(0)->as(); auto [shard_additions, shard_deletions] = getShardingChanges(input, output); From 6b11d3399443cf6611967a0439821564e02e7015 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 9 Jan 2025 03:26:09 -0800 Subject: [PATCH 38/40] revert switching order of passes ReorderShardedAxisPass and InsertReshardingsPass --- csrc/host_ir/lower.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 70f5f2794ac..8374388c4d5 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -463,10 +463,10 @@ std::unique_ptr HostIrLower::lower( // Note: passes run before PreSegmenter optimization passes. preseg_passes::OptimizationPass< preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); - preseg_passes::OptimizationPass< - preseg_passes::ReorderShardedAxisPass>::runPass(fusion.get()); preseg_passes::OptimizationPass< preseg_passes::InsertReshardingsPass>::runPass(fusion.get()); + preseg_passes::OptimizationPass< + preseg_passes::ReorderShardedAxisPass>::runPass(fusion.get()); preseg_passes::OptimizationPass< preseg_passes::MakeReshardingContiguousPass>::runPass(fusion.get()); From 061955f410eaa9f6856e59825e8609c55861609c Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 9 Jan 2025 03:26:17 -0800 Subject: [PATCH 39/40] lint --- csrc/preseg_passes/insert_reshardings.cpp | 6 ++++-- tests/cpp/test_multidevice_matmul.cpp | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/csrc/preseg_passes/insert_reshardings.cpp b/csrc/preseg_passes/insert_reshardings.cpp index c0b6424a2ad..c634616d1d6 100644 --- a/csrc/preseg_passes/insert_reshardings.cpp +++ b/csrc/preseg_passes/insert_reshardings.cpp @@ -33,7 +33,8 @@ void insertReshardingsBefore(Fusion* fusion) { // Remove this after we refactor this as a pre-segmenter pass. FusionGuard fg(fusion); for (Expr* expr : fusion->exprs()) { - if (HostIrLower::canLower(expr, /*ignore_inner_resharding=*/true) || shouldReshardAfter(expr)) { + if (HostIrLower::canLower(expr, /*ignore_inner_resharding=*/true) || + shouldReshardAfter(expr)) { continue; } @@ -85,7 +86,8 @@ void insertReshardingsAfter(Fusion* fusion) { auto exprs = fusion->exprs(); for (auto it = std::rbegin(exprs); it != std::rend(exprs); it++) { Expr* expr = *it; - if (HostIrLower::canLower(expr, /*ignore_inner_resharding=*/true) || !shouldReshardAfter(expr)) { + if (HostIrLower::canLower(expr, /*ignore_inner_resharding=*/true) || + !shouldReshardAfter(expr)) { continue; } diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index 632ea15f051..cbc920bb01f 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -418,7 +418,8 @@ TEST_F(DistributedMatmulTest, AnnotateWeightOnly) { constexpr int64_t kLowerBound = 0; constexpr int64_t kUpperBound = 10; auto x_tensor = at::randint(kLowerBound, kUpperBound, {2, 3}, tensor_options); - auto w_tensor = at::randint(kLowerBound, kUpperBound, {mesh.size(), 3, 5}, tensor_options); + auto w_tensor = at::randint( + kLowerBound, kUpperBound, {mesh.size(), 3, 5}, tensor_options); auto sharded_w_tensor = shardTensor(w_tensor, w); FusionExecutorCache executor_cache(std::move(fusion)); From 1a8390897a6304ff4fc6432e517192ef0e012f54 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 9 Jan 2025 07:44:26 -0800 Subject: [PATCH 40/40] move ignore_inner_resharding as LHS of bool op to lazy evaluate the predicate --- csrc/host_ir/lower.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 8374388c4d5..bf2c590d613 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -315,7 +315,7 @@ bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) { return false; } if (auto* reduction = dynamic_cast(expr)) { - if (isInnerResharding(expr) && !ignore_inner_resharding) { + if (!ignore_inner_resharding && isInnerResharding(expr)) { return false; } auto in = reduction->in()->as(); @@ -337,7 +337,7 @@ bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) { auto c2p_map_it = c2p_map.find(reduction_axis.at(0)); return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim(); } else if (auto* ldst = dynamic_cast(expr)) { - return (!isInnerResharding(ldst) || ignore_inner_resharding) && + return (ignore_inner_resharding || !isInnerResharding(ldst)) && ldst->as()->opType() == LoadStoreOpType::Set; } else if (auto* matmul = dynamic_cast(expr)) { // For now we only support c = matmul(a,b) when b,c are fully replicated and