Skip to content

Commit

Permalink
lower to collective base pipeline AG+GEMM
Browse files Browse the repository at this point in the history
  • Loading branch information
samnordmann committed Dec 18, 2024
1 parent 38721fe commit bb867e8
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 21 deletions.
1 change: 1 addition & 0 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ HostIrEvaluator::HostIrEvaluator(
{container_->getDefaultStream(),
c10::cuda::getDefaultCUDAStream(
static_cast<c10::DeviceIndex>(device_index))});
expr_evaluator_.bind("numberOfStreams", params_.number_of_streams);
}

std::vector<at::Tensor> HostIrEvaluator::runWithInput(
Expand Down
1 change: 1 addition & 0 deletions csrc/host_ir/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct HostIrEvaluatorParams {
// Experimental: whether to cache fusion executor. WAR: avoid recompilation
// but implicitely assumes that the input shape don't change over iterations
bool cache_fusion_executor = false;
int64_t number_of_streams = 4;
};

class HostIrEvaluator final : public OptOutDispatch {
Expand Down
175 changes: 159 additions & 16 deletions csrc/host_ir/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <multidevice/device_mesh.h>
#include <multidevice/utils.h>
#include <ops/all_ops.h>
#include <ops/utils.h>
#include <preseg_passes/insert_reshardings.h>
#include <preseg_passes/make_resharding_contiguous.h>
#include <preseg_passes/propagate_shardings.h>
Expand Down Expand Up @@ -235,6 +236,10 @@ void lowerToReduceScatter(
std::vector<Expr*> HostIrLower::lower(Expr* c) {
FusionGuard fg(c->fusion());

if (c->isA<MatmulOp>()) {
return lowerToCollectiveBasedPipelinedGemmComm(c);
}

std::vector<Expr*> comms;
NVF_ERROR(
c->inputs().size() == 1 && c->input(0)->isA<TensorView>() &&
Expand Down Expand Up @@ -310,6 +315,9 @@ bool HostIrLower::canLower(Expr* expr) {
return false;
}
if (expr->isA<ReductionOp>()) {
if (!isInnerResharding(expr)) {
return false;
}
auto in = expr->as<ReductionOp>()->in()->as<TensorView>();
auto out = expr->as<ReductionOp>()->out()->as<TensorView>();
// get the reduced axis
Expand All @@ -328,10 +336,147 @@ bool HostIrLower::canLower(Expr* expr) {
PairwiseLogicalDomainMap(in, out).mapConsumerToProducer();
auto c2p_map_it = c2p_map.find(reduction_axis.at(0));
return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim();
} else {
return expr->isA<LoadStoreOp>() &&
(expr->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set);
} else if (expr->isA<LoadStoreOp>()) {
return isInnerResharding(expr) &&
expr->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set;
} else if (expr->as<MatmulOp>()) {
// For now we only support c = matmul(a,b) when b,c are fully replicated and
// a is sharded on axis 1
auto* matmul = expr->as<MatmulOp>();
return !isSharded(matmul->inB()) && !isSharded(matmul->out()) &&
matmul->inA()->axis(0)->getParallelType() == ParallelType::Serial &&
getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1;
}
return false;
}

std::vector<Expr*> HostIrLower::lowerToCollectiveBasedPipelinedGemmComm(
Expr* expr) {
auto matmul = expr->as<MatmulOp>();
NVF_ERROR(matmul != nullptr, "Expect a MatmulOp, got", expr);
TensorView* tva = matmul->inA();
TensorView* tvb = matmul->inB();
TensorView* tvc = matmul->out();
NVF_ERROR(
!isSharded(tvb), "The B operand ", tvb, " is expected to be sharded");
NVF_ERROR(
!isSharded(tvc),
"The output ",
matmul->out(),
" is expected to be sharded");
const int64_t sharded_axis_index =
getShardedLogicalAxis(tva, ParallelType::DIDx);
IterDomain* stream_axis = tva->axis(0);
NVF_ERROR(
stream_axis->getParallelType() == ParallelType::Serial &&
sharded_axis_index == 1,
"The operand A ",
tva,
" is expected to be sharded on the dimension 1");

auto hic = FusionGuard::getCurFusion()->as<hir::HostIrContainer>();

auto* get_current_stream = IrBuilder::create<hir::GetCurrentStream>();
hir::Stream* original_stream = get_current_stream->stream();

TensorView* tva_allgathered =
ops::newValLike(tva, tva->dtype())->as<TensorView>();
tva_allgathered->axis(sharded_axis_index)->parallelize(ParallelType::Serial);
tva_allgathered->setMemoryType(MemoryType::Global);
auto* allocate_tva_allgathered =
IrBuilder::create<kir::Allocate>(tva_allgathered, MemoryType::Global);

tvc->setMemoryType(MemoryType::Global);
auto* allocate_tvc =
IrBuilder::create<kir::Allocate>(tvc, MemoryType::Global);

auto* j =
IrBuilder::create<Val>(DataType::Index); // running index of the for-loop
auto* start = hic->zeroVal();
auto* stop = stream_axis->extent();
auto* step = hic->oneVal();
auto* for_loop = IrBuilder::create<ForLoop>(
stream_axis,
/*index=*/j,
start,
stop,
step,
/*vectorize=*/false,
/*vectorize_shift=*/nullptr,
/*unroll_required=*/false,
CircularBufferLoopStage::NotApplicable,
/*circular_buffer_loop_stage_depth=*/0);

auto* number_of_streams =
IrBuilder::create<NamedScalar>("numberOfStreams", DataType::Int);
auto* stream_index = mod(j, number_of_streams);
auto* stream = IrBuilder::create<hir::Stream>(stream_index);
auto* set_stream = IrBuilder::create<hir::SetCurrentStream>(stream);

TensorView* tva_j = select(tva, 0, j);
TensorView* tva_j_unsqueezed = tva_j; // unsqueeze(tva_j, 0);
TensorView* tva_allgathered_j = select(tva_allgathered, 0, j);
TensorView* tvc_j = select(tvc, 0, j);

// [TAG: adding articifial outputs]
// The following line is artificial but necessary to make tva_j_unsqueeze a
// consumer of tva_j.
//
// HostIrEvaluator::handle(ForLoop*) relies on `Val::uses()` to find all
// **transitive** consumers of the loop index `j`. `tva_j_unsqueezed` is a
// bit special among all transitive consumers of `j`. It doesn't use `j`
// directly but uses `tva_j` which is a TensorView. TensorView's uses are
// built lazily by Fusion::resetTvUses. For efficiency, Fusion::resetTvUses
// only fix TensorViews that can reach outputs. Therefore, we add
// tva_j_unsqueezed as an output. Other TensorViews don't need this
// treatmenet because they are direct users of `j`, a scalar whose uses are
// built eagerly upon registration.
//
// We could have added `tvc_j` instead as an output, which transitively
// consumes `tva_j_unsqueezed`. However, `tvc_j` has two definitions, a Select
// and a MatmulOp, and StmtSort::getExprs only traverse via the first
// registered definition (i.e. the Select). This sounds like a bug -- I wonder
// how nvFuser resets the TensorView uses of a kir::Kernel, also non-SSA.
hic->addOutput(tva_j_unsqueezed);

NVF_ERROR(
tva->hasDeviceMesh(),
"The matmul's input ",
tva,
"is expected to have a DeviceMesh");
for (auto tv : {tva_j, tva_allgathered_j, tva_j_unsqueezed, tvc_j}) {
tv->setDeviceMesh(tva->getDeviceMesh());
}

auto* communication = IrBuilder::create<Communication>(
CommunicationType::Allgather,
/*out=*/tva_allgathered_j,
/*in=*/tva_j_unsqueezed,
/*team=*/tva->getDeviceMesh().vector());
auto* wait = IrBuilder::create<hir::Wait>(communication);

auto* mm = IrBuilder::create<MatmulOp>(tvc_j, tva_allgathered_j, tvb);

auto* set_back_original_stream =
IrBuilder::create<hir::SetCurrentStream>(original_stream);
auto* sync_stream = IrBuilder::create<hir::Synchronize>(stream);

std::vector<Expr*> loop_body = {
set_stream,
tva_j->definition(),
tva_j_unsqueezed->definition(),
tva_allgathered_j->definition(),
communication,
wait,
tvc_j->definition(),
mm,
set_back_original_stream,
sync_stream};
for (Expr* expr : loop_body) {
for_loop->body().push_back(expr);
}

return {get_current_stream, allocate_tva_allgathered, allocate_tvc, for_loop};
}

std::unique_ptr<hir::HostIrContainer> HostIrLower::lower(
Expand Down Expand Up @@ -396,21 +541,19 @@ std::unique_ptr<hir::HostIrContainer> HostIrLower::lower(
"Communication segments must contain only one Expr");
for (auto* expr :
HostIrLower::lower(ir_cloner.clone(group->exprs().at(0)))) {
hic->pushBackTopLevelExprs(expr);
// Allocate the recv buffers of communications
NVF_ERROR(
expr->isA<Communication>(),
"Expected a Communication but got ",
expr);
auto* communication = expr->as<Communication>();
TensorView* tv = communication->out();
if (tv->getDeviceMesh().has(my_device_index)) {
auto* allocate =
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
hic->pushBackTopLevelExprs(allocate);
if (expr->isA<Communication>()) {
auto* communication = expr->as<Communication>();
TensorView* tv = communication->out();
if (tv->getDeviceMesh().has(my_device_index)) {
auto* allocate =
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
hic->pushBackTopLevelExprs(allocate);
}
auto wait = IrBuilder::create<hir::Wait>(communication);
hic->pushBackTopLevelExprs(wait);
}
hic->pushBackTopLevelExprs(communication);
auto wait = IrBuilder::create<hir::Wait>(communication);
hic->pushBackTopLevelExprs(wait);
}
} else {
auto host_unit = IrBuilder::create<hir::HostUnit>(
Expand Down
3 changes: 3 additions & 0 deletions csrc/host_ir/lower.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class HostIrLower {
static std::unique_ptr<hir::HostIrContainer> lower(
std::unique_ptr<Fusion> fusion,
int64_t my_device_index);

private:
static std::vector<Expr*> lowerToCollectiveBasedPipelinedGemmComm(Expr* expr);
};

} // namespace nvfuser
6 changes: 4 additions & 2 deletions csrc/multidevice/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ MultiDeviceExecutor::MultiDeviceExecutor(
std::unique_ptr<Fusion> fusion,
Communicator& comm,
hir::HostIrEvaluatorParams params)
: comm_(comm) {
: comm_(comm), number_of_outputs_(fusion->outputs().size()) {
std::unique_ptr<hir::HostIrContainer> hic =
HostIrLower::lower(std::move(fusion), comm.deviceId());
// Create the HostIrEvaluator representing the host program
Expand Down Expand Up @@ -52,7 +52,9 @@ std::vector<at::Tensor> MultiDeviceExecutor::runWithInput(
inputs.at(input_idx);
}

return host_ir_executor_->runWithInput(val_to_IValue);
auto outputs = host_ir_executor_->runWithInput(val_to_IValue);
return std::vector<at::Tensor>(
outputs.end() - number_of_outputs_, outputs.end());
}

std::ostream& MultiDeviceExecutor::print(std::ostream& os) {
Expand Down
6 changes: 6 additions & 0 deletions csrc/multidevice/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ class MultiDeviceExecutor {
Communicator& comm_;
// holds the HostIrEvaluator used for execution
std::unique_ptr<hir::HostIrEvaluator> host_ir_executor_;
// Store the number of outputs before it possibly gets artificially modified
// by HostIr::lower. This is undesirable but required for now. For more
// details, search for the comment in host_ir/lower.cpp tagged with "[TAG:
// adding articifial outputs]"
// TODO: fix
int64_t number_of_outputs_;
};

} // namespace nvfuser
9 changes: 7 additions & 2 deletions csrc/ops/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <expr_simplifier.h>
#include <ir/builder.h>
#include <ir/utils.h>
#include <multidevice/utils.h>
#include <ops/alias.h>
#include <ops/arith.h>
#include <ops/utils.h>
Expand Down Expand Up @@ -947,8 +948,12 @@ TensorView* broadcast(
.iter_type(IterType::Broadcast)
.build());
} else {
out_domain.push_back(
IterDomainBuilder(inp_domain[iinp]).resetSchedulingParams().build());
auto inp_id = inp_domain[iinp];
auto out_id = IterDomainBuilder(inp_id).resetSchedulingParams().build();
if (inp_id->isDeviceDim()) {
out_id->parallelize(inp_id->getParallelType());
}
out_domain.push_back(out_id);
iinp++;
}
ibdim++;
Expand Down
2 changes: 1 addition & 1 deletion csrc/preseg_passes/reorder_sharded_axis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) {
const std::vector<Expr*>& exprs = fusion->exprs();
for (auto it = std::rbegin(exprs); it != std::rend(exprs); it++) {
Expr* expr = *it;
if (!isResharding(expr)) {
if (HostIrLower::canLower(expr)) {
continue;
}
NVF_ERROR(
Expand Down
44 changes: 44 additions & 0 deletions tests/cpp/test_multidevice_host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,50 @@ TEST_F(P2PCommHostIrTest, CoalescedRingPairwiseExchange) {
EXPECT_TRUE(torch::allclose(ref_output, outputs.back()));
}

using OverlapDistributedMatmulTest = MultiDeviceTest;

TEST_F(OverlapDistributedMatmulTest, AG_matmul) {
constexpr int64_t M = 1024;
constexpr int64_t K = 256;
constexpr int64_t N = 512;
constexpr int64_t S = 8;
const int64_t D = communicator_->size();
ASSERT_EQ(M % (D * S), 0);

auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

TensorView* a = makeContigTensor(4); //[S, DIDx(D), M/(S*d), K]
TensorView* b = makeContigTensor(2); //[K, N]
TensorView* c = matmul(a, b); //[S, D, M/(S*D), N]

fusion->addInput(a);
fusion->addInput(b);
fusion->addOutput(c);

auto mesh = DeviceMesh::createForNumDevices(D);
a->setDeviceMesh(mesh);
b->setDeviceMesh(mesh);
c->setDeviceMesh(mesh);

a->axis(1)->parallelize(ParallelType::DIDx);

MultiDeviceExecutor executor(std::move(fusion), *communicator_);

auto tensor_options =
at::TensorOptions().dtype(at::kFloat).device(communicator_->device());
at::Tensor ta_unsharded = at::randn({S, D, M / (S * D), K}, tensor_options);
at::Tensor ta = ta_unsharded.slice(
1, communicator_->deviceId(), communicator_->deviceId() + 1);
at::Tensor tb = at::randn({K, N}, tensor_options);
at::Tensor tc_ref = at::matmul(ta_unsharded, tb);

std::vector<c10::IValue> inputs = {ta, tb};
auto tc = executor.runWithInput(inputs).at(0);

EXPECT_TRUE(torch::allclose(tc_ref, tc));
}

} // namespace hir

} // namespace nvfuser

0 comments on commit bb867e8

Please sign in to comment.