Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower distributed matmul to pipelined algorithm for fine-grained overlap #3606

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class Val;
f(HostUnit); \
f(PostOnStream); \
f(SetCurrentStream); \
f(GetCurrentStream); \
f(Wait); \
f(Synchronize); \
f(StartCoalescing); \
Expand Down
31 changes: 29 additions & 2 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ HostIrEvaluator::HostIrEvaluator(
HostIrEvaluatorParams params)
: container_(std::move(container)),
communicator_(communicator),
params_(params) {
params_(params),
my_device_index_(communicator_ ? communicator_->deviceId() : 0) {
const DeviceIdxType device_index =
(communicator_ != nullptr && communicator_->is_available())
? communicator_->deviceId()
Expand All @@ -200,6 +201,7 @@ HostIrEvaluator::HostIrEvaluator(
{container_->getDefaultStream(),
c10::cuda::getDefaultCUDAStream(
static_cast<c10::DeviceIndex>(device_index))});
expr_evaluator_.bind("numberOfStreams", params_.number_of_streams);
}

std::vector<at::Tensor> HostIrEvaluator::runWithInput(
Expand All @@ -215,6 +217,12 @@ std::vector<at::Tensor> HostIrEvaluator::runWithInput(
dispatch(expr);
}

c10::cuda::getCurrentCUDAStream(
static_cast<c10::DeviceIndex>(my_device_index_))
.synchronize();
for (auto event : events_) {
NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(event));
}
// Collect global outputs
return getKnownTensorOrUndefined(container_->outputs(), expr_evaluator_);
}
Expand Down Expand Up @@ -273,8 +281,27 @@ void HostIrEvaluator::handle(SetCurrentStream* set_current_stream) {
setCurrentCUDAStream(getCUDAStream(set_current_stream->stream()));
}

void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) {
streams_.insert(
{get_current_stream->stream(),
c10::cuda::getCurrentCUDAStream(
static_cast<c10::DeviceIndex>(my_device_index_))});
}

