Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi
available_providers = C.get_available_providers()

# Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU.
if "TensorrtExecutionProvider" in available_providers:
if "NvTensorRTRTXExecutionProvider" in available_providers:
if (
providers
and any(
Expand All @@ -522,15 +522,15 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi
for provider in providers
)
and any(
provider == "TensorrtExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider")
provider == "NvTensorRTRTXExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "NvExecutionProvider")
for provider in providers
)
):
self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
self._fallback_providers = ["CPUExecutionProvider"]
if "NvTensorRTRTXExecutionProvider" in available_providers:
elif "TensorrtExecutionProvider" in available_providers:
if (
providers
and any(
Expand All @@ -539,8 +539,8 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi
for provider in providers
)
and any(
provider == "NvTensorRTRTXExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "NvExecutionProvider")
provider == "TensorrtExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider")
for provider in providers
)
):
Expand Down