Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower distributed matmul to pipelined algorithm for fine-grained overlap #3606

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ HostIrEvaluator::HostIrEvaluator(
{container_->getDefaultStream(),
c10::cuda::getDefaultCUDAStream(
static_cast<c10::DeviceIndex>(device_index))});
expr_evaluator_.bind("numberOfStreams", params_.number_of_streams);
}

std::vector<at::Tensor> HostIrEvaluator::runWithInput(
Expand Down
3 changes: 3 additions & 0 deletions csrc/host_ir/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ struct HostIrEvaluatorParams {
// Experimental: whether to cache fusion executor. WAR: avoid recompilation
// but implicitely assumes that the input shape don't change over iterations
bool cache_fusion_executor = false;
// number of additional cuda streams to use at runtime for comm+compute
// pipelining
int64_t number_of_streams = 4;
};

class HostIrEvaluator final : public OptOutDispatch {
Expand Down
2 changes: 2 additions & 0 deletions csrc/host_ir/host_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ class Wait : public Expr {
}
};

// Makes the current stream wait on the given stream. Non-blocking from the host
// point of view.
class Synchronize : public Expr {
public:
using Expr::Expr;
Expand Down
152 changes: 136 additions & 16 deletions csrc/host_ir/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <multidevice/device_mesh.h>
#include <multidevice/utils.h>
#include <ops/all_ops.h>
#include <ops/utils.h>
#include <preseg_passes/insert_reshardings.h>
#include <preseg_passes/make_resharding_contiguous.h>
#include <preseg_passes/propagate_shardings.h>
Expand Down Expand Up @@ -235,6 +236,10 @@ void lowerToReduceScatter(
std::vector<Expr*> HostIrLower::lower(Expr* c) {
FusionGuard fg(c->fusion());

if (c->isA<MatmulOp>()) {
return lowerToCollectiveBasedPipelinedGemmComm(c);
}

std::vector<Expr*> comms;
NVF_ERROR(
c->inputs().size() == 1 && c->input(0)->isA<TensorView>() &&
Expand Down Expand Up @@ -310,6 +315,9 @@ bool HostIrLower::canLower(Expr* expr) {
return false;
}
if (expr->isA<ReductionOp>()) {
if (!isInnerResharding(expr)) {
return false;
}
auto in = expr->as<ReductionOp>()->in()->as<TensorView>();
auto out = expr->as<ReductionOp>()->out()->as<TensorView>();
// get the reduced axis
Expand All @@ -328,10 +336,124 @@ bool HostIrLower::canLower(Expr* expr) {
PairwiseLogicalDomainMap(in, out).mapConsumerToProducer();
auto c2p_map_it = c2p_map.find(reduction_axis.at(0));
return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim();
} else {
return expr->isA<LoadStoreOp>() &&
(expr->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set);
} else if (expr->isA<LoadStoreOp>()) {
return isInnerResharding(expr) &&
expr->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set;
} else if (expr->as<MatmulOp>()) {
// For now we only support c = matmul(a,b) when b,c are fully replicated and
// a is sharded on axis 1
auto* matmul = expr->as<MatmulOp>();
return !isSharded(matmul->inB()) && !isSharded(matmul->out()) &&
matmul->inA()->axis(0)->getParallelType() == ParallelType::Serial &&
getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1;
}
return false;
}

std::vector<Expr*> HostIrLower::lowerToCollectiveBasedPipelinedGemmComm(
Expr* expr) {
auto matmul = expr->as<MatmulOp>();
NVF_ERROR(matmul != nullptr, "Expect a MatmulOp, got", expr);
TensorView* tva = matmul->inA();
TensorView* tvb = matmul->inB();
TensorView* tvc = matmul->out();
NVF_ERROR(
!isSharded(tvb), "The B operand ", tvb, " is expected to not be sharded");
NVF_ERROR(
!isSharded(tvc),
"The output ",
matmul->out(),
" is expected to not be sharded");
const int64_t sharded_axis_index =
getShardedLogicalAxis(tva, ParallelType::DIDx);
IterDomain* stream_axis = tva->axis(0);
NVF_ERROR(
stream_axis->getParallelType() == ParallelType::Serial &&
sharded_axis_index == 1,
"The operand A ",
tva,
" is expected to be sharded on the dimension 1");

auto hic = FusionGuard::getCurFusion()->as<hir::HostIrContainer>();

auto* get_current_stream = IrBuilder::create<hir::GetCurrentStream>();
hir::Stream* original_stream = get_current_stream->stream();

TensorView* tva_allgathered =
ops::newValLike(tva, tva->dtype())->as<TensorView>();
tva_allgathered->axis(sharded_axis_index)->parallelize(ParallelType::Serial);
tva_allgathered->setMemoryType(MemoryType::Global);
auto* allocate_tva_allgathered =
IrBuilder::create<kir::Allocate>(tva_allgathered, MemoryType::Global);

tvc->setMemoryType(MemoryType::Global);
auto* allocate_tvc =
IrBuilder::create<kir::Allocate>(tvc, MemoryType::Global);

auto* j =
IrBuilder::create<Val>(DataType::Index); // running index of the for-loop
auto* start = hic->zeroVal();
auto* stop = stream_axis->extent();
auto* step = hic->oneVal();
auto* for_loop = IrBuilder::create<ForLoop>(
stream_axis,
/*index=*/j,
start,
stop,
step,
/*vectorize=*/false,
/*vectorize_shift=*/nullptr,
/*unroll_required=*/false,
CircularBufferLoopStage::NotApplicable,
/*circular_buffer_loop_stage_depth=*/0);

auto* number_of_streams =
IrBuilder::create<NamedScalar>("numberOfStreams", DataType::Int);
auto* stream_index = mod(j, number_of_streams);
auto* stream = IrBuilder::create<hir::Stream>(stream_index);
auto* set_stream = IrBuilder::create<hir::SetCurrentStream>(stream);

TensorView* tva_j = select(tva, 0, j);
TensorView* tva_allgathered_j = select(tva_allgathered, 0, j);
TensorView* tvc_j = select(tvc, 0, j);

NVF_ERROR(
tva->hasDeviceMesh(),
"The matmul's input ",
tva,
"is expected to have a DeviceMesh");
for (auto tv : {tva_j, tva_allgathered_j, tvc_j}) {
tv->setDeviceMesh(tva->getDeviceMesh());
}

auto* communication = IrBuilder::create<Communication>(
CommunicationType::Allgather,
/*out=*/tva_allgathered_j,
/*in=*/tva_j,
/*team=*/tva->getDeviceMesh().vector());
auto* wait = IrBuilder::create<hir::Wait>(communication);

auto* mm = IrBuilder::create<MatmulOp>(tvc_j, tva_allgathered_j, tvb);

auto* set_back_original_stream =
IrBuilder::create<hir::SetCurrentStream>(original_stream);
auto* sync_stream = IrBuilder::create<hir::Synchronize>(stream);

std::vector<Expr*> loop_body = {
set_stream,
tva_j->definition(),
tva_allgathered_j->definition(),
communication,
wait,
tvc_j->definition(),
mm,
set_back_original_stream,
sync_stream};
for (Expr* expr : loop_body) {
for_loop->body().push_back(expr);
}

return {get_current_stream, allocate_tva_allgathered, allocate_tvc, for_loop};
}

std::unique_ptr<hir::HostIrContainer> HostIrLower::lower(
Expand Down Expand Up @@ -396,21 +518,19 @@ std::unique_ptr<hir::HostIrContainer> HostIrLower::lower(
"Communication segments must contain only one Expr");
for (auto* expr :
HostIrLower::lower(ir_cloner.clone(group->exprs().at(0)))) {
hic->pushBackTopLevelExprs(expr);
// Allocate the recv buffers of communications
NVF_ERROR(
expr->isA<Communication>(),
"Expected a Communication but got ",
expr);
auto* communication = expr->as<Communication>();
TensorView* tv = communication->out();
if (tv->getDeviceMesh().has(my_device_index)) {
auto* allocate =
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
hic->pushBackTopLevelExprs(allocate);
if (expr->isA<Communication>()) {
auto* communication = expr->as<Communication>();
TensorView* tv = communication->out();
if (tv->getDeviceMesh().has(my_device_index)) {
auto* allocate =
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
hic->pushBackTopLevelExprs(allocate);
}
auto wait = IrBuilder::create<hir::Wait>(communication);
hic->pushBackTopLevelExprs(wait);
}
hic->pushBackTopLevelExprs(communication);
auto wait = IrBuilder::create<hir::Wait>(communication);
hic->pushBackTopLevelExprs(wait);
}
} else {
auto host_unit = IrBuilder::create<hir::HostUnit>(
Expand Down
3 changes: 3 additions & 0 deletions csrc/host_ir/lower.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class HostIrLower {
static std::unique_ptr<hir::HostIrContainer> lower(
std::unique_ptr<Fusion> fusion,
int64_t my_device_index);

private:
static std::vector<Expr*> lowerToCollectiveBasedPipelinedGemmComm(Expr* expr);
};

} // namespace nvfuser
2 changes: 1 addition & 1 deletion csrc/preseg_passes/reorder_sharded_axis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) {
const std::vector<Expr*>& exprs = fusion->exprs();
for (auto it = std::rbegin(exprs); it != std::rend(exprs); it++) {
Expr* expr = *it;
if (!isResharding(expr)) {
if (HostIrLower::canLower(expr)) {
continue;
}
NVF_ERROR(
Expand Down
55 changes: 55 additions & 0 deletions tests/cpp/test_multidevice_host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <cuda_profiler_api.h>
#include <fusion.h>
#include <host_ir/container.h>
#include <host_ir/executor.h>
Expand Down Expand Up @@ -349,6 +350,60 @@ TEST_F(P2PCommHostIrTest, CoalescedRingPairwiseExchange) {
EXPECT_TRUE(torch::allclose(ref_output, outputs.back()));
}

using OverlapDistributedMatmulTest = MultiDeviceTest;

TEST_F(OverlapDistributedMatmulTest, AG_matmul) {
constexpr int64_t M = 32768;
constexpr int64_t K = 32768;
constexpr int64_t N = 1024;
constexpr int64_t S = 8;
const int64_t D = communicator_->size();
ASSERT_EQ(M % (D * S), 0);

auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

TensorView* a = makeContigTensor(4); //[S, DIDx(D), M/(S*d), K]
TensorView* b = makeContigTensor(2); //[K, N]
TensorView* c = matmul(a, b); //[S, D, M/(S*D), N]

fusion->addInput(a);
fusion->addInput(b);
fusion->addOutput(c);

auto mesh = DeviceMesh::createForNumDevices(D);
a->setDeviceMesh(mesh);
b->setDeviceMesh(mesh);
c->setDeviceMesh(mesh);

a->axis(1)->parallelize(ParallelType::DIDx);

MultiDeviceExecutor executor(std::move(fusion), *communicator_);

auto tensor_options =
at::TensorOptions().dtype(at::kFloat).device(communicator_->device());
at::Tensor ta_unsharded = at::randn({S, D, M / (S * D), K}, tensor_options);
at::Tensor ta = ta_unsharded.slice(
1, communicator_->deviceId(), communicator_->deviceId() + 1);
at::Tensor tb = at::randn({K, N}, tensor_options);
at::Tensor tc_ref = at::matmul(ta_unsharded, tb);

std::vector<c10::IValue> inputs = {ta, tb};
at::Tensor tc;

constexpr int64_t number_of_iterations = 20;
constexpr int64_t number_of_warmup_iterations = 5;
for (const auto& i : c10::irange(number_of_iterations)) {
if (i == number_of_warmup_iterations) {
cudaProfilerStart();
}
tc = executor.runWithInput(inputs).at(0);
}
cudaProfilerStop();

EXPECT_TRUE(torch::allclose(tc_ref, tc));
}

} // namespace hir

} // namespace nvfuser
Loading