void HostIrEvaluator::handle(Synchronize* synchronize) {
getCUDAStream(synchronize->stream()).synchronize();
cudaStream_t current_stream =
c10::cuda::getCurrentCUDAStream(
static_cast<c10::DeviceIndex>(my_device_index_))
.stream();
cudaStream_t stream_to_sync = getCUDAStream(synchronize->stream()).stream();

cudaEvent_t event = {};
NVFUSER_CUDA_RT_SAFE_CALL(
cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
NVFUSER_CUDA_RT_SAFE_CALL(cudaEventRecord(event, stream_to_sync));
NVFUSER_CUDA_RT_SAFE_CALL(
cudaStreamWaitEvent(current_stream, event, cudaEventWaitDefault));
NVFUSER_CUDA_RT_SAFE_CALL(cudaEventDestroy(event));
}

void HostIrEvaluator::handle(PostOnStream* post_ir) {
Expand Down
3 changes: 3 additions & 0 deletions csrc/host_ir/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct HostIrEvaluatorParams {
// Experimental: whether to cache fusion executor. WAR: avoid recompilation
// but implicitely assumes that the input shape don't change over iterations
bool cache_fusion_executor = false;
int64_t number_of_streams = 4;
};

class HostIrEvaluator final : public OptOutDispatch {
Expand Down Expand Up @@ -112,6 +113,7 @@ class HostIrEvaluator final : public OptOutDispatch {
private:
using OptOutDispatch::handle;
void handle(SetCurrentStream* set_current_stream) override;
void handle(GetCurrentStream* get_current_stream) override;
void handle(Synchronize* synchronize) override;
void handle(PostOnStream* post_ir) override;
void handle(Communication* communication) override;
Expand All @@ -138,6 +140,7 @@ class HostIrEvaluator final : public OptOutDispatch {
using StreamKey = std::variant<int64_t, Stream*>;
std::unordered_map<StreamKey, c10::cuda::CUDAStream> streams_;
std::unordered_map<Expr*, c10::intrusive_ptr<c10d::Work>> works_;
const int64_t my_device_index_;
};

} // namespace hir
Expand Down
26 changes: 26 additions & 0 deletions csrc/host_ir/host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,32 @@ bool SetCurrentStream::sameAs(const Statement* other) const {
return false;
}

GetCurrentStream::GetCurrentStream(IrBuilderPasskey passkey) : Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(passkey.ir_container_->isA<HostIrContainer>());
auto stream = IrBuilder::createInContainer<Stream>(passkey.ir_container_);
addAttribute(stream);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(GetCurrentStream)

std::string GetCurrentStream::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "GetCurrentStream into " << stream()->toString()
<< std::endl;
return ss.str();
}

// TODO: implement better ?
std::string GetCurrentStream::toInlineString(int indent_size) const {
NVF_CHECK(false, "Cannot be printed inline");
}

// TODO: implement
bool GetCurrentStream::sameAs(const Statement* other) const {
return false;
}

Wait::Wait(IrBuilderPasskey passkey, Expr* expr)
: Expr(passkey, {}, {}, {expr}) {
NVF_ERROR(passkey.ir_container_ != nullptr);
Expand Down
27 changes: 27 additions & 0 deletions csrc/host_ir/host_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,31 @@ class SetCurrentStream : public Expr {
}
};

class GetCurrentStream : public Expr {
public:
using Expr::Expr;
GetCurrentStream(IrBuilderPasskey passkey);

GetCurrentStream(const GetCurrentStream& other) = delete;
GetCurrentStream& operator=(const GetCurrentStream& other) = delete;
GetCurrentStream(GetCurrentStream&& other) = delete;
GetCurrentStream& operator=(GetCurrentStream&& other) = delete;

NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
const char* getOpString() const override {
return "hir::GetCurrentStream";
}

bool sameAs(const Statement* other) const override;

Stream* stream() const {
return attributes_.at(0)->as<Stream>();
}
};

class Wait : public Expr {
public:
using Expr::Expr;
Expand All @@ -186,6 +211,8 @@ class Wait : public Expr {
}
};

// Makes the current stream wait on the given stream. Non-blocking from the host
// point of view.
class Synchronize : public Expr {
public:
using Expr::Expr;
Expand Down
175 changes: 159 additions & 16 deletions csrc/host_ir/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <multidevice/device_mesh.h>
#include <multidevice/utils.h>
#include <ops/all_ops.h>
#include <ops/utils.h>
#include <preseg_passes/insert_reshardings.h>
#include <preseg_passes/make_resharding_contiguous.h>
#include <preseg_passes/propagate_shardings.h>
Expand Down Expand Up @@ -235,6 +236,10 @@ void lowerToReduceScatter(
std::vector<Expr*> HostIrLower::lower(Expr* c) {
FusionGuard fg(c->fusion());

if (c->isA<MatmulOp>()) {
return lowerToCollectiveBasedPipelinedGemmComm(c);
}

std::vector<Expr*> comms;
NVF_ERROR(
c->inputs().size() == 1 && c->input(0)->isA<TensorView>() &&
Expand Down Expand Up @@ -310,6 +315,9 @@ bool HostIrLower::canLower(Expr* expr) {
return false;
}
if (expr->isA<ReductionOp>()) {
if (!isInnerResharding(expr)) {
return false;
}
auto in = expr->as<ReductionOp>()->in()->as<TensorView>();
auto out = expr->as<ReductionOp>()->out()->as<TensorView>();
// get the reduced axis
Expand All @@ -328,10 +336,147 @@ bool HostIrLower::canLower(Expr* expr) {
PairwiseLogicalDomainMap(in, out).mapConsumerToProducer();
auto c2p_map_it = c2p_map.find(reduction_axis.at(0));
return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim();
} else {
return expr->isA<LoadStoreOp>() &&
(expr->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set);
} else if (expr->isA<LoadStoreOp>()) {
return isInnerResharding(expr) &&
expr->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set;
} else if (expr->as<MatmulOp>()) {
// For now we only support c = matmul(a,b) when b,c are fully replicated and
// a is sharded on axis 1
auto* matmul = expr->as<MatmulOp>();
return !isSharded(matmul->inB()) && !isSharded(matmul->out()) &&
matmul->inA()->axis(0)->getParallelType() == ParallelType::Serial &&
getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1;
}
return false;
}

std::vector<Expr*> HostIrLower::lowerToCollectiveBasedPipelinedGemmComm(
Expr* expr) {
auto matmul = expr->as<MatmulOp>();
NVF_ERROR(matmul != nullptr, "Expect a MatmulOp, got", expr);
TensorView* tva = matmul->inA();
TensorView* tvb = matmul->inB();
TensorView* tvc = matmul->out();
NVF_ERROR(
!isSharded(tvb), "The B operand ", tvb, " is expected to be sharded");
samnordmann marked this conversation as resolved.
Show resolved Hide resolved
NVF_ERROR(
!isSharded(tvc),
"The output ",
matmul->out(),
" is expected to be sharded");
samnordmann marked this conversation as resolved.
Show resolved Hide resolved
const int64_t sharded_axis_index =
getShardedLogicalAxis(tva, ParallelType::DIDx);
IterDomain* stream_axis = tva->axis(0);
NVF_ERROR(
stream_axis->getParallelType() == ParallelType::Serial &&
sharded_axis_index == 1,
"The operand A ",
tva,
" is expected to be sharded on the dimension 1");

auto hic = FusionGuard::getCurFusion()->as<hir::HostIrContainer>();

auto* get_current_stream = IrBuilder::create<hir::GetCurrentStream>();
hir::Stream* original_stream = get_current_stream->stream();

TensorView* tva_allgathered =
ops::newValLike(tva, tva->dtype())->as<TensorView>();
tva_allgathered->axis(sharded_axis_index)->parallelize(ParallelType::Serial);
tva_allgathered->setMemoryType(MemoryType::Global);
auto* allocate_tva_allgathered =
IrBuilder::create<kir::Allocate>(tva_allgathered, MemoryType::Global);

tvc->setMemoryType(MemoryType::Global);
auto* allocate_tvc =
IrBuilder::create<kir::Allocate>(tvc, MemoryType::Global);

auto* j =
IrBuilder::create<Val>(DataType::Index); // running index of the for-loop
auto* start = hic->zeroVal();
auto* stop = stream_axis->extent();
auto* step = hic->oneVal();
auto* for_loop = IrBuilder::create<ForLoop>(
stream_axis,
/*index=*/j,
start,
stop,
step,
/*vectorize=*/false,
/*vectorize_shift=*/nullptr,
/*unroll_required=*/false,
CircularBufferLoopStage::NotApplicable,
/*circular_buffer_loop_stage_depth=*/0);

auto* number_of_streams =
IrBuilder::create<NamedScalar>("numberOfStreams", DataType::Int);
auto* stream_index = mod(j, number_of_streams);
auto* stream = IrBuilder::create<hir::Stream>(stream_index);
auto* set_stream = IrBuilder::create<hir::SetCurrentStream>(stream);

TensorView* tva_j = select(tva, 0, j);
TensorView* tva_j_unsqueezed = tva_j; // unsqueeze(tva_j, 0);
samnordmann marked this conversation as resolved.
Show resolved Hide resolved
TensorView* tva_allgathered_j = select(tva_allgathered, 0, j);
TensorView* tvc_j = select(tvc, 0, j);

// [TAG: adding articifial outputs]
// The following line is artificial but necessary to make tva_j_unsqueeze a
// consumer of tva_j.
//
// HostIrEvaluator::handle(ForLoop*) relies on `Val::uses()` to find all
// **transitive** consumers of the loop index `j`. `tva_j_unsqueezed` is a
// bit special among all transitive consumers of `j`. It doesn't use `j`
// directly but uses `tva_j` which is a TensorView. TensorView's uses are
// built lazily by Fusion::resetTvUses. For efficiency, Fusion::resetTvUses
// only fix TensorViews that can reach outputs. Therefore, we add
// tva_j_unsqueezed as an output. Other TensorViews don't need this
// treatmenet because they are direct users of `j`, a scalar whose uses are
// built eagerly upon registration.
//
// We could have added `tvc_j` instead as an output, which transitively
// consumes `tva_j_unsqueezed`. However, `tvc_j` has two definitions, a Select
// and a MatmulOp, and StmtSort::getExprs only traverse via the first
// registered definition (i.e. the Select). This sounds like a bug -- I wonder
// how nvFuser resets the TensorView uses of a kir::Kernel, also non-SSA.
hic->addOutput(tva_j_unsqueezed);

NVF_ERROR(
tva->hasDeviceMesh(),
"The matmul's input ",
tva,
"is expected to have a DeviceMesh");
for (auto tv : {tva_j, tva_allgathered_j, tva_j_unsqueezed, tvc_j}) {
tv->setDeviceMesh(tva->getDeviceMesh());
}

auto* communication = IrBuilder::create<Communication>(
CommunicationType::Allgather,
/*out=*/tva_allgathered_j,
/*in=*/tva_j_unsqueezed,
/*team=*/tva->getDeviceMesh().vector());
auto* wait = IrBuilder::create<hir::Wait>(communication);

auto* mm = IrBuilder::create<MatmulOp>(tvc_j, tva_allgathered_j, tvb);

auto* set_back_original_stream =
IrBuilder::create<hir::SetCurrentStream>(original_stream);
auto* sync_stream = IrBuilder::create<hir::Synchronize>(stream);

std::vector<Expr*> loop_body = {
set_stream,
tva_j->definition(),
tva_j_unsqueezed->definition(),
tva_allgathered_j->definition(),
communication,
wait,
tvc_j->definition(),
mm,
set_back_original_stream,
sync_stream};
for (Expr* expr : loop_body) {
for_loop->body().push_back(expr);
}

return {get_current_stream, allocate_tva_allgathered, allocate_tvc, for_loop};
}

std::unique_ptr<hir::HostIrContainer> HostIrLower::lower(
Expand Down Expand Up @@ -396,21 +541,19 @@ std::unique_ptr<hir::HostIrContainer> HostIrLower::lower(
"Communication segments must contain only one Expr");
for (auto* expr :
HostIrLower::lower(ir_cloner.clone(group->exprs().at(0)))) {
hic->pushBackTopLevelExprs(expr);
// Allocate the recv buffers of communications
NVF_ERROR(
expr->isA<Communication>(),
"Expected a Communication but got ",
expr);
auto* communication = expr->as<Communication>();
TensorView* tv = communication->out();
if (tv->getDeviceMesh().has(my_device_index)) {
auto* allocate =
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
hic->pushBackTopLevelExprs(allocate);
if (expr->isA<Communication>()) {
auto* communication = expr->as<Communication>();
TensorView* tv = communication->out();
if (tv->getDeviceMesh().has(my_device_index)) {
auto* allocate =
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
hic->pushBackTopLevelExprs(allocate);
}
auto wait = IrBuilder::create<hir::Wait>(communication);
hic->pushBackTopLevelExprs(wait);
}
hic->pushBackTopLevelExprs(communication);
auto wait = IrBuilder::create<hir::Wait>(communication);
hic->pushBackTopLevelExprs(wait);
}
} else {
auto host_unit = IrBuilder::create<hir::HostUnit>(
Expand Down
3 changes: 3 additions & 0 deletions csrc/host_ir/lower.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class HostIrLower {
static std::unique_ptr<hir::HostIrContainer> lower(
std::unique_ptr<Fusion> fusion,
int64_t my_device_index);

private:
static std::vector<Expr*> lowerToCollectiveBasedPipelinedGemmComm(Expr* expr);
};

} // namespace nvfuser
Loading
Loading