Skip to content

Commit

Permalink
rust: Use c_char for ffi.
Browse files Browse the repository at this point in the history
  • Loading branch information
hgaiser committed Oct 26, 2023
1 parent 2a1fd25 commit b80fe5b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
16 changes: 9 additions & 7 deletions rust/onnxruntime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ to download.
//! See the [`sample.rs`](https://github.com/nbigaouette/onnxruntime-rs/blob/main/onnxruntime/examples/sample.rs)
//! example for more details.
use std::ffi::c_char;

use onnxruntime_sys as sys;

// Make functions `extern "stdcall"` for Windows 32bit.
Expand Down Expand Up @@ -187,8 +189,8 @@ use sys::OnnxEnumInt;
// Re-export ndarray as it's part of the public API anyway
pub use ndarray;

fn char_p_to_string(raw: *const i8) -> Result<String> {
let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() };
fn char_p_to_string(raw: *const c_char) -> Result<String> {
let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut c_char).to_owned() };

match c_string.into_string() {
Ok(string) => Ok(string),
Expand All @@ -201,7 +203,7 @@ mod onnxruntime {
//! Module containing a custom logger, used to catch the runtime's own logging and send it
//! to Rust's tracing logging instead.
use std::ffi::CStr;
use std::ffi::{CStr, c_char};
use tracing::{debug, error, info, span, trace, warn, Level};

use onnxruntime_sys as sys;
Expand Down Expand Up @@ -240,10 +242,10 @@ mod onnxruntime {
pub(crate) fn custom_logger(
_params: *mut std::ffi::c_void,
severity: sys::OrtLoggingLevel,
category: *const i8,
logid: *const i8,
code_location: *const i8,
message: *const i8,
category: *const c_char,
logid: *const c_char,
code_location: *const c_char,
message: *const c_char,
) {
let log_level = match severity {
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => Level::TRACE,
Expand Down
16 changes: 8 additions & 8 deletions rust/onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Module containing session types
use std::{convert::TryFrom, ffi::CString, fmt::Debug, path::Path};
use std::{convert::TryFrom, ffi::{CString, c_char}, fmt::Debug, path::Path};

#[cfg(not(target_family = "windows"))]
use std::os::unix::ffi::OsStrExt;
Expand Down Expand Up @@ -423,17 +423,17 @@ impl Session {
.map(|output| output.name.clone())
.map(|n| CString::new(n).unwrap())
.collect();
let output_names_ptr: Vec<*const i8> = output_names_cstring
let output_names_ptr: Vec<*const c_char> = output_names_cstring
.iter()
.map(|n| n.as_ptr().cast::<i8>())
.map(|n| n.as_ptr().cast::<c_char>())
.collect();

let input_names_ptr: Vec<*const i8> = self
let input_names_ptr: Vec<*const c_char> = self
.inputs
.iter()
.map(|input| input.name.clone())
.map(|n| CString::new(n).unwrap())
.map(|n| n.into_raw() as *const i8)
.map(|n| n.into_raw() as *const c_char)
.collect();

{
Expand Down Expand Up @@ -508,7 +508,7 @@ impl Session {
.into_iter()
.map(|p| {
assert_not_null_pointer(p, "i8 for CString")?;
unsafe { Ok(CString::from_raw(p as *mut i8)) }
unsafe { Ok(CString::from_raw(p as *mut c_char)) }
})
.collect();
cstrings?;
Expand Down Expand Up @@ -694,14 +694,14 @@ mod dangerous {
*const sys::OrtSession,
usize,
*mut sys::OrtAllocator,
*mut *mut i8,
*mut *mut c_char,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
i: usize,
env: _Environment,
) -> Result<String> {
let mut name_bytes: *mut i8 = std::ptr::null_mut();
let mut name_bytes: *mut c_char = std::ptr::null_mut();

let status = unsafe { f(session_ptr, i, allocator_ptr, &mut name_bytes) };
status_to_result(status).map_err(OrtError::InputName)?;
Expand Down

0 comments on commit b80fe5b

Please sign in to comment.