Skip to content

Commit cd38342

Browse files
jfix71facebook-github-bot
authored andcommitted
Add support for 4-bit fused-rowwise-quantized SLWS to Interpreter (pytorch#3719)
Summary: This PR adds support for 4-bit rowwise quantization to FRWQ-SLWS. Pull Request resolved: pytorch#3719 Test Plan: Added tests to cover this. I also added a new set of tests (see second commit) that has two columns in addition to being weighted -- previously all of our weighted tests only had a single column of data. Related to pytorch#3463 CC: jsubag Differential Revision: D18265376 Pulled By: jfix71 fbshipit-source-id: fbab62a867eb6306f9cde82abdc374de48a8d94e
1 parent 63c26e9 commit cd38342

File tree

16 files changed

+318
-103
lines changed

16 files changed

+318
-103
lines changed

include/glow/Base/Type.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,15 @@ inline bool isFusedQuantizedElemKind(ElemKind e) {
306306
e == ElemKind::UInt4FusedFP16QTy;
307307
}
308308

309+
/// \returns the scale and offset ElemKind used by the fused ElemKind \p e.
310+
inline ElemKind getScaleOffsetElemKindFromFused(ElemKind e) {
311+
assert(isFusedQuantizedElemKind(e) && "Must pass Fused ElemKind.");
312+
if (e == ElemKind::UInt8FusedQTy) {
313+
return ElemKind::FloatTy;
314+
}
315+
return ElemKind::Float16Ty;
316+
}
317+
309318
/// A class that represents a type of a tensor.
310319
struct Type final {
311320
/// Contains the dimensions (sizes) of the tensor. Ex: [sx, sy, sz, ...].

include/glow/Graph/Graph.h

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -936,42 +936,45 @@ class Function final : public Named {
936936
/// Creates and \returns a node of \p name, performing the SparseLengthsSum
937937
/// operation, using fused rowwise quantization for the input \p data wherein
938938
/// the scales and offsets are fused inline with each row of data. \p data
939-
/// must be ElemKind::UInt8FusedQTy. Gathers slices of the outer-most
940-
/// dimension of data indexed by the \p indices vector, and then accumulates
941-
/// them into len(\p lengths) entries: first Lengths[0] slices are aggregated
942-
/// to Result[0], next Lengths[1] slices are aggregated to Result[1], etc.
943-
/// I.e. sum(Lengths) must be equal to len(Indices). \p precision represents
944-
/// what precision to use for Scale, Offset, and Result. If
945-
/// \p useFP16Accumulation, then internal arithmetic will use FP16
939+
/// must be of a fused ElemKind. Gathers slices of the outer-most dimension of
940+
/// data indexed by the \p indices vector, and then accumulates them into
941+
/// len(\p lengths) entries: first Lengths[0] slices are aggregated to
942+
/// Result[0], next Lengths[1] slices are aggregated to Result[1], etc. I.e.
943+
/// sum(Lengths) must be equal to len(Indices). The precision for the Result
944+
/// is determined by the \p data input's ElemKind used for Scale and
945+
/// Offset. If \p useFP16Accumulation, then internal arithmetic will use FP16
946946
/// accumulation; otherwise defaults to FP32.
947947
FusedRowwiseQuantizedSparseLengthsSumNode *
948-
createFusedRowwiseQuantizedSparseLengthsSum(
949-
llvm::StringRef name, Constant *data, NodeValue indices,
950-
NodeValue lengths, ElemKind precision = ElemKind::FloatTy,
951-
bool useFP16Accumulation = false);
948+
createFusedRowwiseQuantizedSparseLengthsSum(llvm::StringRef name,
949+
Constant *data, NodeValue indices,
950+
NodeValue lengths,
951+
bool useFP16Accumulation = false);
952952

953953
/// Same as \ref createFusedRowwiseQuantizedSparseLengthsSum(), but expects
954954
/// float input \p data, which is rowwise-quantized and fused internally.
955+
/// \p fusedElemKind represents the element kind to use for the final fused
956+
/// rowwise-quantized data.
955957
FusedRowwiseQuantizedSparseLengthsSumNode *
956958
createFusedRowwiseQuantizedSparseLengthsSum(
957959
llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths,
958-
ElemKind precision = ElemKind::FloatTy, bool useFP16Accumulation = false);
960+
ElemKind fusedElemKind = ElemKind::UInt8FusedQTy,
961+
bool useFP16Accumulation = false);
959962

960963
/// Same as \ref createFusedRowwiseQuantizedSparseLengthsSum(), but i-th slice
961964
/// is multiplied by weights[i]. len(weights) must be equal to len(indices).
962965
FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
963966
createFusedRowwiseQuantizedSparseLengthsWeightedSum(
964967
llvm::StringRef name, NodeValue data, NodeValue weights,
965-
NodeValue indices, NodeValue lengths,
966-
ElemKind precision = ElemKind::FloatTy, bool useFP16Accumulation = false);
968+
NodeValue indices, NodeValue lengths, bool useFP16Accumulation = false);
967969

968970
/// Same as \ref createFusedRowwiseQuantizedSparseLengthsWeightedSum(), but
969971
/// expects float input \p data, which is rowwise-quantized and fused
970-
/// internally.
972+
/// internally. \p fusedElemKind represents the element kind to use for the
973+
/// final fused rowwise-quantized data.
971974
FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
972975
createFusedRowwiseQuantizedSparseLengthsWeightedSum(
973976
llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
974-
NodeValue lengths, ElemKind precision = ElemKind::FloatTy,
977+
NodeValue lengths, ElemKind fusedElemKind = ElemKind::UInt8FusedQTy,
975978
bool useFP16Accumulation = false);
976979

977980
/// Given a vector of segment lengths, calculates offsets of each segment and

lib/Backends/CPU/tests/CPUOperatorTest.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ std::set<std::string> glow::backendTestBlacklist = {
100100
"back2/0",
101101
"FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat/0",
102102
"FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat16/0",
103+
"FusedRowwiseQuantizedSparseLengthsSum_Fused4Bit_Float16_AccumFloat16/0",
104+
"FusedRowwiseQuantizedSLWSTwoColumn_Float16_AccumFloat/0",
105+
"FusedRowwiseQuantizedSLWSTwoColumn_Float16_AccumFloat16/0",
106+
"FusedRowwiseQuantizedSLWSTwoColumn_Fused4Bit_Float16_AccumFloat16/0",
107+
"SLWSTwoColumn_Float16_AccumFloat/0",
103108
"SparseToDenseMask1/0",
104109
"SparseToDenseMask2/0",
105110
"FP16Reshape/0",

lib/Backends/Habana/tests/HabanaOperatorTest.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,12 @@ std::set<std::string> glow::backendTestBlacklist = {
147147
"FP16SoftMax/0",
148148
"Fp16Splat/0",
149149
"FP16Transpose2Dims/0",
150+
"FusedRowwiseQuantizedSLWSTwoColumn_Float16_AccumFloat/0",
151+
"FusedRowwiseQuantizedSLWSTwoColumn_Float16_AccumFloat16/0",
152+
"FusedRowwiseQuantizedSLWSTwoColumn_Fused4Bit_Float16_AccumFloat16",
150153
"FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat/0",
151154
"FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat16/0",
155+
"FusedRowwiseQuantizedSparseLengthsSum_Fused4Bit_Float16_AccumFloat16/0",
152156
"FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16/0",
153157
"FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_back_to_"
154158
"back/0",
@@ -266,6 +270,7 @@ std::set<std::string> glow::backendTestBlacklist = {
266270
"sliceReshape_Float16/0",
267271
"sliceVectors_Float16/0",
268272
"sliceVectors_Int64/0",
273+
"SLWSTwoColumn_Float16_AccumFloat/0",
269274
"SLSAllZeroLengths_Float/0",
270275
"SLSAllZeroLengths_Float16/0",
271276
"SoftMax/0",

lib/Backends/Interpreter/Interpreter.cpp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -306,32 +306,35 @@ bool Interpreter::isOpSupported(const NodeInfo &NI) const {
306306
RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx) ==
307307
ElemKind::Int32ITy);
308308

309-
case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind:
309+
case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind: {
310310
if (NI.getInElemTy(
311-
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) ==
312-
ElemKind::UInt8FusedFP16QTy) {
311+
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) !=
312+
ElemKind::Int64ITy ||
313+
NI.getInElemTy(
314+
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx) !=
315+
ElemKind::Int32ITy) {
316+
return false;
317+
}
318+
319+
switch (NI.getInElemTy(
320+
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx)) {
321+
case ElemKind::UInt4FusedFP16QTy:
322+
case ElemKind::UInt8FusedFP16QTy:
313323
return (NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
314324
WeightsIdx) == ElemKind::Float16Ty) &&
315-
(NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
316-
IndicesIdx) == ElemKind::Int64ITy) &&
317-
(NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
318-
LengthsIdx) == ElemKind::Int32ITy) &&
319325
(NI.getOutElemTy(
320326
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
321327
ResultIdx) == ElemKind::Float16Ty);
328+
case ElemKind::UInt8FusedQTy:
329+
return (NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
330+
WeightsIdx) == ElemKind::FloatTy) &&
331+
(NI.getOutElemTy(
332+
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
333+
ResultIdx) == ElemKind::FloatTy);
334+
default:
335+
return false;
322336
}
323-
return (NI.getInElemTy(
324-
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) ==
325-
ElemKind::UInt8FusedQTy) &&
326-
(NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
327-
WeightsIdx) == ElemKind::FloatTy) &&
328-
(NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
329-
IndicesIdx) == ElemKind::Int64ITy) &&
330-
(NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
331-
LengthsIdx) == ElemKind::Int32ITy) &&
332-
(NI.getOutElemTy(
333-
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx) ==
334-
ElemKind::FloatTy);
337+
}
335338

