Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mobilenetv2 doesn't work with Vulkan backend #6516

Open
sternezsl opened this issue Oct 28, 2024 · 2 comments
Open

mobilenetv2 doesn't work with Vulkan backend #6516

sternezsl opened this issue Oct 28, 2024 · 2 comments
Assignees

Comments

@sternezsl
Copy link

🐛 Describe the bug

I can successfully export the vulkan pte. When I run the model with

./backends/vulkan/vulkan_executor_runner --model_path /scratch/models/vulkan_mobilenetv2.pte

I get the error:

I 00:00:00.001717 executorch:executor_runner.cpp:82] Model file /scratch/models/vulkan_mobilenetv2.pte is loaded.
I 00:00:00.001730 executorch:executor_runner.cpp:91] Using method forward
I 00:00:00.001732 executorch:executor_runner.cpp:138] Setting up planned buffer 0, size 606112.
libc++abi: terminating due to uncaught exception of type vkcompute::vkapi::Error: Exception raised from check_conv_args at /scratch/code/meta/executorch/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp:225: (check_packed_dim_is(in, WHCN::kChannelsDim)) is false!
[1] 544559 IOT instruction (core dumped) ./backends/vulkan/vulkan_executor_runner --model_path

I inspect the code and find that the input tensor's packed_dim is 0 rather than 2(WHCN::kChannelsDim). If I comment out the check_conv_args function, then I run into another problem:

I 00:00:00.000351 executorch:executor_runner.cpp:82] Model file /scratch/models/vulkan_mobilenetv2.pte is loaded.
I 00:00:00.000358 executorch:executor_runner.cpp:91] Using method forward
I 00:00:00.000360 executorch:executor_runner.cpp:138] Setting up planned buffer 0, size 606112.
libc++abi: terminating due to uncaught exception of type vkcompute::vkapi::Error: Exception raised from mean at /scratch/code/meta/executorch/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp:112: (dims_list->size() == 1) is false!

I noticed that @SS-JIA replaced the mobilenet demo with a Add one in the tutorial a few days ago. I guess you know the problem. I try to fix the problem. unfortunately, presently I do not know much about Vulkan backend, could you please give me some hints and I'll try to fix it.

The following is the model conversion script:

import torch
import torchvision.models as models

from torch.export import export, ExportedProgram
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge_transform_and_lower
from executorch.exir.backend.backend_api import to_backend
from executorch.backends.vulkan import VulkanPartitioner

mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
sample_inputs = (torch.randn(1, 3, 224, 224), )

exported_program: ExportedProgram = export(mobilenet_v2, sample_inputs)
edge: EdgeProgramManager = to_edge_transform_and_lower(
    exported_program,
    partitioner=[VulkanPartitioner()],
)

# print(edge.exported_program().graph_module)

exec_prog = edge.to_executorch()

with open("vulkan_mobilenetv2.pte", "wb") as file:
    exec_prog.write_to_file(file)

Versions

PyTorch version: 2.5.0+git8a0ce38
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
VULKAN used to build PyTorch: True

OS: Fedora Linux Asahi Remix 40 (KDE Plasma) (aarch64)
GCC version: (GCC) 14.2.1 20240912 (Red Hat 14.2.1-3)
Clang version: 18.1.7 (https://github.com/llvm/llvm-project 768118d1ad38bf13c545828f67bd6b474d61fc55)
CMake version: version 3.20.0
Libc version: glibc-2.39
Vulkan Driver Version: Mesa 24.3.0-devel (git-f05157f591),
Vulkan Instance Version: 1.3.290

Python version: 3.10.10 | packaged by conda-forge | (main, Mar 24 2023, 19:56:21) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-6.11.3-dimilar-4-1-edge-ARCH+-aarch64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU: Apple M1 Max

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-breakpoint==1.1.0
[pip3] flake8-bugbear==23.9.16
[pip3] flake8-comprehensions==3.14.0
[pip3] flake8-plugin-utils==1.3.3
[pip3] flake8-pyi==23.5.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.25.0
[pip3] pytorch-sphinx-theme==0.0.19
[pip3] torchao==0.7.0+git6b529961
[pip3] torchvision==0.20.0a0+f851df1
[conda] numpy 1.25.0 pypi_0 pypi
[conda] pytorch-sphinx-theme 0.0.19 pypi_0 pypi
[conda] torchao 0.7.0+git6b529961 pypi_0 pypi
[conda] torchfix 0.5.0 pypi_0 pypi
[conda] torchvision 0.20.0a0+f851df1 pypi_0 pypi

@sternezsl
Copy link
Author

  • The first error is due to the memory layouts of input and output tensor are not TENSOR_CHANNELS_PACKED
  • About the second error, I don't think the reduce implementation of sum.dim_IntList is the alternative of the old one. I restore the old Sum.cpp and the staff of sum_dim*.glsl and sum_dim*.yaml, then the mobilenetv2 works now and I get the same result as same as the xnnpack backend

@SS-JIA
Copy link
Contributor

SS-JIA commented Nov 4, 2024

Hi @sternezsl, thanks for raising this issue!

I noticed that @SS-JIA replaced the mobilenet demo with a Add one in the tutorial a few days ago. I guess you know the problem. I try to fix the problem. unfortunately, presently I do not know much about Vulkan backend, could you please give me some hints and I'll try to fix it.

I definitely appreciate the help on this :) I did encounter an issue when trying out the mobilenet example, but it was separate from the issues that you have described. The problem that I encountered back then was related to the runner binary, not the Vulkan export, and I decided to replace the example with a simpler one because I didn't have time to look into the failure before the beta release of ExecuTorch. The fact that you were able to get MobileNet V2 working suggests that the problem with the runner binary has been fixed, however. If we can verify that mobilenet works on main then it would make sense to make the example mobilenet once more (y)

About the second error, I don't think the reduce implementation of sum.dim_IntList is the alternative of the old one. I restore the old Sum.cpp and the staff of sum_dim*.glsl and sum_dim*.yaml, then the mobilenetv2 works now and I get the same result as same as the xnnpack backend

I also became aware of this regression while trying to commit the new reduce shader, but decided to go ahead with it for the sake of updating the implementation to use a more modern pattern of being agnostic to memory layouts, and because the new implementation is more performant for the single reduction dim case. Also, at the time mobilenet v2 was the only model that I was working with at the time that used sum with multiple reduction dimensions so I figured the impact of the regression was not super significant.

When I first committed the new reduce shader, there was an issue with the regression since the new implementation cannot handle reducing over multiple tensor dimensions. However, I have since landed #6488 as well which addresses this. The linked PR essentially adds the ability for operators to implement a specific check function to detect whether they are supported by Vulkan, and after the PR was committed the sum op in mobilenet will no longer be delegated to Vulkan.

Please let me know if you're still running into problems on the current main branch.

The first error is due to the memory layouts of input and output tensor are not TENSOR_CHANNELS_PACKED

I am currently working on #6636 which will fix errors like these. However, in the meantime if you lower to Vulkan with the following settings:

VulkanPartitioner(
    compile_options={
        "memory_layout_override": vulkan_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED,
    }

It should fix the issue as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants