Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions onnxruntime/test/unittest_util/base_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ void BaseTester::ExecuteModel(Model& model, SessionType& session,
bool SetEpsForAllNodes(Graph& graph,
const std::vector<std::unique_ptr<IExecutionProvider>>& execution_providers,
const std::vector<std::shared_ptr<CustomRegistry>>* custom_registries,
const std::function<bool(const IExecutionProvider&)>& ep_uses_kernel_registry_fn) {
const std::function<bool(const IExecutionProvider&)>& ep_only_uses_kernel_registry_fn) {
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
const KernelRegistry::TypeConstraintMap type_constraint_map{};

Expand All @@ -440,7 +440,7 @@ bool SetEpsForAllNodes(Graph& graph,

node.SetExecutionProviderType(provider_type);

if (!ep_uses_kernel_registry_fn(*ep)) {
if (!ep_only_uses_kernel_registry_fn(*ep)) {
found = true;
break;
}
Expand Down Expand Up @@ -830,12 +830,12 @@ void BaseTester::ExecuteModelForEps(

ASSERT_TRUE(!execution_providers.empty()) << "Empty execution providers vector.";
if (try_assign_ep_for_nodes) {
auto ep_uses_kernel_registry = [](const IExecutionProvider& ep) {
auto ep_only_uses_kernel_registry = [](const IExecutionProvider& ep) {
const auto& provider_type = ep.Type();

constexpr std::array kEpsThatDoNotUseKernelRegistry{
constexpr std::array kEpsThatCompileNodes{
kOpenVINOExecutionProvider,
kTensorrtExecutionProvider,
kTensorrtExecutionProvider, // uses kernel registry for Memcpy* nodes, but compiles all others.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to confirm, the plugin EP name will be different, right? that assumption is made here:

ORT_ENFORCE(std::find(all_provider_types.begin(), all_provider_types.end(),
*dynamic_plugin_ep_name) == all_provider_types.end(),
"Dynamic plugin EP name conflicts with a known EP name: ", *dynamic_plugin_ep_name);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. @chilo-ms do you know if the plugin EP name is different?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume so, but let me confirm with Nvidia for their plugin TRT RTX EP.

Also, do we want to include provider-bridge TRT RTX EP, kNvTensorRTRTXExecutionProvider, in this list as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should only add kNvTensorRTRTXExecutionProvider to this list if its name differs from the plugin EP version, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should only add kNvTensorRTRTXExecutionProvider to this list if its name differs from the plugin EP version, right?

Yes, waiting for Nvidia's reply for the name of their plugin TRT RTX EP, will update.

kNnapiExecutionProvider,
kVSINPUExecutionProvider,
kCoreMLExecutionProvider,
Expand All @@ -844,24 +844,33 @@ void BaseTester::ExecuteModelForEps(
kSnpeExecutionProvider,
};

// check list of known EPs that do not use a kernel registry
if (const auto ep_it = std::find(kEpsThatDoNotUseKernelRegistry.begin(), kEpsThatDoNotUseKernelRegistry.end(),
// check list of known EPs that compile nodes
if (const auto ep_it = std::find(kEpsThatCompileNodes.begin(), kEpsThatCompileNodes.end(),
provider_type);
ep_it != kEpsThatDoNotUseKernelRegistry.end()) {
ep_it != kEpsThatCompileNodes.end()) {
return false;
}

// assume that a dynamic plugin EP which does not return a kernel registry does not use one
if (provider_type == dynamic_plugin_ep_infra::GetEpName() &&
ep.GetKernelRegistry() == nullptr) {
return false;
const OrtEp* ort_ep = ep.GetOrtEp();

if (ort_ep != nullptr) { // This is a plugin EP

if (ep.GetKernelRegistry() == nullptr) {
// assume that a dynamic plugin EP which does not return a kernel registry does not use one
return false;
}

if (ort_ep->Compile != nullptr) {
// assume that a plugin EP that compiles nodes does not use a kernel registry for all nodes
return false;
}
}

// otherwise, assume that the EP uses a kernel registry
return true;
};

if (!SetEpsForAllNodes(model.MainGraph(), execution_providers, custom_registries, ep_uses_kernel_registry)) {
if (!SetEpsForAllNodes(model.MainGraph(), execution_providers, custom_registries, ep_only_uses_kernel_registry)) {
std::string providers;
for (const auto& ep : execution_providers) {
providers.append(ep->Type() + " ");
Expand Down
Loading