336339
case Kinded::Kind::LengthsRangeFillNodeKind:
337340
case Kinded::Kind::LengthsToRangesNodeKind:

lib/Backends/Interpreter/InterpreterNodes.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3353,7 +3353,9 @@ void BoundInterpreterFunction::
33533353
assert(totalLength <= indices->dims()[0] &&
33543354
"sum(Lengths) must be equal to len(Indices)");
33553355

3356-
const size_t inLineSize = data->size() / data->dims()[0];
3356+
const bool using4BitQuantization =
3357+
data->getType().getElementType() == ElemKind::UInt4FusedFP16QTy;
3358+
33573359
const size_t outLineSize = out->size() / out->dims()[0];
33583360

33593361
auto DH = data->getHandle<uint8_t>();
@@ -3366,13 +3368,20 @@ void BoundInterpreterFunction::
33663368
for (size_t j = 0, e = LH.raw(i); j < e; j++) {
33673369
const float weight = static_cast<float>(WH.raw(curIdx));
33683370
const size_t rowIdx = IH.raw(curIdx++);
3369-
size_t offsetIn = rowIdx * inLineSize;
33703371
T scale, offset;
33713372
std::tie(scale, offset) = DH.getFusedScaleOffsetFromRow<T>(rowIdx);
33723373
for (size_t k = 0; k < outLineSize; k++) {
3373-
float d = quantization::dequantizeWithFloatOffset(
3374-
DH.raw(offsetIn++), static_cast<float>(scale),
3375-
static_cast<float>(offset));
3374+
float d = 0.0f;
3375+
if (!using4BitQuantization) {
3376+
d = quantization::dequantizeWithFloatOffset(
3377+
DH.at({rowIdx, k}), static_cast<float>(scale),
3378+
static_cast<float>(offset));
3379+
} else {
3380+
const bool isMSB = (k % 2 == 1);
3381+
d = quantization::dequantize4BitWithFloatOffset(
3382+
DH.at({rowIdx, k / 2}), static_cast<float>(scale),
3383+
static_cast<float>(offset), isMSB);
3384+
}
33763385
accum[k] += d * weight;
33773386
}
33783387
}

