From b5ea730c326639f21f0e98ccc909bf4f7b9c5750 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Thu, 12 Dec 2024 15:06:49 -0800 Subject: [PATCH] Add more safety checks to BlockingCounter This makes it harder to have use-after-free situations. PiperOrigin-RevId: 705643798 --- tsl/platform/BUILD | 5 ++- tsl/platform/blocking_counter.h | 77 +++++++++++++++++++++++++++------ 2 files changed, 68 insertions(+), 14 deletions(-) diff --git a/tsl/platform/BUILD b/tsl/platform/BUILD index f160011f6..be1f120fc 100644 --- a/tsl/platform/BUILD +++ b/tsl/platform/BUILD @@ -60,7 +60,10 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ ":logging", - ":mutex", + "//util/symbolize", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", ], ) diff --git a/tsl/platform/blocking_counter.h b/tsl/platform/blocking_counter.h index c085e4d66..1c1122ee0 100644 --- a/tsl/platform/blocking_counter.h +++ b/tsl/platform/blocking_counter.h @@ -17,9 +17,14 @@ limitations under the License. #define TENSORFLOW_TSL_PLATFORM_BLOCKING_COUNTER_H_ #include +#include // NOLINT +#include +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" #include "tsl/platform/logging.h" -#include "tsl/platform/mutex.h" +#include "util/symbolize/symbolized_stacktrace.h" namespace tsl { @@ -28,10 +33,13 @@ class BlockingCounter { BlockingCounter(int initial_count) : state_(initial_count << 1), notified_(false) { CHECK_GE(initial_count, 0); - DCHECK_EQ((initial_count << 1) >> 1, initial_count); + DCHECK_EQ((static_cast(initial_count) << 1) >> 1, + initial_count); } - ~BlockingCounter() {} + ~BlockingCounter() = default; + + static thread_local constexpr char kNonce = 0; inline void DecrementCount() { unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2; @@ -39,29 +47,70 @@ class BlockingCounter { DCHECK_NE(((v + 2) & ~1), 0); return; // either count has not dropped to 0, or waiter is not waiting } - mutex_lock l(mu_); + absl::MutexLock l(&mu_); DCHECK(!notified_); notified_ = true; - cond_var_.notify_all(); + cond_var_.SignalAll(); } inline void Wait() { + LOG(INFO) << "kNonce: " << (const void*)&kNonce; + const void* prior_last_waiter_addr = + last_waiter_addr_.load(std::memory_order_relaxed); + if (prior_last_waiter_addr != nullptr) { + CHECK_EQ(prior_last_waiter_addr, (const void*)&kNonce) + << "multiple threads called WaitFor()"; + } else { + auto expected = prior_last_waiter_addr; + if (!last_waiter_addr_.compare_exchange_strong( + expected, &kNonce, std::memory_order_relaxed)) { + LOG(FATAL) << "Tried to swap " << prior_last_waiter_addr << " with " + << (const void*)&kNonce << " but found " << expected; + } + } unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); if ((v >> 1) == 0) return; - mutex_lock l(mu_); + absl::MutexLock l(&mu_); + + // only one thread may call Wait(). To support more than one thread, + // implement a counter num_to_exit, like in the Barrier class. + CHECK_EQ(num_waiting_, 0) << "multiple threads called Wait()"; + num_waiting_++; + while (!notified_) { - cond_var_.wait(l); + cond_var_.Wait(&mu_); } } // Wait for the specified time, return false iff the count has not dropped to // zero before the timeout expired. inline bool WaitFor(std::chrono::milliseconds ms) { + LOG(INFO) << "this: " << this << " kNonce: " << (const void*)&kNonce; + const void* prior_last_waiter_addr = + last_waiter_addr_.load(std::memory_order_relaxed); + if (prior_last_waiter_addr != nullptr) { + CHECK_EQ(prior_last_waiter_addr, (const void*)&kNonce) + << "multiple threads called WaitFor(): " << last_waiter_addr_ << " " + << &kNonce; + } else { + auto expected = prior_last_waiter_addr; + if (!last_waiter_addr_.compare_exchange_strong( + expected, &kNonce, std::memory_order_relaxed)) { + LOG(FATAL) << "Tried to swap " << prior_last_waiter_addr << " with " + << (const void*)&kNonce << " but found " << expected; + } + LOG(INFO) << util::CurrentStackTrace(); + } + unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); if ((v >> 1) == 0) return true; - mutex_lock l(mu_); + absl::Duration timeout = absl::FromChrono(ms); + absl::MutexLock l(&mu_); + + // only one thread may call Wait(). To support more than one thread, + // implement a counter num_to_exit, like in the Barrier class. + while (!notified_) { - const std::cv_status status = cond_var_.wait_for(l, ms); - if (status == std::cv_status::timeout) { + if (cond_var_.WaitWithTimeout(&mu_, timeout)) { return false; } } @@ -69,10 +118,12 @@ class BlockingCounter { } private: - mutex mu_; - condition_variable cond_var_; + absl::Mutex mu_; + absl::CondVar cond_var_; std::atomic state_; // low bit is waiter flag - bool notified_; + std::atomic last_waiter_addr_ = nullptr; + int num_waiting_ ABSL_GUARDED_BY(mu_) = 0; + bool notified_ ABSL_GUARDED_BY(mu_); }; } // namespace tsl