Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,7 @@ get_property(onnxruntime_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_
if (onnxruntime_USE_CUDA)
set(CMAKE_CUDA_STANDARD 17)
if(onnxruntime_CUDA_HOME)
file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME})
file(TO_CMAKE_PATH ${onnxruntime_CUDA_HOME} CUDAToolkit_ROOT)
endif()
find_package(CUDAToolkit REQUIRED)

Expand Down
9 changes: 6 additions & 3 deletions cmake/onnxruntime_providers_nv.cmake
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Licensed under the MIT License.
find_package(CUDAToolkit REQUIRED 12.8)
if(onnxruntime_CUDA_HOME)
file(TO_CMAKE_PATH ${onnxruntime_CUDA_HOME} CUDAToolkit_ROOT)
endif()
find_package(CUDAToolkit REQUIRED)
enable_language(CUDA)
if(onnxruntime_DISABLE_CONTRIB_OPS)
message( FATAL_ERROR "To compile TensorRT execution provider contrib ops have to be enabled to dump an engine using com.microsoft:EPContext node." )
Expand Down Expand Up @@ -146,9 +149,9 @@ endif ()
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE Eigen3::Eigen onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface Eigen3::Eigen)
add_dependencies(onnxruntime_providers_nv_tensorrt_rtx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER)
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver)
else()
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver)
endif()
target_include_directories(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${TENSORRT_RTX_INCLUDE_DIR} ${onnx_tensorrt_SOURCE_DIR}
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cuda/shared_inc/cuda_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ std::conditional_t<THRW, void, Status> CudaCall(
const char* file, const int line);

#define CUDA_CALL(expr) (::onnxruntime::CudaCall<cudaError, false>((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__))
#define CU_CALL(expr) (::onnxruntime::CudaCall<CUresult, false>((expr), #expr, "CUDA", CUDA_SUCCESS, "", __FILE__, __LINE__))
#define CUBLAS_CALL(expr) (::onnxruntime::CudaCall<cublasStatus_t, false>((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__))

#define CUSPARSE_CALL(expr) (::onnxruntime::CudaCall<cusparseStatus_t, false>((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__))
Expand All @@ -26,6 +27,7 @@ std::conditional_t<THRW, void, Status> CudaCall(
#define CUFFT_CALL(expr) (::onnxruntime::CudaCall<cufftResult, false>((expr), #expr, "CUFFT", CUFFT_SUCCESS, "", __FILE__, __LINE__))

#define CUDA_CALL_THROW(expr) (::onnxruntime::CudaCall<cudaError, true>((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__))
#define CU_CALL_THROW(expr) (::onnxruntime::CudaCall<CUresult, true>((expr), #expr, "CUDA", CUDA_SUCCESS, "", __FILE__, __LINE__))
#define CUBLAS_CALL_THROW(expr) (::onnxruntime::CudaCall<cublasStatus_t, true>((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__))

#define CUSPARSE_CALL_THROW(expr) (::onnxruntime::CudaCall<cusparseStatus_t, true>((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__))
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/providers/nv_tensorrt_rtx/nv_cuda_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ const char* CudaErrString<cudaError_t>(cudaError_t x) {
return cudaGetErrorString(x);
}

template <>
const char* CudaErrString<CUresult>(CUresult x) {
const char* errorStr = NULL;
cuGetErrorString(x, &errorStr);
return errorStr;
}

#ifndef USE_CUDA_MINIMAL
template <>
const char* CudaErrString<cublasStatus_t>(cublasStatus_t e) {
Expand Down Expand Up @@ -141,5 +148,7 @@ std::conditional_t<THRW, void, Status> CudaCall(

template Status CudaCall<cudaError, false>(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line);
template void CudaCall<cudaError, true>(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line);
template Status CudaCall<CUresult, false>(CUresult retCode, const char* exprString, const char* libName, CUresult successCode, const char* msg, const char* file, const int line);
template void CudaCall<CUresult, true>(CUresult retCode, const char* exprString, const char* libName, CUresult successCode, const char* msg, const char* file, const int line);

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -959,8 +959,6 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info)
device_id_(info.device_id) {
InitProviderOrtApi();

// TODO(maximlianm) remove this since we should be able to compile an AOT context file without GPU

if (!info.has_user_compute_stream) {
// If the app is passing in a compute stream, it already has initialized cuda and created a context.
// Calling cudaSetDevice() will set the default context in the current thread
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <fstream>
#include <unordered_map>
#include <string>
Expand All @@ -10,10 +12,11 @@
#include <iostream>
#include <filesystem>
#include "flatbuffers/idl.h"
#include <NvInferVersion.h>
#include "nv_includes.h"

Check warning on line 15 in onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h:15: Include the directory when naming header files [build/include_subdir] [4]
#include "core/providers/cuda/cuda_pch.h"
#include "core/common/path_string.h"
#include "core/framework/murmurhash3.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"

namespace fs = std::filesystem;

Expand All @@ -31,7 +34,7 @@
* }
*
*/
int GetNumProfiles(std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_shapes) {
static int GetNumProfiles(std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_shapes) {
int num_profile = 0;
for (auto it = profile_shapes.begin(); it != profile_shapes.end(); it++) {
num_profile = static_cast<int>(it->second.size());
Expand All @@ -52,7 +55,7 @@
*
* [Deprecated] Use SerializeProfileV2
*/
void SerializeProfile(const std::string& file_name, std::unordered_map<std::string, std::unordered_map<size_t, std::pair<int64_t, int64_t>>>& shape_ranges) {
static void SerializeProfile(const std::string& file_name, std::unordered_map<std::string, std::unordered_map<size_t, std::pair<int64_t, int64_t>>>& shape_ranges) {
// Serialize profile
flexbuffers::Builder builder;
auto profile_start = builder.StartMap();
Expand All @@ -78,7 +81,7 @@

// Deserialize engine profile
// [Deprecated] Use DeserializeProfileV2
std::unordered_map<std::string, std::unordered_map<size_t, std::pair<int64_t, int64_t>>> DeserializeProfile(std::ifstream& infile) {
static std::unordered_map<std::string, std::unordered_map<size_t, std::pair<int64_t, int64_t>>> DeserializeProfile(std::ifstream& infile) {
// Load flexbuffer
infile.seekg(0, std::ios::end);
size_t length = infile.tellg();
Expand Down Expand Up @@ -153,7 +156,7 @@
* }
*
*/
void SerializeProfileV2(const std::string& file_name, std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>>& shape_ranges) {
static void SerializeProfileV2(const std::string& file_name, std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>>& shape_ranges) {
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] In SerializeProfileV2()";
// Serialize profile
flexbuffers::Builder builder;
Expand Down Expand Up @@ -233,7 +236,7 @@
* }
* }
*/
std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>> DeserializeProfileV2(std::ifstream& infile) {
static std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>> DeserializeProfileV2(std::ifstream& infile) {
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] In DeserializeProfileV2()";
// Load flexbuffer
infile.seekg(0, std::ios::end);
Expand Down Expand Up @@ -278,10 +281,10 @@
* Return false meaning no need to rebuild engine if everything is same.
* Otherwise return true and engine needs to be rebuilt.
*/
bool CompareProfiles(const std::string& file_name,
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_min_shapes,
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_max_shapes,
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_opt_shapes) {
static bool CompareProfiles(const std::string& file_name,
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_min_shapes,
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_max_shapes,
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_opt_shapes) {
std::ifstream profile_file(file_name, std::ios::binary | std::ios::in);
if (!profile_file) {
LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] " << file_name << " doesn't exist.";
Expand Down Expand Up @@ -372,7 +375,7 @@
* Get cache by name
*
*/
std::string GetCachePath(const std::string& root, const std::string& name) {
static std::string GetCachePath(const std::string& root, const std::string& name) {
if (root.empty()) {
return name;
} else {
Expand All @@ -386,7 +389,7 @@
* Get compute capability
*
*/
std::string GetComputeCapability(const cudaDeviceProp& prop) {
static std::string GetComputeCapability(const cudaDeviceProp& prop) {
const std::string compute_capability = std::to_string(prop.major * 10 + prop.minor);
return compute_capability;
}
Expand All @@ -397,7 +400,7 @@
* \param root root path of the cache
* \param file_extension It could be ".engine", ".profile" or ".timing"
*/
std::vector<fs::path> GetCachesByType(const std::string& root, std::string file_extension) {
static std::vector<fs::path> GetCachesByType(const std::string& root, std::string file_extension) {
std::vector<fs::path> cache_files;
for (const auto& entry : fs::directory_iterator(root)) {
if (fs::path(file_extension) == fs::path(entry).extension()) {
Expand All @@ -407,15 +410,15 @@
return cache_files;
}

bool IsCacheExistedByType(const std::string& root, std::string file_extension) {
static bool IsCacheExistedByType(const std::string& root, std::string file_extension) {
auto cache_files = GetCachesByType(root, file_extension);
if (cache_files.size() == 0) {
return false;
}
return true;
}

void RemoveCachesByType(const std::string& root, std::string file_extension) {
static void RemoveCachesByType(const std::string& root, std::string file_extension) {
auto cache_files = GetCachesByType(root, file_extension);
for (const auto& entry : cache_files) {
fs::remove(entry);
Expand All @@ -431,7 +434,7 @@
* compiled kernels, so the name must be unique and deterministic across models and sessions.
* </remarks>
*/
HashValue TRTGenerateId(const GraphViewer& graph_viewer, std::string trt_version, std::string cuda_version) {
static HashValue TRTGenerateId(const GraphViewer& graph_viewer, std::string trt_version, std::string cuda_version) {
HashValue model_hash = 0;

// find the top level graph
Expand Down Expand Up @@ -507,9 +510,9 @@
return model_hash;
}

bool ValidateProfileShapes(std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_min_shapes,
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_max_shapes,
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_opt_shapes) {
static bool ValidateProfileShapes(std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_min_shapes,
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_max_shapes,
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_opt_shapes) {
if (profile_min_shapes.empty() && profile_max_shapes.empty() && profile_opt_shapes.empty()) {
return true;
}
Expand Down Expand Up @@ -552,7 +555,7 @@
*
* Return true if string can be successfully parsed or false if string has wrong format.
*/
bool MakeInputNameShapePair(std::string pair_string, std::pair<std::string, std::vector<int64_t>>& pair) {
static bool MakeInputNameShapePair(std::string pair_string, std::pair<std::string, std::vector<int64_t>>& pair) {
if (pair_string.empty()) {
return true;
}
Expand Down Expand Up @@ -595,7 +598,7 @@
*
* Return true if string can be successfully parsed or false if string has wrong format.
*/
bool ParseProfileShapes(std::string profile_shapes_string, std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_shapes) {
static bool ParseProfileShapes(std::string profile_shapes_string, std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_shapes) {
if (profile_shapes_string.empty()) {
return true;
}
Expand Down Expand Up @@ -628,7 +631,7 @@
return true;
}

std::vector<std::string> split(const std::string& str, char delimiter) {
static std::vector<std::string> split(const std::string& str, char delimiter) {
std::vector<std::string> tokens;
std::string token;
std::istringstream tokenStream(str);
Expand All @@ -638,7 +641,7 @@
return tokens;
}

std::string join(const std::vector<std::string>& vec, const std::string& delimiter) {
static std::string join(const std::vector<std::string>& vec, const std::string& delimiter) {
std::string result;
for (size_t i = 0; i < vec.size(); ++i) {
result += vec[i];
Expand All @@ -657,7 +660,7 @@
* This func will generate the suffix "2068723788287043730_189_fp16"
*
*/
std::string GetCacheSuffix(const std::string& fused_node_name, const std::string& trt_node_name_with_precision) {
static std::string GetCacheSuffix(const std::string& fused_node_name, const std::string& trt_node_name_with_precision) {
std::vector<std::string> split_fused_node_name = split(fused_node_name, '_');
if (split_fused_node_name.size() >= 3) {
// Get index of model hash from fused_node_name
Expand Down Expand Up @@ -697,4 +700,26 @@
return checkTrtDimIsDynamic(tensor->getDimensions());
}
}

struct ScopedContext {
explicit ScopedContext(int device_id) {
CUcontext cu_context = 0;
CU_CALL_THROW(cuCtxGetCurrent(&cu_context));
if (!cu_context) {
// cuCtxGetCurrent succeeded but returned nullptr, which indicates that no CUDA context
// is currently set for this thread. This implicates that there is not user created context.
// We use runtime API to initialize a context for the specified device.
CUDA_CALL_THROW(cudaSetDevice(device_id));
CU_CALL_THROW(cuCtxGetCurrent(&cu_context));
}
CU_CALL_THROW(cuCtxPushCurrent(cu_context));
}

ScopedContext(const ScopedContext&) = delete;

~ScopedContext() {
// Destructor must not throw. Perform a best-effort pop of the current context.
cuCtxPopCurrent(nullptr);
}
};
} // namespace onnxruntime
Loading
Loading