Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN EP] Fix multithread sync bug in ETW callback #23156

Merged
merged 1 commit into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ void QnnLogging(const char* format,
}
}

Status QnnBackendManager::InitializeQnnLog() {
Status QnnBackendManager::InitializeQnnLog(const logging::Logger& logger) {
logger_ = &logger;

// Set Qnn log level align with Ort log level
auto ort_log_level = logger_->GetSeverity();
QnnLog_Level_t qnn_log_level = MapOrtSeverityToQNNLogLevel(ort_log_level);
Expand Down Expand Up @@ -303,23 +305,15 @@ QnnLog_Level_t QnnBackendManager::MapOrtSeverityToQNNLogLevel(logging::Severity
}
}

Status QnnBackendManager::ResetQnnLogLevel() {
Status QnnBackendManager::ResetQnnLogLevel(std::optional<logging::Severity> ort_log_level) {
std::lock_guard<std::mutex> lock(logger_mutex_);

if (backend_setup_completed_ && logger_ != nullptr) {
auto ort_log_level = logger_->GetSeverity();
LOGS(*logger_, INFO) << "Reset Qnn log level to ORT Logger level: " << (unsigned int)ort_log_level;
return UpdateQnnLogLevel(ort_log_level);
if (!backend_setup_completed_ || logger_ == nullptr) {
return Status::OK();
}
return Status::OK();
}

Status QnnBackendManager::UpdateQnnLogLevel(logging::Severity ort_log_level) {
ORT_RETURN_IF(nullptr == log_handle_, "Unable to update QNN Log Level. Invalid QNN log handle.");
ORT_RETURN_IF(false == backend_setup_completed_, "Unable to update QNN Log Level. Backend setup not completed.");
ORT_RETURN_IF(nullptr == logger_, "Unable to update QNN Log Level. Invalid logger.");

QnnLog_Level_t qnn_log_level = MapOrtSeverityToQNNLogLevel(ort_log_level);
logging::Severity actual_log_level = ort_log_level.has_value() ? *ort_log_level : logger_->GetSeverity();
QnnLog_Level_t qnn_log_level = MapOrtSeverityToQNNLogLevel(actual_log_level);

LOGS(*logger_, INFO) << "Updating Qnn log level to: " << qnn_log_level;

Expand All @@ -332,7 +326,8 @@ Status QnnBackendManager::UpdateQnnLogLevel(logging::Severity ort_log_level) {
LOGS(*logger_, ERROR) << "Invalid log handle provided to QnnLog_setLogLevel.";
}
}
ORT_RETURN_IF(QNN_BACKEND_NO_ERROR != result, "Failed to set log level in Qnn backend. Error: ", QnnErrorHandleToString(result));
ORT_RETURN_IF(QNN_BACKEND_NO_ERROR != result,
"Failed to set log level in Qnn backend. Error: ", QnnErrorHandleToString(result));
return Status::OK();
}

Expand Down Expand Up @@ -823,7 +818,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger,
LOGS(logger, VERBOSE) << "Backend build version: "
<< sdk_build_version_;

SetLogger(&logger);
ORT_RETURN_IF_ERROR(InitializeQnnLog(logger));
LOGS(logger, VERBOSE) << "SetLogger succeed.";

ORT_RETURN_IF_ERROR(InitializeBackend());
Expand Down Expand Up @@ -1049,6 +1044,24 @@ Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id)
return Status::OK();
}

Status QnnBackendManager::TerminateQnnLog() {
std::lock_guard<std::mutex> lock(logger_mutex_);
if (logger_ == nullptr) {
return Status::OK();
}

if (nullptr != qnn_interface_.logFree && nullptr != log_handle_) {
auto ret_val = qnn_interface_.logFree(log_handle_);

// Reset QNN log handle to nullptr so other threads that are waiting on logger_mutex_ know it was freed.
log_handle_ = nullptr;
ORT_RETURN_IF(QNN_SUCCESS != ret_val,
"Unable to terminate logging in the backend.");
}

return Status::OK();
}

void QnnBackendManager::ReleaseResources() {
if (!backend_setup_completed_) {
return;
Expand All @@ -1074,7 +1087,6 @@ void QnnBackendManager::ReleaseResources() {
ORT_THROW("Failed to ShutdownBackend.");
}

std::lock_guard<std::mutex> lock(logger_mutex_);
result = TerminateQnnLog();
if (Status::OK() != result) {
ORT_THROW("Failed to TerminateQnnLog.");
Expand Down
46 changes: 18 additions & 28 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class QnnBackendManager {
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
int64_t max_spill_fill_size);

// Initializes handles to QNN resources (device, logger, etc.).
// NOTE: This function locks the internal `logger_mutex_`.
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib);

Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id);
Expand All @@ -121,34 +123,10 @@ class QnnBackendManager {

const Qnn_ProfileHandle_t& GetQnnProfileHandle() { return profile_backend_handle_; }

void SetLogger(const logging::Logger* logger) {
if (logger_ == nullptr) {
logger_ = logger;
(void)InitializeQnnLog();
}
}

Status InitializeQnnLog();

Status UpdateQnnLogLevel(logging::Severity ort_log_level);

Status ResetQnnLogLevel();

// Terminate logging in the backend
Status TerminateQnnLog() {
if (logger_ == nullptr) {
return Status::OK();
}

if (nullptr != qnn_interface_.logFree && nullptr != log_handle_) {
ORT_RETURN_IF(QNN_SUCCESS != qnn_interface_.logFree(log_handle_),
"Unable to terminate logging in the backend.");
}

return Status::OK();
}

void ReleaseResources();
// Resets the QNN log level to the given ORT log level or to the default log level if the argument is
// std::nullopt.
// NOTE: This function locks the internal `logger_mutex_`.
Status ResetQnnLogLevel(std::optional<logging::Severity> ort_log_level = std::nullopt);

Status ExtractBackendProfilingInfo();
Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile,
Expand All @@ -171,6 +149,18 @@ class QnnBackendManager {
uint64_t& max_spill_fill_buffer_size);

private:
// Sets the ORT logger and creates a corresponding QNN logger with the same log level.
// NOTE: caller must lock the `logger_mutex_` before calling this function.
Status InitializeQnnLog(const logging::Logger& logger);

// Terminate logging in the backend
// NOTE: This function locks the internal `logger_mutex_`.
Status TerminateQnnLog();

// Releases all QNN resources. Called in the destructor.
// NOTE: This function indirectly locks the internal `logger_mutex_` via nested function calls.
void ReleaseResources();

void* LoadLib(const char* file_name, int flags, std::string& error_msg);

Status LoadQnnSystemLib();
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) {
if ((MatchAnyKeyword & static_cast<ULONGLONG>(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) {
auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity();
(void)qnn_backend_manager_->UpdateQnnLogLevel(ortETWSeverity);
(void)qnn_backend_manager_->ResetQnnLogLevel(ortETWSeverity);
adrianlizarraga marked this conversation as resolved.
Show resolved Hide resolved
}
if ((MatchAnyKeyword & static_cast<ULONGLONG>(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) {
if (Level != 0) {
Expand All @@ -439,7 +439,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio

if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) {
// (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID);
(void)qnn_backend_manager_->ResetQnnLogLevel();
(void)qnn_backend_manager_->ResetQnnLogLevel(std::nullopt);
}
});
etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_);
Expand Down
Loading