diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 264a843e2d46e..1dbf0318c988c 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -505,6 +505,19 @@ class InferenceSession { */ const std::vector& GetRegisteredProviderTypes() const; + /** + * Get the registered Execution Providers. + * + * This method can be called after EP registration but before Initialize() completes. + * Used only for early validation of compiled model compatibility where accessing + * EPs through session state is not yet possible. + * + * @return const reference to the ExecutionProviders collection. + */ + const ExecutionProviders& GetExecutionProviders() const noexcept { + return execution_providers_; + } + /* * Get the options this session was initialized with. */ diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 944e83d8cad66..a303f8483fbc1 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -258,8 +258,10 @@ static Status ValidateCompiledModelCompatibility(InferenceSession& sess) { const auto& registered_provider_types = sess.GetRegisteredProviderTypes(); - // Access the execution providers through the session state (available after Initialize) - const auto& execution_providers = sess.GetSessionState().GetExecutionProviders(); + // Access the execution providers directly from the session. + // This allows validation to run before Initialize() completes, avoiding expensive + // graph transformations for incompatible models. EPs are fully registered at this point. + const auto& execution_providers = sess.GetExecutionProviders(); for (const auto& ep_type : registered_provider_types) { // Construct the full metadata key using the prefix + EP type @@ -378,14 +380,20 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, reinterpret_cast(prepacked_weights_container))); } - ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize()); - #if !defined(ORT_MINIMAL_BUILD) - // Validate compiled model compatibility for all registered execution providers - // This must be done after Initialize() so the session state is available + // Validate compiled model compatibility for all registered execution providers BEFORE Initialize(). + // This is an optimization to fail fast for incompatible models, avoiding expensive graph transformations, + // partitioning, and kernel binding that occur during Initialize(). + // This is safe because: + // 1. Model metadata (containing compatibility strings) is available after Load() completes. + // 2. Compiling EPs are fully registered at this point. + // 3. Non-compiling EPs (like CPU EP, which may be implicitly added during Initialize()) don't participate + // in compatibility validation - they return NOT_APPLICABLE by default. ORT_API_RETURN_IF_STATUS_NOT_OK(ValidateCompiledModelCompatibility(sess)); #endif // !defined(ORT_MINIMAL_BUILD) + ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize()); + return nullptr; } diff --git a/onnxruntime/test/framework/ep_compatibility_test.cc b/onnxruntime/test/framework/ep_compatibility_test.cc index 0ae3fb746dd24..40e1669fb0129 100644 --- a/onnxruntime/test/framework/ep_compatibility_test.cc +++ b/onnxruntime/test/framework/ep_compatibility_test.cc @@ -90,6 +90,63 @@ class TestCompatibilityExecutionProvider : public IExecutionProvider { bool should_fail_validation_ = false; }; +// Test execution provider that tracks whether GetCapability is called. +// This is used to verify that early validation fails BEFORE Initialize() does expensive work. +class TestEarlyValidationExecutionProvider : public IExecutionProvider { + public: + static constexpr const char* kTestEarlyValidationExecutionProviderType = "TestEarlyValidationExecutionProvider"; + + TestEarlyValidationExecutionProvider() : IExecutionProvider(kTestEarlyValidationExecutionProviderType) { + } + + std::shared_ptr GetKernelRegistry() const override { + return std::make_shared(); + } + + std::vector CreatePreferredAllocators() override { + return {}; + } + + // Override GetCapability to track if it's called (happens during Initialize()) + std::vector> GetCapability( + const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& graph_optimizer_registry, + IResourceAccountant* resource_accountant = nullptr) const override { + ORT_UNUSED_PARAMETER(graph_viewer); + ORT_UNUSED_PARAMETER(kernel_lookup); + ORT_UNUSED_PARAMETER(graph_optimizer_registry); + ORT_UNUSED_PARAMETER(resource_accountant); + get_capability_called_ = true; + return {}; // Return empty - we don't actually want to handle any nodes + } + + // Configurable mock behavior for validation + void SetMockCompatibilityStatus(OrtCompiledModelCompatibility status) { + mock_compatibility_status_ = status; + } + + common::Status ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const override { + ORT_UNUSED_PARAMETER(compatibility_info); + model_compatibility = mock_compatibility_status_; + return Status::OK(); + } + + // Query whether GetCapability was called + bool WasGetCapabilityCalled() const { + return get_capability_called_; + } + + void ResetGetCapabilityCalled() { + get_capability_called_ = false; + } + + private: + OrtCompiledModelCompatibility mock_compatibility_status_ = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; + mutable bool get_capability_called_ = false; +}; + // Helper class to create test models class ModelBuilderWithCompatibility { public: @@ -388,6 +445,72 @@ TEST_F(EpCompatibilityTest, TestEpValidationFailure) { EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Mock validation failure")); } +// Test that early validation optimization works: when a model is incompatible, +// validation should fail BEFORE Initialize() performs expensive graph partitioning. +// We verify this by checking that GetCapability() is NOT called when validation fails. +TEST_F(EpCompatibilityTest, TestEarlyValidation_FailsBeforeGetCapability) { + const std::string ep_type = TestEarlyValidationExecutionProvider::kTestEarlyValidationExecutionProviderType; + const std::string compatibility_string = "test_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_UNSUPPORTED); + + // Verify GetCapability hasn't been called yet + EXPECT_FALSE(test_ep->WasGetCapabilityCalled()); + + // Create model with compatibility metadata for this EP + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata)); + + // Keep a raw pointer to check state after move + auto* test_ep_ptr = test_ep.get(); + + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + + // Initialization should fail due to incompatible model + auto status = InitializeSessionWithValidation(*session); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("not supported")); + + // CRITICAL: GetCapability should NOT have been called because validation failed early, + // before Initialize() could perform graph partitioning + EXPECT_FALSE(test_ep_ptr->WasGetCapabilityCalled()) + << "GetCapability was called, indicating validation did not fail early before Initialize()"; +} + +// Test that when validation succeeds, GetCapability IS called (normal flow) +TEST_F(EpCompatibilityTest, TestEarlyValidation_SucceedsAndProceedsToGetCapability) { + const std::string ep_type = TestEarlyValidationExecutionProvider::kTestEarlyValidationExecutionProviderType; + const std::string compatibility_string = "test_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL); + + // Verify GetCapability hasn't been called yet + EXPECT_FALSE(test_ep->WasGetCapabilityCalled()); + + // Create model with compatibility metadata for this EP + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata)); + + // Keep a raw pointer to check state after move + auto* test_ep_ptr = test_ep.get(); + + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + + // Initialization should succeed + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); + + // GetCapability SHOULD have been called because validation succeeded and + // Initialize() proceeded normally with graph partitioning + EXPECT_TRUE(test_ep_ptr->WasGetCapabilityCalled()) + << "GetCapability was not called, but it should have been after successful validation"; +} + // Test session option configuration for fail on suboptimal TEST_F(EpCompatibilityTest, TestSessionOptionConfiguration) { SessionOptions so;