Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,19 @@ class InferenceSession {
*/
const std::vector<std::string>& GetRegisteredProviderTypes() const;

/**
* Get the registered Execution Providers.
*
* This method can be called after EP registration but before Initialize() completes.
* Used only for early validation of compiled model compatibility where accessing
* EPs through session state is not yet possible.
*
* @return const reference to the ExecutionProviders collection.
*/
const ExecutionProviders& GetExecutionProviders() const noexcept {
return execution_providers_;
}

/*
* Get the options this session was initialized with.
*/
Expand Down
20 changes: 14 additions & 6 deletions onnxruntime/core/session/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,10 @@ static Status ValidateCompiledModelCompatibility(InferenceSession& sess) {

const auto& registered_provider_types = sess.GetRegisteredProviderTypes();

// Access the execution providers through the session state (available after Initialize)
const auto& execution_providers = sess.GetSessionState().GetExecutionProviders();
// Access the execution providers directly from the session.
// This allows validation to run before Initialize() completes, avoiding expensive
// graph transformations for incompatible models. EPs are fully registered at this point.
const auto& execution_providers = sess.GetExecutionProviders();

for (const auto& ep_type : registered_provider_types) {
// Construct the full metadata key using the prefix + EP type
Expand Down Expand Up @@ -378,14 +380,20 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options,
reinterpret_cast<PrepackedWeightsContainer*>(prepacked_weights_container)));
}

ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize());

#if !defined(ORT_MINIMAL_BUILD)
// Validate compiled model compatibility for all registered execution providers
// This must be done after Initialize() so the session state is available
// Validate compiled model compatibility for all registered execution providers BEFORE Initialize().
// This is an optimization to fail fast for incompatible models, avoiding expensive graph transformations,
// partitioning, and kernel binding that occur during Initialize().
// This is safe because:
// 1. Model metadata (containing compatibility strings) is available after Load() completes.
// 2. Compiling EPs are fully registered at this point.
// 3. Non-compiling EPs (like CPU EP, which may be implicitly added during Initialize()) don't participate
// in compatibility validation - they return NOT_APPLICABLE by default.
ORT_API_RETURN_IF_STATUS_NOT_OK(ValidateCompiledModelCompatibility(sess));
#endif // !defined(ORT_MINIMAL_BUILD)

ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize());

return nullptr;
}

Expand Down
123 changes: 123 additions & 0 deletions onnxruntime/test/framework/ep_compatibility_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,63 @@ class TestCompatibilityExecutionProvider : public IExecutionProvider {
bool should_fail_validation_ = false;
};

// Test execution provider that tracks whether GetCapability is called.
// This is used to verify that early validation fails BEFORE Initialize() does expensive work.
class TestEarlyValidationExecutionProvider : public IExecutionProvider {
public:
static constexpr const char* kTestEarlyValidationExecutionProviderType = "TestEarlyValidationExecutionProvider";

TestEarlyValidationExecutionProvider() : IExecutionProvider(kTestEarlyValidationExecutionProviderType) {
}

std::shared_ptr<KernelRegistry> GetKernelRegistry() const override {
return std::make_shared<KernelRegistry>();
}

std::vector<AllocatorPtr> CreatePreferredAllocators() override {
return {};
}

// Override GetCapability to track if it's called (happens during Initialize())
std::vector<std::unique_ptr<ComputeCapability>> GetCapability(
const onnxruntime::GraphViewer& graph_viewer,
const IKernelLookup& kernel_lookup,
const GraphOptimizerRegistry& graph_optimizer_registry,
IResourceAccountant* resource_accountant = nullptr) const override {
ORT_UNUSED_PARAMETER(graph_viewer);
ORT_UNUSED_PARAMETER(kernel_lookup);
ORT_UNUSED_PARAMETER(graph_optimizer_registry);
ORT_UNUSED_PARAMETER(resource_accountant);
get_capability_called_ = true;
return {}; // Return empty - we don't actually want to handle any nodes
}

// Configurable mock behavior for validation
void SetMockCompatibilityStatus(OrtCompiledModelCompatibility status) {
mock_compatibility_status_ = status;
}

common::Status ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info,
OrtCompiledModelCompatibility& model_compatibility) const override {
ORT_UNUSED_PARAMETER(compatibility_info);
model_compatibility = mock_compatibility_status_;
return Status::OK();
}

// Query whether GetCapability was called
bool WasGetCapabilityCalled() const {
return get_capability_called_;
}

void ResetGetCapabilityCalled() {
get_capability_called_ = false;
}

private:
OrtCompiledModelCompatibility mock_compatibility_status_ = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL;
mutable bool get_capability_called_ = false;
};

// Helper class to create test models
class ModelBuilderWithCompatibility {
public:
Expand Down Expand Up @@ -388,6 +445,72 @@ TEST_F(EpCompatibilityTest, TestEpValidationFailure) {
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Mock validation failure"));
}

// Test that early validation optimization works: when a model is incompatible,
// validation should fail BEFORE Initialize() performs expensive graph partitioning.
// We verify this by checking that GetCapability() is NOT called when validation fails.
TEST_F(EpCompatibilityTest, TestEarlyValidation_FailsBeforeGetCapability) {
const std::string ep_type = TestEarlyValidationExecutionProvider::kTestEarlyValidationExecutionProviderType;
const std::string compatibility_string = "test_compatibility_v1.0";

auto test_ep = std::make_unique<TestEarlyValidationExecutionProvider>();
test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_UNSUPPORTED);

// Verify GetCapability hasn't been called yet
EXPECT_FALSE(test_ep->WasGetCapabilityCalled());

// Create model with compatibility metadata for this EP
std::map<std::string, std::string> compatibility_info = {{ep_type, compatibility_string}};
auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info);

auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata));

// Keep a raw pointer to check state after move
auto* test_ep_ptr = test_ep.get();

ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep)));

// Initialization should fail due to incompatible model
auto status = InitializeSessionWithValidation(*session);
EXPECT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("not supported"));

// CRITICAL: GetCapability should NOT have been called because validation failed early,
// before Initialize() could perform graph partitioning
EXPECT_FALSE(test_ep_ptr->WasGetCapabilityCalled())
<< "GetCapability was called, indicating validation did not fail early before Initialize()";
}

// Test that when validation succeeds, GetCapability IS called (normal flow)
TEST_F(EpCompatibilityTest, TestEarlyValidation_SucceedsAndProceedsToGetCapability) {
const std::string ep_type = TestEarlyValidationExecutionProvider::kTestEarlyValidationExecutionProviderType;
const std::string compatibility_string = "test_compatibility_v1.0";

auto test_ep = std::make_unique<TestEarlyValidationExecutionProvider>();
test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL);

// Verify GetCapability hasn't been called yet
EXPECT_FALSE(test_ep->WasGetCapabilityCalled());

// Create model with compatibility metadata for this EP
std::map<std::string, std::string> compatibility_info = {{ep_type, compatibility_string}};
auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info);

auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata));

// Keep a raw pointer to check state after move
auto* test_ep_ptr = test_ep.get();

ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep)));

// Initialization should succeed
ASSERT_STATUS_OK(InitializeSessionWithValidation(*session));

// GetCapability SHOULD have been called because validation succeeded and
// Initialize() proceeded normally with graph partitioning
EXPECT_TRUE(test_ep_ptr->WasGetCapabilityCalled())
<< "GetCapability was not called, but it should have been after successful validation";
}

// Test session option configuration for fail on suboptimal
TEST_F(EpCompatibilityTest, TestSessionOptionConfiguration) {
SessionOptions so;
Expand Down
Loading