Skip to content

Commit

Permalink
Adds conditions to keep QDQ
Browse files Browse the repository at this point in the history
The conditions are based on
 - keep uint8/int8 QDQs after Conv (except Conv -> QDQ -> Bias Add/Div/Mul)
 - keep DQ if used for initializer
 - QDQ pair before Conv should be stripped if the QDQ pair is uint16/int16
  • Loading branch information
rayngun committed Feb 11, 2025
1 parent f84614c commit 69647fc
Showing 1 changed file with 11 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -478,21 +478,32 @@ static void AddStandaloneNodeUnit(onnxruntime::Graph& dst_graph, const onnxrunti
};

if (node_unit.OpType() == "QuantizeLinear") {
const auto& node =node_unit.GetNode();
SkipReason reason;
// keep if next target is supported
if (CheckQRuleSet(node_unit, &node_unit.GetNode(), src_graph, reason))
AddNode(initializers_to_keep, src_graph, dst_graph, node_unit.GetNode());
// #2 If input 0 is a constant initializer, then don't keep the Q
else if (src_graph.IsConstantInitializer(node_unit.GetNode().InputDefs().at(0)->Name(), true))
return;
else if (node.GetInputEdgesCount() == 1 &&
(node.InputNodesBegin()->OpType() == "Conv" || node.InputNodesBegin()->OpType() == "Add") &&
(GetQDQDataType(&node) == DT_UINT8 || GetQDQDataType(&node) == DT_INT8))
AddNode(initializers_to_keep, src_graph, dst_graph, node_unit.GetNode());
else
add_identity_op(false);
} else if (node_unit.OpType() == "DequantizeLinear") {
const auto& node =node_unit.GetNode();
// keep if prev target is supported
if (node_unit.GetNode().Name().find(DuplicateDQ) != std::string::npos)
add_identity_op(true);
else if (IsConnectedQPresent(src_graph, dst_graph.Nodes(), &node_unit.GetNode(), node_unit.GetNode().InputDefs()))
AddNode(initializers_to_keep, src_graph, dst_graph, node_unit.GetNode());
else if (src_graph.IsConstantInitializer(node_unit.GetNode().InputDefs().at(0)->Name(), true))
AddNode(initializers_to_keep, src_graph, dst_graph, node_unit.GetNode());
else if (node.GetOutputEdgesCount() == 1 && node.OutputNodesBegin()->OpType() == "Conv" &&
(GetQDQDataType(&node) == DT_UINT16 || GetQDQDataType(&node) == DT_INT16))
add_identity_op(false);
else if (DQFeedsASupportedOp(&node_unit.GetNode()))
AddNode(initializers_to_keep, src_graph, dst_graph, node_unit.GetNode());
else
Expand Down

0 comments on commit 69647fc

Please sign in to comment.