Skip to content
4 changes: 4 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,10 @@ jobs:
./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear
./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d

# Run e2e testing for selected operators. More operators will be tested via this
# route in the future.
python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*"

nxp-build-test:
name: nxp-build-test
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
Expand Down
14 changes: 14 additions & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,19 @@ runtime.python_library(
],
)

runtime.python_library(
name = "fold_qdq",
srcs = ["fold_qdq.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/backends/vulkan:utils_lib",
"//executorch/exir:pass_base",
],
)

runtime.python_library(
name = "fuse_patterns",
srcs = ["fuse_patterns.py"],
Expand All @@ -144,6 +157,7 @@ runtime.python_library(
"//executorch/examples/...",
],
deps = [
":fold_qdq",
":fuse_patterns",
":fuse_quantized_ops",
":insert_prepack_nodes",
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

from executorch.backends.vulkan._passes.fold_qdq import FoldQDQPass
from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass
from executorch.backends.vulkan._passes.fuse_quantized_ops import (
FuseQuantizedOpsTransform,
Expand All @@ -30,6 +31,7 @@
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass

__all__ = [
"FoldQDQPass",
"FusePatternsPass",
"FuseQuantizedOpsTransform",
"insert_prepack_nodes",
Expand Down
41 changes: 41 additions & 0 deletions backends/vulkan/_passes/fold_qdq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.vulkan.utils as utils
import torch

from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass


class FoldQDQPass(ExportPass):
"""
Erase Q/DQ chain introduced by PT2E quantization workflow. It is assumed that all
valid quant op patterns have already been fused before this pass.
"""

def __init__(self, edge_program: torch.export.ExportedProgram):
super(FoldQDQPass, self).__init__()
self.edge_program = edge_program

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if utils.is_quant_node(node):
original_node = node.args[0]
assert isinstance(original_node, torch.fx.Node)
# For each direct user that is a dequant node, connect the original
# node to the users of the dequant node.
for user in node.users:
if utils.is_dequant_node(user):
dq_node = user
dq_node.replace_all_uses_with(original_node)

graph_module.recompile()
dead_code_elimination_pass(graph_module)
# Re-trace to validate everything is ok
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
131 changes: 131 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import executorch.backends.vulkan.patterns as vk_patterns
import torch.library

Expand Down Expand Up @@ -321,6 +323,135 @@ def linear_qta8a_qga4w(
lib.impl(name, linear_qta8a_qga4w, "CompositeExplicitAutograd")
linear_qta8a_qga4w_op = getattr(getattr(torch.ops, namespace), name)

#################
## qaqw_linear ##
#################


def linear_q8ta_q8csw(
x: torch.Tensor,
input_scale: float,
input_zero_point: int,
weights: torch.Tensor,
weight_sums: torch.Tensor,
weight_scales: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
weights,
weight_scales,
weight_zeros,
0,
-127,
127,
torch.int8,
)

# Perform linear operation
out = torch.nn.functional.linear(x, weights)
if bias is not None:
out = out + bias

return out


name = "linear_q8ta_q8csw"
lib.define(
f"""
{name}(
Tensor x,
float input_scale,
int input_zero_point,
Tensor weights,
Tensor weight_sums,
Tensor weight_scales,
Tensor? bias = None) -> Tensor
"""
)
lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd")
qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name)

##################
## conv2d_q8ta_q8csw ##
##################


def conv2d_q8ta_q8csw(
x: torch.Tensor,
input_scale: float,
input_zero_point: int,
weights: torch.Tensor,
weight_sums: torch.Tensor,
weight_scales: torch.Tensor,
bias: Optional[torch.Tensor],
kernel_size: list,
stride: list,
padding: list,
dilation: list,
groups: int,
):
IC = x.shape[1]
K_h, K_w = kernel_size[0], kernel_size[1]

canonical_weight_K_dim = K_h * K_w * IC
# Remove any padding added to output channels dim to align to a multiple of 4
if weights.shape[-1] != canonical_weight_K_dim:
weights = weights[:, :canonical_weight_K_dim]
weight_scales = weight_scales[:canonical_weight_K_dim]
if bias is not None:
bias = bias[:canonical_weight_K_dim]

weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)

# Calculate dimensions
OC = weights.shape[0]
in_features = weights.shape[1]
IC = in_features // (K_h * K_w)

# Reshape to original 4D format (OC, IC, H, W)
weights = weights.view(OC, IC, K_h, K_w)

# Dequantize weights
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
weights,
weight_scales,
weight_zeros,
0, # axis=0 for output channel quantization
-127,
127,
torch.int8,
)

# Perform convolution
out = torch.nn.functional.conv2d(
x, weights, bias, stride, padding, dilation, groups
)

return out


name = "conv2d_q8ta_q8csw"
lib.define(
f"""
{name}(
Tensor x,
float input_scale,
int input_zero_point,
Tensor weights,
Tensor weight_sums,
Tensor weight_scales,
Tensor? bias,
SymInt[] kernel_size,
SymInt[] stride,
SymInt[] padding,
SymInt[] dilation,
SymInt groups) -> Tensor
"""
)
lib.impl(name, conv2d_q8ta_q8csw, "CompositeExplicitAutograd")
conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name)

######################
## apply_rotary_emb ##
######################
Expand Down
40 changes: 40 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,19 @@ def register_int8_mm_op():
)


