Skip to content

Commit

Permalink
Harden assertBuffersHaveSameSize to check shapes. (#3531)
Browse files Browse the repository at this point in the history
I wrote this to make the allgather-related issue discovered in
#3284 (comment) easier
to expose. And it seems a good runtime check to have in extra, because
`_allgather_base` treats I/O tensors as flat buffers and ignores the
shapes.
  • Loading branch information
wujingyue authored Dec 6, 2024
1 parent 4a215ef commit 76483fe
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 15 deletions.
10 changes: 4 additions & 6 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,15 +378,13 @@ void HostIrEvaluator::handle(Wait* wait) {

namespace {

void allConsumerValsOfHelper(
Val* val,
std::unordered_set<Val*>& visisted_vals) {
if (visisted_vals.find(val) != visisted_vals.end()) {
void allConsumerValsOfHelper(Val* val, std::unordered_set<Val*>& visited_vals) {
if (visited_vals.find(val) != visited_vals.end()) {
return;
}
visisted_vals.insert(val);
visited_vals.insert(val);
for (Val* consumer : ir_utils::consumerValsOf(val)) {
allConsumerValsOfHelper(consumer, visisted_vals);
allConsumerValsOfHelper(consumer, visited_vals);
}
}

Expand Down
9 changes: 6 additions & 3 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,15 @@ void assertBuffersHaveSameSize(
if (bufs1.empty() && bufs2.empty()) {
return;
}
const auto numel = (bufs1.empty() ? bufs2 : bufs1).at(0).numel();
const auto shape = (bufs1.empty() ? bufs2 : bufs1).at(0).sizes();
for (const auto& bufs : {bufs1, bufs2}) {
for (const auto& buf : bufs) {
NVF_ERROR(
buf.numel() == numel,
"all buffers must have the same number of elements");
buf.sizes() == shape,
"all buffers must have the same shape, but got: ",
buf.sizes(),
" vs ",
shape);
}
}
}
Expand Down
30 changes: 24 additions & 6 deletions tests/cpp/test_multidevice_overlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class OverlapTest : public MultiDeviceTest {

void validate() {
auto tc_expected = getExpectedResult();
auto tc_cpu = tc_.to(torch::kCPU);
auto tc_cpu = tc_.cpu();
EXPECT_TRUE(tc_cpu.allclose(tc_expected, 1e-1, 1e-1))
<< "Unexpected results, obtained:" << tc_cpu
<< "\n expected: " << tc_expected;
Expand Down Expand Up @@ -837,18 +837,19 @@ TEST_F(AllgatherOverlapTest, AllgatherBasedPipeliningHostIrImplementation) {
IrBuilder::create<hir::Stream>(stream_index));

TensorView* tva_j = select(tva, 0, j);
TensorView* tva_j_unsqueezed = unsqueeze(tva_j, 0);
TensorView* tva_allgathered_j = select(tva_allgathered, 0, j);

// Setting the DeviceMesh of the communication's I/O is artificial but
// required at this point
DeviceMesh full_mesh(all_devices_);
tva_allgathered_j->setDeviceMesh(full_mesh);
tva_j->setDeviceMesh(full_mesh);
tva_j_unsqueezed->setDeviceMesh(full_mesh);

auto* communication = IrBuilder::create<Communication>(
CommunicationType::Allgather,
/*out=*/tva_allgathered_j,
/*in=*/tva_j,
/*in=*/tva_j_unsqueezed,
/*team=*/all_devices_);
auto* wait = IrBuilder::create<hir::Wait>(communication);

Expand All @@ -864,6 +865,7 @@ TEST_F(AllgatherOverlapTest, AllgatherBasedPipeliningHostIrImplementation) {
std::vector<Expr*> loop_body = {
set_stream,
tva_j->definition(),
tva_j_unsqueezed->definition(),
tva_allgathered_j->definition(),
communication,
wait,
Expand Down Expand Up @@ -899,9 +901,25 @@ TEST_F(AllgatherOverlapTest, AllgatherBasedPipeliningHostIrImplementation) {
for_loop_stream->body().push_back(sync_stream);
hic->pushBackTopLevelExprs(for_loop_stream);

// The following line is artificial but necessary to make
// tva_j->isProducerOf(tvc_j) == true
hic->addOutput(tvc_j);
// The following line is artificial but necessary to make tva_j_unsqueeze a
// consumer of tva_j.
//
// HostIrEvaluator::handle(ForLoop*) relies on `Val::uses()` to find all
// **transitive** consumers of the loop index `j`. `tva_j_unsqueezed` is a
// bit special among all transitive consumers of `j`. It doesn't use `j`
// directly but uses `tva_j` which is a TensorView. TensorView's uses are
// built lazily by Fusion::resetTvUses. For efficiency, Fusion::resetTvUses
// only fix TensorViews that can reach outputs. Therefore, we add
// tva_j_unsqueezed as an output. Other TensorViews don't need this
// treatmenet because they are direct users of `j`, a scalar whose uses are
// built eagerly upon registration.
//
// We could have added `tvc_j` instead as an output, which transitively
// consumes `tva_j_unsqueezed`. However, `tvc_j` has two definitions, a Select
// and a MatmulOp, and StmtSort::getExprs only traverse via the first
// registered definition (i.e. the Select). This sounds like a bug -- I wonder
// how nvFuser resets the TensorView uses of a kir::Kernel, also non-SSA.
hic->addOutput(tva_j_unsqueezed);

hir::HostIrEvaluator hie(std::move(hic), communicator_);

Expand Down

0 comments on commit 76483fe

Please sign in to comment.