Skip to content

Commit

Permalink
Propagate driver context in parallel join build (facebookincubator#11329
Browse files Browse the repository at this point in the history
)

Summary:
When hash build is building join table, it suspends the current running driver, to prevent the parallel building threads from triggering arbitration and create a deadlock situation. The downside of this approach is, we lose the opportunity of reclaiming from it, because we will not wait for the table build to finish and become reclaimable when reclaiming from the operators.
This PR removes the suspension logic. Instead it pushes in the driver context to the parallel building threads so that these threads, when triggering arbitration (because suspension is removed), can correctly suspend the current running driver, so that the reclaiming can wait for it to finish the table build and get a chance to reclaim from it.

Pull Request resolved: facebookincubator#11329

Reviewed By: xiaoxmeng

Differential Revision: D64795445

Pulled By: tanjialiang

fbshipit-source-id: 3856645ede8acc82690c3a28aecec6e1c61b560c
  • Loading branch information
tanjialiang authored and facebook-github-bot committed Oct 25, 2024
1 parent c7fe8e7 commit 2e837a8
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 27 deletions.
12 changes: 10 additions & 2 deletions velox/exec/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,6 @@ std::string blockingReasonToString(BlockingReason reason) {
return "kWaitForArbitration";
}
VELOX_UNREACHABLE();
return "";
}

