Skip to content

Commit

Permalink
rust: Add TensorRT and OpenVINO execution providers.
Browse files Browse the repository at this point in the history
  • Loading branch information
hgaiser committed Nov 30, 2023
1 parent 094b04e commit fd80903
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
6 changes: 6 additions & 0 deletions rust/onnxruntime/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ pub enum OrtError {
/// Error occurred when appending CUDA execution provider
#[error("Failed to append CUDA execution provider: {0}")]
AppendExecutionProviderCuda(OrtApiError),
/// Error occurred when appending TensorRT execution provider
#[error("Failed to append TensorRT execution provider: {0}")]
AppendExecutionProviderTensorRT(OrtApiError),
/// Error occurred when appending OpenVINO execution provider
#[error("Failed to append OpenVINO execution provider: {0}")]
AppendExecutionProviderOpenVINO(OrtApiError),
}

/// Error used when dimensions of input (from model and from inference call)
Expand Down
62 changes: 62 additions & 0 deletions rust/onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,68 @@ impl<'a> SessionBuilder<'a> {
Ok(self)
}

/// Append a TensorRT execution provider
pub fn with_execution_provider_tensorrt(self) -> Result<SessionBuilder<'a>> {
let mut tensorrt_options: *mut sys::OrtTensorRTProviderOptionsV2 = null_mut();
// let status = unsafe {
// self.env
// .env()
// .api()
// .CreateTensorRTProviderOptions
// .unwrap()(&mut tensorrt_options)
// };
// status_to_result(status).map_err(OrtError::TensorRTProviderOptions)?;

let status = unsafe {
self.env
.env()
.api()
.SessionOptionsAppendExecutionProvider_TensorRT_V2
.unwrap()(self.session_options_ptr, tensorrt_options)
};
status_to_result(status).map_err(OrtError::AppendExecutionProviderTensorRT)?;

// unsafe {
// self.env
// .env()
// .api()
// .ReleaseTensorRTProviderOptions
// .unwrap()(tensorrt_options);
// };
Ok(self)
}

/// Append a TensorRT execution provider
pub fn with_execution_provider_openvino(self) -> Result<SessionBuilder<'a>> {
let mut openvino_options: *mut sys::OrtOpenVINOProviderOptions = null_mut();
// let status = unsafe {
// self.env
// .env()
// .api()
// .CreateOpenVINOProviderOptions
// .unwrap()(&mut openvino_options)
// };
// status_to_result(status).map_err(OrtError::OpenVINOProviderOptions)?;

let status = unsafe {
self.env
.env()
.api()
.SessionOptionsAppendExecutionProvider_OpenVINO
.unwrap()(self.session_options_ptr, openvino_options)
};
status_to_result(status).map_err(OrtError::AppendExecutionProviderOpenVINO)?;

// unsafe {
// self.env
// .env()
// .api()
// .ReleaseOpenVINOProviderOptions
// .unwrap()(openvino_options);
// };
Ok(self)
}

/// Download an ONNX pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models) and commit the session
#[cfg(feature = "model-fetching")]
pub fn with_model_downloaded<M>(self, model: M) -> Result<Session>
Expand Down

0 comments on commit fd80903

Please sign in to comment.