-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Run model with a cupy array on CUDA #10238
Comments
Hi, I am facing the same Issue. Have you figured out how to use from_dlpack() and to_dlpack() with ONNX? |
After some research I found a plausible solution. Take a look at #10286 as you may need to build onnxruntime from source. Once you have everything working, you will find the function you were looking for in onnxruntime.training.ortmodule._utils:
|
This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details. |
I also came across this task here and found that the solution is quite simple: You have to use IO Bindings and just pass the cupy data pointer. I've also made sure that the array is contiguous, because the PyTroch example also shows this. Also there is a nice cupy interoperability guide which I have used. image_gpu = cp.array(image, dtype=cp.float32)
image_gpu = cp.ascontiguousarray(image_gpu)
binding = onnx_sess.io_binding()
binding.bind_input(name=onnx_sess.get_inputs()[0].name, device_type='cuda', device_id=0, element_type=cp.float32,
shape=tuple(image_gpu.shape), buffer_ptr=image_gpu.data.ptr)
binding.bind_output(name=onnx_sess.get_outputs()[0].name)
onnx_sess.run_with_iobinding(binding)
results = binding.copy_outputs_to_cpu()[0] Here is an example, where only cupy arrays are used: image_gpu = cp.array(image, dtype=cp.float32)
image_gpu = cp.ascontiguousarray(image_gpu)
binding = onnx_sess.io_binding()
binding.bind_input(name=onnx_sess.get_inputs()[0].name, device_type='cuda', device_id=0, element_type=cp.float32,
shape=tuple(image_gpu.shape), buffer_ptr=image_gpu.data.ptr)
binding.bind_output("output", "cuda")
onnx_sess.run_with_iobinding(binding)
if ort_output.data_ptr():
ort_output = binding.get_outputs()[0] # returns OrtValue with memory pointer
mem = cp.cuda.UnownedMemory(ort_output.data_ptr(), np.prod(onnx_sess.get_outputs()[0].shape), owner=ort_output)
mem_ptr = cp.cuda.MemoryPointer(mem, 0)
results = cp.ndarray(ort_output.shape(), dtype=cp.float32, memptr=mem_ptr) I guess also the feature request can thus be closed #15963 The code above was tested using onnxruntime 1.8.0, cupy 9.6.0 and cuda 11.0. edit on 26th February: always check if the ort-value pointer is non-zero! |
I have been getting corrupted results on some calls with the solution above when doing repeated calls to the model, and after a day trying to solve the issue I found by chance that one needs to call |
I tried to use cupy to convert data, and then feed to onnx with the solution above, but one error occured:
Does anyone know the reason? onnxruntime-gpu: 1.18.1 | cupy-cuda12x: 13.3.0 | cuda: 12.2 |
My I ask what is your onnxruntime, cupy and cuda version? |
I am currently running the code above with the following versions:
I had also troubles finding the correct versions for my python environment. You may find a solution here: |
Thank you for your config, my code runs successfully under this config. |
Update: My previous config is also ok, but I made a mistake and installed the cuda11 version of onnxruntime-gpu library, my code ran successfully when I reinstalled using this command from the official website:
|
Similar to #10217
Can we run onnxruntime model with cupy array (with some conversions)?
I tried the dlpack way as mentioned in #4162 but I got module not found error with the statement
C.OrtValue.from_dlpack()
Have I missed something or the installation is different as usual to support
from onnxruntime.capi import _pybind_state as C
?I basically install onnxruntime as following:
# Ubuntu 18.04 with cuda 11.2 pip install onnxruntime-gpu==1.9.0
Thanks!
The text was updated successfully, but these errors were encountered: