Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] CUDA_ERROR_INVALID_IMAGE when deploying ResNet18 model to Jetson AGX Orin 32GB using TVM #17543

Open
JuneJulyAugust opened this issue Nov 24, 2024 · 6 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@JuneJulyAugust
Copy link

JuneJulyAugust commented Nov 24, 2024

Expected behavior

The model should run successfully on the Jetson AGX Orin without any errors, producing a top-1 prediction from the TVM runtime.

Actual behavior

The code fails with the error CUDA_ERROR_INVALID_IMAGE when attempting to run the model on the Jetson AGX Orin. The error occurs when using the RPC connection.

Environment

  • Host Machine:
    • OS: Ubuntu 22.04, WSL2 on Windows 10
    • CUDA SDK version: 12.4
    • TVM version: 0.19.dev45+g4d99ec5d9
    • GPU: NVIDIA RTX A3000 12GB
    • Driver Version: 553.35
    • CUDA Version: 12.4
    • LLVM Version: 14.0.0
nvidia-smi output:
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.134                Driver Version: 553.35         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX A3000 12GB La...    On  |   00000000:01:00.0 Off |                  Off |
| N/A   43C    P8             13W /   91W |       0MiB /  12288MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

-TVM build flags

SE_NVTX: OFF
USE_GTEST: OFF
SUMMARIZE: OFF
TVM_DEBUG_WITH_ABI_CHANGE: OFF
USE_IOS_RPC: OFF
USE_MSC: OFF
USE_ETHOSU: OFF
CUDA_VERSION: 12.4
USE_LIBBACKTRACE: AUTO
DLPACK_PATH: 3rdparty/dlpack/include
USE_TENSORRT_CODEGEN: OFF
USE_OPENCL_EXTN_QCOM: NOT-FOUND
USE_TARGET_ONNX: OFF
USE_AOT_EXECUTOR: ON
BUILD_DUMMY_LIBTVM: OFF
USE_CUDNN: ON
USE_TENSORRT_RUNTIME: OFF
USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR: OFF
USE_THRUST: OFF
USE_CCACHE: AUTO
USE_ARM_COMPUTE_LIB: OFF
USE_CPP_RTVM: OFF
USE_OPENCL_GTEST: /path/to/opencl/gtest
TVM_LOG_BEFORE_THROW: OFF
USE_MKL: OFF
USE_PT_TVMDSOOP: OFF
MLIR_VERSION: NOT-FOUND
USE_CLML: OFF
USE_STACKVM_RUNTIME: OFF
USE_GRAPH_EXECUTOR_CUDA_GRAPH: OFF
ROCM_PATH: /opt/rocm
USE_DNNL: OFF
USE_MSCCL: OFF
USE_NNAPI_RUNTIME: OFF
USE_VITIS_AI: OFF
USE_MLIR: OFF
USE_RCCL: OFF
USE_LLVM: llvm-config --ignore-libllvm --link-static
USE_VERILATOR: OFF
USE_TF_TVMDSOOP: OFF
USE_THREADS: ON
USE_MSVC_MT: OFF
BACKTRACE_ON_SEGFAULT: OFF
USE_GRAPH_EXECUTOR: ON
USE_NCCL: OFF
USE_ROCBLAS: OFF
GIT_COMMIT_HASH: 4d99ec5d9e77f2fe1a0e13fe5339865e6536b371
USE_VULKAN: OFF
USE_RUST_EXT: OFF
USE_CUTLASS: OFF
USE_CPP_RPC: OFF
USE_HEXAGON: OFF
USE_CUSTOM_LOGGING: OFF
USE_UMA: OFF
USE_FALLBACK_STL_MAP: OFF
USE_SORT: ON
USE_RTTI: ON
GIT_COMMIT_TIME: 2024-11-23 20:57:00 +0800
USE_HIPBLAS: OFF
USE_HEXAGON_SDK: /path/to/sdk
USE_BLAS: none
USE_ETHOSN: OFF
USE_LIBTORCH: OFF
USE_RANDOM: ON
USE_CUDA: ON
USE_COREML: OFF
USE_AMX: OFF
BUILD_STATIC_RUNTIME: OFF
USE_CMSISNN: OFF
USE_KHRONOS_SPIRV: OFF
USE_CLML_GRAPH_EXECUTOR: OFF
USE_TFLITE: OFF
USE_HEXAGON_GTEST: /path/to/hexagon/gtest
PICOJSON_PATH: 3rdparty/picojson
USE_OPENCL_ENABLE_HOST_PTR: OFF
INSTALL_DEV: OFF
USE_PROFILER: ON
USE_NNPACK: OFF
LLVM_VERSION: 14.0.0
USE_MRVL: OFF
USE_OPENCL: OFF
COMPILER_RT_PATH: 3rdparty/compiler-rt
USE_NNAPI_CODEGEN: OFF
RANG_PATH: 3rdparty/rang/include
USE_SPIRV_KHR_INTEGER_DOT_PRODUCT: OFF
USE_OPENMP: none
USE_BNNS: OFF
USE_FLASHINFER: OFF
USE_CUBLAS: ON
USE_METAL: OFF
USE_MICRO_STANDALONE_RUNTIME: OFF
USE_HEXAGON_EXTERNAL_LIBS: OFF
USE_ALTERNATIVE_LINKER: AUTO
USE_BYODT_POSIT: OFF
USE_NVSHMEM: OFF
USE_HEXAGON_RPC: OFF
USE_MICRO: OFF
DMLC_PATH: 3rdparty/dmlc-core/include
INDEX_DEFAULT_I64: ON
USE_RELAY_DEBUG: OFF
USE_RPC: ON
USE_TENSORFLOW_PATH: none
TVM_CLML_VERSION:
USE_MIOPEN: OFF
USE_ROCM: OFF
USE_PAPI: OFF
USE_CURAND: OFF
TVM_CXX_COMPILER_PATH: /usr/bin/c++
HIDE_PRIVATE_SYMBOLS: ON
  • Jetson AGX Orin 32GB:
    • CUDA SDK version: 11.4
    • JetPack version: 5.1.3
    • OS: Ubuntu 20.04
    • TVM runtime code commit: db6d205

Steps to reproduce

  1. Follow the instructions in the Deploy the Pretrained Model on Jetson Nano to deploy a pretrained model on Jetson AGX Orin 32GB version.

  2. Use this script on the host machine:

import torch
import torchvision
from PIL import Image
import numpy as np
import os
import time
import tvm
from tvm import te
import tvm.relay as relay
from tvm import rpc
from tvm.contrib import utils, graph_executor as runtime
from tvm.contrib.download import download_testdata

def load_image():
    img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
    img_name = "cat.png"
    img_path = download_testdata(img_url, img_name, module="data")
    image = Image.open(img_path).resize((224, 224))


    def transform_image(image):
        image = np.array(image) - np.array([123.0, 117.0, 104.0])
        image /= np.array([58.395, 57.12, 57.375])
        image = image.transpose((2, 0, 1))
        image = image[np.newaxis, :]
        return image

    x = transform_image(image)
    return x

def get_synset():
    synset_url = "".join(
        [
            "https://gist.githubusercontent.com/zhreshold/",
            "4d0b62f3d01426887599d4f7ede23ee5/raw/",
            "596b27d23537e5a1b5751d2b0481ef172f58b539/",
            "imagenet1000_clsid_to_human.txt",
        ]
    )
    synset_name = "imagenet1000_clsid_to_human.txt"
    synset_path = download_testdata(synset_url, synset_name, module="data")
    with open(synset_path) as f:
        synset = eval(f.read())
    return synset

def build_model():
    # one line to get the model
    model_name = "resnet18"
    model = getattr(torchvision.models, model_name)(pretrained=True)
    model = model.eval()

    # We grab the TorchScripted model via tracing
    input_shape = [1, 3, 224, 224]
    input_data = torch.randn(input_shape)
    scripted_model = torch.jit.trace(model, input_data).eval()

    x = load_image()
    input_name = "input0"
    shape_list = [(input_name, x.shape)]
    mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

    # we want a probability so add a softmax operator
    func = mod["main"]
    func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs)

    local_demo = False
    if local_demo:
        # target = tvm.target.Target("llvm")
        target = tvm.target.Target("cuda")
    else:
        target = tvm.target.Target("nvidia/jetson-agx-orin-32gb")
    print(f"TVM target: {target}")
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(func, target, params=params)

    # Save the library at local temporary directory.
    lib_fname = "models/net.tar"
    lib.export_library(lib_fname)

    if local_demo:
        remote = rpc.LocalSession()
    else:
        remote = rpc.connect("192.168.2.10", 9091)

    # upload the library to remote device and load it
    remote.upload(lib_fname)
    rlib = remote.load_module("net.tar")

    # create the remote runtime module
    if local_demo:
        # dev = remote.cpu(0)
        dev = remote.cuda(0)
    else:
        dev = remote.cuda(0)

    module = runtime.GraphModule(rlib["default"](dev))
    # set input data
    module.set_input(input_name, tvm.nd.array(x.astype("float32")))
    start_time = time.time()
    # run
    module.run()
    # get output
    out = module.get_output(0)
    end_time = time.time()
    print("Time taken: ", end_time - start_time)
    # get top1 result
    top1 = np.argmax(out.numpy())
    synset = get_synset()
    print("TVM prediction top-1: {}".format(synset[top1]))


def main():
    build_model()
    

if __name__ == "__main__":
    main()
  1. Run the script on the Jetson AGX Orin:
python3 -m tvm.exec.rpc_server --host 0.0.0.0 --port=9091
  1. When running on host machine, get this error:
/home/fang/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/fang/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
TVM target: cuda -keys=cuda,gpu -arch=sm_87 -max_num_threads=1024 -max_shared_memory_per_block=49152 -max_threads_per_block=1024 -registers_per_block=65536 -thread_warp_size=32
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/mnt/d/projects/vcas_onnx_tensorrt/src/tvm_deploy_restnet_to_jetson.py", line 119, in <module>
    main()
  File "/mnt/d/projects/vcas_onnx_tensorrt/src/tvm_deploy_restnet_to_jetson.py", line 115, in main
    build_model()
  File "/mnt/d/projects/vcas_onnx_tensorrt/src/tvm_deploy_restnet_to_jetson.py", line 103, in build_model
    module.run()
  File "/mnt/d/projects/open_source/tvm/python/tvm/contrib/graph_executor.py", line 264, in run
    self._run()
  File "/mnt/d/projects/open_source/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 245, in __call__
    raise_last_ffi_error()
  File "/mnt/d/projects/open_source/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_module.cc", line 129, in tvm::runtime::RPCWrappedFunc::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
    sess_->CallFunc(handle_, values.data(), type_codes.data(), args.size(), set_return);
  File "/mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_endpoint.cc", line 1087, in tvm::runtime::RPCClientSession::CallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::TVMArgs)> const&)
    endpoint_->CallFunc(func, arg_values, arg_type_codes, num_args, fencode_return);
  File "/mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_endpoint.cc", line 870, in tvm::runtime::RPCEndpoint::CallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::TVMArgs)>)
    code = HandleUntilReturnEvent(true, encode_return);
  File "/mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_endpoint.cc", line 714, in tvm::runtime::RPCEndpoint::HandleUntilReturnEvent(bool, std::function<void (tvm::runtime::TVMArgs)>)
    code = handler_->HandleNextEvent(client_mode, false, setreturn);
  File "/mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_endpoint.cc", line 136, in tvm::runtime::RPCEndpoint::EventHandler::HandleNextEvent(bool, bool, std::function<void (tvm::runtime::TVMArgs)>)
    this->HandleProcessPacket(setreturn);
  File "/mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_endpoint.cc", line 375, in tvm::runtime::RPCEndpoint::EventHandler::HandleProcessPacket(std::function<void (tvm::runtime::TVMArgs)>)
    this->HandleReturn(code, setreturn);
  File "/mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_endpoint.cc", line 439, in tvm::runtime::RPCEndpoint::EventHandler::HandleReturn(tvm::runtime::RPCCode, std::function<void (tvm::runtime::TVMArgs)>)
    LOG(FATAL) << msg;
tvm.error.RPCError: Traceback (most recent call last):
  6: tvm::runtime::RPCWrappedFunc::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_module.cc:129
  5: tvm::runtime::RPCClientSession::CallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::TVMArgs)> const&)
        at /mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_endpoint.cc:1087
  4: tvm::runtime::RPCEndpoint::CallFunc(void*, TVMValue const*, int const*, int, std::function<void (tvm::runtime::TVMArgs)>)
        at /mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_endpoint.cc:870
  3: tvm::runtime::RPCEndpoint::HandleUntilReturnEvent(bool, std::function<void (tvm::runtime::TVMArgs)>)
        at /mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_endpoint.cc:714
  2: tvm::runtime::RPCEndpoint::EventHandler::HandleNextEvent(bool, bool, std::function<void (tvm::runtime::TVMArgs)>)
        at /mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_endpoint.cc:136
  1: tvm::runtime::RPCEndpoint::EventHandler::HandleProcessPacket(std::function<void (tvm::runtime::TVMArgs)>)
        at /mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_endpoint.cc:375
  0: tvm::runtime::RPCEndpoint::EventHandler::HandleReturn(tvm::runtime::RPCCode, std::function<void (tvm::runtime::TVMArgs)>)
        at /mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_endpoint.cc:439
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::detail::PackFuncVoidAddr_<4, tvm::runtime::CUDAWrappedFunc>(tvm::runtime::CUDAWrappedFunc, std::vector<tvm::runtime::detail::ArgConvertCode, std::allocator<tvm::runtime::detail::ArgConvertCode> > const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: tvm::runtime::CUDAWrappedFunc::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*, void**) const
  File "/mnt/d/projects/open_source/tvm/src/runtime/rpc/rpc_endpoint.cc", line 439
RPCError: Error caught from RPC call:
[00:30:45] /data/projects/tvm/src/runtime/cuda/cuda_module.cc:110: CUDAError: cuModuleLoadData(&(module_[device_id]), data_.c_str()) failed with error: CUDA_ERROR_INVALID_IMAGE

Triage

  • needs-triage
@JuneJulyAugust JuneJulyAugust added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Nov 24, 2024
@JuneJulyAugust
Copy link
Author

I tested a jetson-agx-xavier target today with a Jetson Xavier device. I got the same CUDA_ERROR_INVALID_IMAGE error. Has anyone ran a model on Jetson devices successfully recently? Or I should try earlier release tags?

@JuneJulyAugust
Copy link
Author

I printed out the command in python/tvm/contrib/nvcc.py when compiling model for target "nvidia/jetson-agx-xavier":

Command to compile: ['nvcc', '--fatbin', '-O3', '-gencode', 'arch=compute_72,code=sm_72', '-o', '/tmp/tmpz1pl_3jt/tvm_kernels.fatbin', '/tmp/tmpz1pl_3jt/tvm_kernels.cu']

The arch and code arguments seem right. Any expert can help check?

@mshr-h
Copy link
Contributor

mshr-h commented Nov 27, 2024

I actually have AGX Xavier, AGX Orin, and Orin Nano, but I haven't hit a similar error before.
I would recommend trying to compile the model on the device side instead of the host machine and see if it works.

@JuneJulyAugust
Copy link
Author

Thanks for the reply. I will compile TVM code and compile the model on the AGX Orin to see if error still happens.

@cgerum
Copy link
Contributor

cgerum commented Nov 28, 2024

This error is most likely caused by different cuda versions on the host and device. Please notice that even if tvm is linked against the same cuda versions on the host and the device, it might still be using the wrong nvcc command line util.

Please make sure that nvcc --version give the same results on the host and on the device side.

@JuneJulyAugust
Copy link
Author

@mshr-h @cgerum Thank you for your replies. I have tried compiling TVM on the AGX Orin and compiling the model. The model can be loaded successfully. Previously, my host system was x86_64 with CUDA version 12.4, while the target CUDA version was 11.4, as I mentioned in my report.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

3 participants