lib/Backends/NNPI/tests/NNPIOperatorTest.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ std::set<std::string> glow::backendTestBlacklist = {
4747
"FullyConnected_Int16_BiasInt16/0",
4848
"FullyConnected_Int16_BiasInt32/0",
4949
"FullyConnected_Int8_BiasInt8/0",
50+
"FusedRowwiseQuantizedSLWSTwoColumn_Fused4Bit_Float16_AccumFloat16/0",
51+
"FusedRowwiseQuantizedSparseLengthsSum_Fused4Bit_Float16_AccumFloat16/0",
5052
"GroupConv3D/0",
5153
"GroupwiseQuantizedConvolution/0",
5254
"insertTensorTest/0",

lib/Backends/OpenCL/tests/OpenCLOperatorTest.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,12 @@ std::set<std::string> glow::backendTestBlacklist = {
181181
"FusedRowwiseQuantizedSparseLengthsSum_Float/0",
182182
"FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat/0",
183183
"FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat16/0",
184+
"FusedRowwiseQuantizedSparseLengthsSum_Fused4Bit_Float16_AccumFloat16/0",
185+
"FusedRowwiseQuantizedSLWSTwoColumn_Float/0",
186+
"FusedRowwiseQuantizedSLWSTwoColumn_Float16_AccumFloat/0",
187+
"FusedRowwiseQuantizedSLWSTwoColumn_Float16_AccumFloat16/0",
188+
"FusedRowwiseQuantizedSLWSTwoColumn_Fused4Bit_Float16_AccumFloat16/0",
189+
"SLWSTwoColumn_Float16_AccumFloat/0",
184190
"SLSWithZeroLengths/0",
185191
"SparseToDense/0",
186192
"SparseToDenseMask1/0",

lib/Graph/Graph.cpp

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,38 +1680,48 @@ Function::createRowwiseQuantizedSparseLengthsSum(
16801680
/// Helper used to get specific output type required for
16811681
/// createRowwiseQuantizedSparseLengthsSum and
16821682
/// createRowwiseQuantizedSparseLengthsWeightedSum.
1683-
/// Function \p F is used to get the speficific type, using inputs \p inDims and
1684-
/// \p lenghtsDims to compute output dimensions.
1685-
static TypeRef getOutputTypeOfFusedRowwiseQuantizedSLS(
1686-
Function *F, const llvm::ArrayRef<size_t> &inDims,
1687-
const llvm::ArrayRef<size_t> &lengthsDims, ElemKind scaleOffsetKind) {
1688-
ShapeVector outDims(inDims.begin(), inDims.end());
1683+
/// Function \p F is used to get the specific type, using inputs \p data and
1684+
/// \p lengthsDims to compute output dimensions.
1685+
static TypeRef
1686+
getOutputTypeOfFusedRowwiseQuantizedSLS(Function *F, NodeValue data,
1687+
llvm::ArrayRef<size_t> lengthsDims) {
1688+
ShapeVector outDims(data.dims().begin(), data.dims().end());
16891689
outDims[0] = lengthsDims[0];
16901690
// The output column count is the same as the input column count, but
16911691
// without the extra bytes for the fused scale/offset, as the output is not
16921692
// fused.
1693-
outDims[1] -=
1694-
2 * ((scaleOffsetKind == ElemKind::FloatTy) ? sizeof(float)
1695-
: sizeof(float16_t));
1696-
return F->getParent()->uniqueType(scaleOffsetKind, outDims);
1693+
CHECK(isFusedQuantizedElemKind(data.getElementType()))
1694+
<< "Must use a fused ElemKind for data.";
1695+
outDims[1] -= 2 * ((data.getElementType() == ElemKind::UInt8FusedQTy)
1696+
? sizeof(float)
1697+
: sizeof(float16_t));
1698+
// If using 4-bit quantization, then the input data has packed two 4-bit
1699+
// elements into one byte, so we need to double the outDims.
1700+
if (data.getElementType() == ElemKind::UInt4FusedFP16QTy) {
1701+
outDims[1] *= 2;
1702+
}
1703+
const ElemKind outputK = (data.getElementType() == ElemKind::UInt8FusedQTy)
1704+
? ElemKind::FloatTy
1705+
: ElemKind::Float16Ty;
1706+
return F->getParent()->uniqueType(outputK, outDims);
16971707
}
16981708

16991709
FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
17001710
Function::createFusedRowwiseQuantizedSparseLengthsWeightedSum(
17011711
llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices,
1702-
NodeValue lengths, ElemKind precision, bool useFP16Accumulation) {
1703-
auto outTy = getOutputTypeOfFusedRowwiseQuantizedSLS(
1704-
this, data.dims(), lengths.dims(), precision);
1712+
NodeValue lengths, bool useFP16Accumulation) {
1713+
auto outTy =
1714+
getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, lengths.dims());
17051715
return addNode(new FusedRowwiseQuantizedSparseLengthsWeightedSumNode(
17061716
name, outTy, data, weights, indices, lengths, useFP16Accumulation));
17071717
}
17081718

17091719
FusedRowwiseQuantizedSparseLengthsSumNode *
17101720
Function::createFusedRowwiseQuantizedSparseLengthsSum(
17111721
llvm::StringRef name, Constant *data, NodeValue indices, NodeValue lengths,
1712-
ElemKind precision, bool useFP16Accumulation) {
1713-
auto outTy = getOutputTypeOfFusedRowwiseQuantizedSLS(
1714-
this, data->dims(), lengths.dims(), precision);
1722+
bool useFP16Accumulation) {
1723+
auto outTy =
1724+
getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, lengths.dims());
17151725
return addNode(new FusedRowwiseQuantizedSparseLengthsSumNode(
17161726
name, outTy, data, indices, lengths, useFP16Accumulation));
17171727
}
@@ -1734,18 +1744,30 @@ static Constant *quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
17341744
// dimension to include space for the scale/offset, each 4 bytes
17351745
// (float/int32_t).
17361746
switch (precision) {
1737-
case ElemKind::FloatTy: {
1747+
case ElemKind::UInt8FusedQTy: {
17381748
Constant *rwqData = F->getParent()->createConstant(
1739-
ElemKind::UInt8FusedQTy,
1740-
{fDims.first, fDims.second + 2 * sizeof(float)}, 0.0, 0, "data");
1749+
precision, {fDims.first, fDims.second + 2 * sizeof(float)}, 0.0, 0,
1750+
"data");
17411751
quantization::tensorFusedRowwiseQuantization<float>(
17421752
fData, rwqData->getPayloadMutable());
17431753
return rwqData;
17441754
}
1745-
case ElemKind::Float16Ty: {
1755+
case ElemKind::UInt8FusedFP16QTy: {
17461756
Constant *rwqData = F->getParent()->createConstant(
1747-
ElemKind::UInt8FusedFP16QTy,
1748-
{fDims.first, fDims.second + 2 * sizeof(float16_t)}, 0.0, 0, "data");
1757+
precision, {fDims.first, fDims.second + 2 * sizeof(float16_t)}, 0.0, 0,
1758+
"data");
1759+
quantization::tensorFusedRowwiseQuantization<float16_t>(
1760+
fData, rwqData->getPayloadMutable());
1761+
return rwqData;
1762+
}
1763+
case ElemKind::UInt4FusedFP16QTy: {
1764+
// We pack 4-bit values into bytes, so given the input size in float we
1765+
// divide by two and take the ceiling to make sure we have enough space for
1766+
// all elements.
1767+
const size_t outerDim =
1768+
std::ceil(((float)fDims.second) / 2) + 2 * sizeof(float16_t);
1769+
Constant *rwqData = F->getParent()->createConstant(
1770+
precision, {fDims.first, outerDim}, 0.0, 0, "data");
17491771
quantization::tensorFusedRowwiseQuantization<float16_t>(
17501772
fData, rwqData->getPayloadMutable());
17511773
return rwqData;
@@ -1758,23 +1780,23 @@ static Constant *quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
17581780
FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
17591781
Function::createFusedRowwiseQuantizedSparseLengthsWeightedSum(
17601782
llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
1761-
NodeValue lengths, ElemKind precision, bool useFP16Accumulation) {
1783+
NodeValue lengths, ElemKind fusedElemKind, bool useFP16Accumulation) {
17621784
Constant *rwqData =
1763-
quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(this, data,
1764-
precision);
1785+
quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
1786+
this, data, fusedElemKind);
17651787
return createFusedRowwiseQuantizedSparseLengthsWeightedSum(
1766-
name, rwqData, weights, indices, lengths, precision, useFP16Accumulation);
1788+
name, rwqData, weights, indices, lengths, useFP16Accumulation);
17671789
}
17681790

17691791
FusedRowwiseQuantizedSparseLengthsSumNode *
17701792
Function::createFusedRowwiseQuantizedSparseLengthsSum(
17711793
llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths,
1772-
ElemKind precision, bool useFP16Accumulation) {
1794+
ElemKind fusedElemKind, bool useFP16Accumulation) {
17731795
Constant *rwqData =
1774-
quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(this, data,
1775-
precision);
1796+
quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
1797+
this, data, fusedElemKind);
17761798
return this->createFusedRowwiseQuantizedSparseLengthsSum(
1777-
name, rwqData, indices, lengths, precision, useFP16Accumulation);
1799+
name, rwqData, indices, lengths, useFP16Accumulation);
17781800
}
17791801

17801802
LengthsToRangesNode *Function::createLengthsToRanges(llvm::StringRef name,

lib/Graph/Nodes.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1264,10 +1264,16 @@ static bool verifyFusedRowwiseQuantizedSparseLengthsSum(
12641264
// Wrap this in isValid to prevent potential segfault if the result is
12651265
// incorrectly shaped.
12661266
if (isValid) {
1267+
// If using 4-bit quantization for embeddings then the input is packed into
1268+
// two elements per byte.
1269+
size_t finalSize = result.dims()[1];
1270+
if (data.getType()->getElementType() == ElemKind::UInt4FusedFP16QTy) {
1271+
finalSize /= 2;
1272+
}
12671273
isValid &=
12681274
expectCompareTrue("Result output shape should have second dim without "
12691275
"extra columns from scale/offset in Data.",
1270-
result.dims()[1] + extraCols, data.dims()[1], parent);
1276+
finalSize + extraCols, data.dims()[1], parent);
12711277
}
12721278
return isValid;
12731279
}

0 commit comments

Comments
 (0)