DriverThreadContext* driverThreadContext() {
Expand All @@ -1153,7 +1152,16 @@ DriverThreadContext* driverThreadContext() {

ScopedDriverThreadContext::ScopedDriverThreadContext(const DriverCtx& driverCtx)
: savedDriverThreadCtx_(driverThreadCtx),
currentDriverThreadCtx_{.driverCtx = driverCtx} {
currentDriverThreadCtx_(DriverThreadContext(&driverCtx)) {
driverThreadCtx = &currentDriverThreadCtx_;
}

ScopedDriverThreadContext::ScopedDriverThreadContext(
const DriverThreadContext* _driverThreadCtx)
: savedDriverThreadCtx_(driverThreadCtx),
currentDriverThreadCtx_(
_driverThreadCtx == nullptr ? nullptr
: _driverThreadCtx->driverCtx()) {
driverThreadCtx = &currentDriverThreadCtx_;
}

Expand Down
16 changes: 14 additions & 2 deletions velox/exec/Driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -735,15 +735,27 @@ class SuspendedSection {

/// Provides the execution context of a driver thread. This is set to a
/// per-thread local variable if the running thread is a driver thread.
struct DriverThreadContext {
const DriverCtx& driverCtx;
class DriverThreadContext {
public:
explicit DriverThreadContext(const DriverCtx* driverCtx)
: driverCtx_(driverCtx) {}

const DriverCtx* driverCtx() const {
VELOX_CHECK_NOT_NULL(driverCtx_);
return driverCtx_;
}

private:
const DriverCtx* driverCtx_;
};

/// Object used to set/restore the driver thread context when driver execution
/// starts/leaves the driver thread.
class ScopedDriverThreadContext {
public:
explicit ScopedDriverThreadContext(const DriverCtx& driverCtx);
explicit ScopedDriverThreadContext(
const DriverThreadContext* _driverThreadCtx);
~ScopedDriverThreadContext();

private:
Expand Down
10 changes: 0 additions & 10 deletions velox/exec/HashBuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -745,16 +745,6 @@ bool HashBuild::finishHashBuild() {
// https://github.com/facebookincubator/velox/issues/3567 is fixed.
CpuWallTiming timing;
{
// If there is a chance the join build is parallel, we suspend the driver
// while the hash table is being built. This is because off-driver thread
// memory allocations inside parallel join build might trigger memory
// arbitration.
std::unique_ptr<SuspendedSection> suspendedSection;
if (allowParallelJoinBuild) {
suspendedSection = std::make_unique<SuspendedSection>(
driverThreadContext()->driverCtx.driver);
}

CpuWallTimer cpuWallTimer{timing};
table_->prepareJoinTable(
std::move(otherTables),
Expand Down
11 changes: 9 additions & 2 deletions velox/exec/HashTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ void HashTable<ignoreNullKeys>::parallelJoinBuild() {
rowPartitions.push_back(table->rows()->createRowPartitions(*rows_->pool()));
}

const auto* driverThreadCtx = driverThreadContext();
// The parallel table partitioning step.
for (auto i = 0; i < numPartitions; ++i) {
auto* table = getTable(i);
Expand All @@ -937,7 +938,10 @@ void HashTable<ignoreNullKeys>::parallelJoinBuild() {
return std::make_unique<bool>(true);
}));
VELOX_CHECK(!partitionSteps.empty());
buildExecutor_->add([step = partitionSteps.back()]() { step->prepare(); });
buildExecutor_->add([driverThreadCtx, step = partitionSteps.back()]() {
ScopedDriverThreadContext scopedDriverThreadContext(driverThreadCtx);
step->prepare();
});
}

std::exception_ptr error;
Expand All @@ -961,7 +965,10 @@ void HashTable<ignoreNullKeys>::parallelJoinBuild() {
return std::make_unique<bool>(true);
}));
VELOX_CHECK(!buildSteps.empty());
buildExecutor_->add([step = buildSteps.back()]() { step->prepare(); });
buildExecutor_->add([driverThreadCtx, step = buildSteps.back()]() {
ScopedDriverThreadContext scopedDriverThreadContext(driverThreadCtx);
step->prepare();
});
}
syncWorkItems(buildSteps, error, offThreadBuildTiming_);

Expand Down
6 changes: 3 additions & 3 deletions velox/exec/MemoryReclaimer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void MemoryReclaimer::enterArbitration() {
return;
}

Driver* const driver = driverThreadCtx->driverCtx.driver;
Driver* const driver = driverThreadCtx->driverCtx()->driver;
if (driver->task()->enterSuspended(driver->state()) != StopReason::kNone) {
// There is no need for arbitration if the associated task has already
// terminated.
Expand All @@ -47,7 +47,7 @@ void MemoryReclaimer::leaveArbitration() noexcept {
// request is not issued from a driver thread.
return;
}
Driver* const driver = driverThreadCtx->driverCtx.driver;
Driver* const driver = driverThreadCtx->driverCtx()->driver;
driver->task()->leaveSuspended(driver->state());
}

Expand Down Expand Up @@ -169,7 +169,7 @@ uint64_t ParallelMemoryReclaimer::reclaim(
void memoryArbitrationStateCheck(memory::MemoryPool& pool) {
const auto* driverThreadCtx = driverThreadContext();
if (driverThreadCtx != nullptr) {
Driver* driver = driverThreadCtx->driverCtx.driver;
Driver* driver = driverThreadCtx->driverCtx()->driver;
if (!driver->state().suspended()) {
VELOX_FAIL(
"Driver thread is not suspended under memory arbitration processing: {}, request memory pool: {}",
Expand Down
4 changes: 2 additions & 2 deletions velox/exec/Operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ void Operator::MemoryReclaimer::enterArbitration() {
return;
}

Driver* const runningDriver = driverThreadCtx->driverCtx.driver;
Driver* const runningDriver = driverThreadCtx->driverCtx()->driver;
if (!FLAGS_velox_memory_pool_capacity_transfer_across_tasks) {
if (auto opDriver = ensureDriver()) {
// NOTE: the current running driver might not be the driver of the
Expand Down Expand Up @@ -652,7 +652,7 @@ void Operator::MemoryReclaimer::leaveArbitration() noexcept {
// is not issued from a driver thread.
return;
}
Driver* const runningDriver = driverThreadCtx->driverCtx.driver;
Driver* const runningDriver = driverThreadCtx->driverCtx()->driver;
if (!FLAGS_velox_memory_pool_capacity_transfer_across_tasks) {
if (auto opDriver = ensureDriver()) {
VELOX_CHECK_EQ(
Expand Down
2 changes: 1 addition & 1 deletion velox/exec/tests/DriverTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1469,7 +1469,7 @@ DEBUG_ONLY_TEST_F(DriverTest, driverThreadContext) {
"facebook::velox::exec::Values::getOutput",
std::function<void(const exec::Values*)>([&](const exec::Values* values) {
ASSERT_TRUE(driverThreadContext() != nullptr);
capturedTask = driverThreadContext()->driverCtx.task.get();
capturedTask = driverThreadContext()->driverCtx()->task.get();
}));
std::vector<RowVectorPtr> batches;
for (int i = 0; i < 4; ++i) {
Expand Down
6 changes: 3 additions & 3 deletions velox/exec/tests/HashJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6545,17 +6545,17 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashBuildAbortDuringAllocation) {
return;
}

auto& driverCtx = driverThreadContext()->driverCtx;
const auto* driverCtx = driverThreadContext()->driverCtx();
ASSERT_EQ(
driverCtx.task->enterSuspended(driverCtx.driver->state()),
driverCtx->task->enterSuspended(driverCtx->driver->state()),
StopReason::kNone);
testData.abortFromRootMemoryPool ? abortPool(pool->root())
: abortPool(pool);
// We can't directly reclaim memory from this hash build operator
// as its driver thread is running and in suspegnsion state.
ASSERT_GE(pool->root()->usedBytes(), 0);
ASSERT_EQ(
driverCtx.task->leaveSuspended(driverCtx.driver->state()),
driverCtx->task->leaveSuspended(driverCtx->driver->state()),
StopReason::kAlreadyTerminated);
ASSERT_TRUE(pool->aborted());
ASSERT_TRUE(pool->root()->aborted());
Expand Down
4 changes: 2 additions & 2 deletions velox/exec/tests/utils/ArbitratorTestUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class FakeMemoryReclaimer : public exec::MemoryReclaimer {
if (driverThreadCtx == nullptr) {
return;
}
auto* driver = driverThreadCtx->driverCtx.driver;
auto* driver = driverThreadCtx->driverCtx()->driver;
ASSERT_TRUE(driver != nullptr);
if (driver->task()->enterSuspended(driver->state()) != StopReason::kNone) {
VELOX_FAIL("Terminate detected when entering suspension");
Expand All @@ -59,7 +59,7 @@ class FakeMemoryReclaimer : public exec::MemoryReclaimer {
if (driverThreadCtx == nullptr) {
return;
}
auto* driver = driverThreadCtx->driverCtx.driver;
auto* driver = driverThreadCtx->driverCtx()->driver;
ASSERT_TRUE(driver != nullptr);
driver->task()->leaveSuspended(driver->state());
}
Expand Down

0 comments on commit 2e837a8

Please sign in to comment.