@update_features(
[
exir_ops.edge.et_vk.linear_q8ta_q8csw.default,
]
)
def register_qa_qw_linear():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
supports_prepacking=True,
supports_resize=False,
)


@update_features(
[
exir_ops.edge.et_vk.linear_weight_int4.default,
Expand Down Expand Up @@ -457,6 +470,33 @@ def register_convolution_op():
)


@update_features(
[
exir_ops.edge.et_vk.conv2d_q8ta_q8csw.default,
]
)
def register_quantized_conv_op():
return OpFeatures(
inputs_storage=[
utils.CHANNELS_PACKED_TEXTURE, # input
utils.NO_STORAGE, # input_scale (non tensor)
utils.NO_STORAGE, # input_zero_point (non tensor)
utils.NO_STORAGE, # weight (prepacked)
utils.NO_STORAGE, # weight_sums (prepacked)
utils.NO_STORAGE, # weight_scales (prepacked)
utils.NO_STORAGE, # bias (prepacked)
utils.NO_STORAGE, # kernel_size (non tensor)
utils.NO_STORAGE, # stride (non tensor)
utils.NO_STORAGE, # padding (non tensor)
utils.NO_STORAGE, # dilation (non tensor)
utils.NO_STORAGE, # groups (non tensor)
utils.NO_STORAGE, # original OC count (non tensor)
],
supports_resize=False,
supports_prepacking=True,
)


@update_features("llama::sdpa_with_kv_cache")
def register_sdpa_with_kv_cache_op():
return OpFeatures(
Expand Down
9 changes: 5 additions & 4 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
vulkan_supported_ops,
)

from executorch.backends.vulkan.patterns import PatternMatch

from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
VkMemoryLayout,
VkStorageType,
Expand All @@ -41,7 +43,6 @@

from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupportBase
from torch.fx.passes.utils.matcher_utils import InternalMatch

# pyre-ignore
ops_not_to_decompose = [
Expand All @@ -60,7 +61,7 @@ def __init__(
require_dynamic_shape: bool = False,
operator_blocklist: Optional[Set[OpKey]] = None,
operator_allowlist: Optional[Set[OpKey]] = None,
fusable_subgraphs: Optional[List[InternalMatch]] = None,
fusable_subgraphs: Optional[List[PatternMatch]] = None,
nn_module_blocklist: Optional[Set[str]] = None,
nn_module_allowlist: Optional[Set[str]] = None,
) -> None:
Expand All @@ -72,13 +73,13 @@ def __init__(
operator_blocklist if operator_blocklist is not None else set()
)
self.operator_allowlist = operator_allowlist
self.fusable_subgraphs: List[InternalMatch] = (
self.fusable_subgraphs: List[PatternMatch] = (
fusable_subgraphs if fusable_subgraphs is not None else []
)
# Create a set of all nodes that are part of fusable subgraphs for quick lookup
self.fusable_nodes: Set[torch.fx.Node] = set()
for match in self.fusable_subgraphs:
self.fusable_nodes.update(match.nodes_map.values())
self.fusable_nodes.update(match.all_nodes)

self.nn_module_blocklist = nn_module_blocklist
self.nn_module_allowlist = nn_module_allowlist
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/patterns/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ runtime.python_library(
"pattern_registry.py",
"rope.py",
"quantized_linear.py",
"quantized_convolution.py",
],
visibility = [
"//executorch/backends/...",
Expand Down
Loading
Loading