@@ -1680,38 +1680,48 @@ Function::createRowwiseQuantizedSparseLengthsSum(
16801680// / Helper used to get specific output type required for
16811681// / createRowwiseQuantizedSparseLengthsSum and
16821682// / createRowwiseQuantizedSparseLengthsWeightedSum.
1683- // / Function \p F is used to get the speficific type, using inputs \p inDims and
1684- // / \p lenghtsDims to compute output dimensions.
1685- static TypeRef getOutputTypeOfFusedRowwiseQuantizedSLS (
1686- Function *F, const llvm::ArrayRef< size_t > &inDims ,
1687- const llvm::ArrayRef<size_t > & lengthsDims, ElemKind scaleOffsetKind ) {
1688- ShapeVector outDims (inDims. begin (), inDims .end ());
1683+ // / Function \p F is used to get the specific type, using inputs \p data and
1684+ // / \p lengthsDims to compute output dimensions.
1685+ static TypeRef
1686+ getOutputTypeOfFusedRowwiseQuantizedSLS ( Function *F, NodeValue data ,
1687+ llvm::ArrayRef<size_t > lengthsDims) {
1688+ ShapeVector outDims (data. dims (). begin (), data. dims () .end ());
16891689 outDims[0 ] = lengthsDims[0 ];
16901690 // The output column count is the same as the input column count, but
16911691 // without the extra bytes for the fused scale/offset, as the output is not
16921692 // fused.
1693- outDims[1 ] -=
1694- 2 * ((scaleOffsetKind == ElemKind::FloatTy) ? sizeof (float )
1695- : sizeof (float16_t ));
1696- return F->getParent ()->uniqueType (scaleOffsetKind, outDims);
1693+ CHECK (isFusedQuantizedElemKind (data.getElementType ()))
1694+ << " Must use a fused ElemKind for data." ;
1695+ outDims[1 ] -= 2 * ((data.getElementType () == ElemKind::UInt8FusedQTy)
1696+ ? sizeof (float )
1697+ : sizeof (float16_t ));
1698+ // If using 4-bit quantization, then the input data has packed two 4-bit
1699+ // elements into one byte, so we need to double the outDims.
1700+ if (data.getElementType () == ElemKind::UInt4FusedFP16QTy) {
1701+ outDims[1 ] *= 2 ;
1702+ }
1703+ const ElemKind outputK = (data.getElementType () == ElemKind::UInt8FusedQTy)
1704+ ? ElemKind::FloatTy
1705+ : ElemKind::Float16Ty;
1706+ return F->getParent ()->uniqueType (outputK, outDims);
16971707}
16981708
16991709FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
17001710Function::createFusedRowwiseQuantizedSparseLengthsWeightedSum (
17011711 llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices,
1702- NodeValue lengths, ElemKind precision, bool useFP16Accumulation) {
1703- auto outTy = getOutputTypeOfFusedRowwiseQuantizedSLS (
1704- this , data. dims () , lengths.dims (), precision );
1712+ NodeValue lengths, bool useFP16Accumulation) {
1713+ auto outTy =
1714+ getOutputTypeOfFusedRowwiseQuantizedSLS ( this , data, lengths.dims ());
17051715 return addNode (new FusedRowwiseQuantizedSparseLengthsWeightedSumNode (
17061716 name, outTy, data, weights, indices, lengths, useFP16Accumulation));
17071717}
17081718
17091719FusedRowwiseQuantizedSparseLengthsSumNode *
17101720Function::createFusedRowwiseQuantizedSparseLengthsSum (
17111721 llvm::StringRef name, Constant *data, NodeValue indices, NodeValue lengths,
1712- ElemKind precision, bool useFP16Accumulation) {
1713- auto outTy = getOutputTypeOfFusedRowwiseQuantizedSLS (
1714- this , data-> dims () , lengths.dims (), precision );
1722+ bool useFP16Accumulation) {
1723+ auto outTy =
1724+ getOutputTypeOfFusedRowwiseQuantizedSLS ( this , data, lengths.dims ());
17151725 return addNode (new FusedRowwiseQuantizedSparseLengthsSumNode (
17161726 name, outTy, data, indices, lengths, useFP16Accumulation));
17171727}
@@ -1734,18 +1744,30 @@ static Constant *quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
17341744 // dimension to include space for the scale/offset, each 4 bytes
17351745 // (float/int32_t).
17361746 switch (precision) {
1737- case ElemKind::FloatTy : {
1747+ case ElemKind::UInt8FusedQTy : {
17381748 Constant *rwqData = F->getParent ()->createConstant (
1739- ElemKind::UInt8FusedQTy ,
1740- { fDims . first , fDims . second + 2 * sizeof ( float )}, 0.0 , 0 , " data" );
1749+ precision, { fDims . first , fDims . second + 2 * sizeof ( float )}, 0.0 , 0 ,
1750+ " data" );
17411751 quantization::tensorFusedRowwiseQuantization<float >(
17421752 fData , rwqData->getPayloadMutable ());
17431753 return rwqData;
17441754 }
1745- case ElemKind::Float16Ty : {
1755+ case ElemKind::UInt8FusedFP16QTy : {
17461756 Constant *rwqData = F->getParent ()->createConstant (
1747- ElemKind::UInt8FusedFP16QTy,
1748- {fDims .first , fDims .second + 2 * sizeof (float16_t )}, 0.0 , 0 , " data" );
1757+ precision, {fDims .first , fDims .second + 2 * sizeof (float16_t )}, 0.0 , 0 ,
1758+ " data" );
1759+ quantization::tensorFusedRowwiseQuantization<float16_t >(
1760+ fData , rwqData->getPayloadMutable ());
1761+ return rwqData;
1762+ }
1763+ case ElemKind::UInt4FusedFP16QTy: {
1764+ // We pack 4-bit values into bytes, so given the input size in float we
1765+ // divide by two and take the ceiling to make sure we have enough space for
1766+ // all elements.
1767+ const size_t outerDim =
1768+ std::ceil (((float )fDims .second ) / 2 ) + 2 * sizeof (float16_t );
1769+ Constant *rwqData = F->getParent ()->createConstant (
1770+ precision, {fDims .first , outerDim}, 0.0 , 0 , " data" );
17491771 quantization::tensorFusedRowwiseQuantization<float16_t >(
17501772 fData , rwqData->getPayloadMutable ());
17511773 return rwqData;
@@ -1758,23 +1780,23 @@ static Constant *quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
17581780FusedRowwiseQuantizedSparseLengthsWeightedSumNode *
17591781Function::createFusedRowwiseQuantizedSparseLengthsWeightedSum (
17601782 llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
1761- NodeValue lengths, ElemKind precision , bool useFP16Accumulation) {
1783+ NodeValue lengths, ElemKind fusedElemKind , bool useFP16Accumulation) {
17621784 Constant *rwqData =
1763- quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum (this , data,
1764- precision );
1785+ quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum (
1786+ this , data, fusedElemKind );
17651787 return createFusedRowwiseQuantizedSparseLengthsWeightedSum (
1766- name, rwqData, weights, indices, lengths, precision, useFP16Accumulation);
1788+ name, rwqData, weights, indices, lengths, useFP16Accumulation);
17671789}
17681790
17691791FusedRowwiseQuantizedSparseLengthsSumNode *
17701792Function::createFusedRowwiseQuantizedSparseLengthsSum (
17711793 llvm::StringRef name, Tensor &data, NodeValue indices, NodeValue lengths,
1772- ElemKind precision , bool useFP16Accumulation) {
1794+ ElemKind fusedElemKind , bool useFP16Accumulation) {
17731795 Constant *rwqData =
1774- quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum (this , data,
1775- precision );
1796+ quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum (
1797+ this , data, fusedElemKind );
17761798 return this ->createFusedRowwiseQuantizedSparseLengthsSum (
1777- name, rwqData, indices, lengths, precision, useFP16Accumulation);
1799+ name, rwqData, indices, lengths, useFP16Accumulation);
17781800}
17791801
17801802LengthsToRangesNode *Function::createLengthsToRanges (llvm::StringRef name,
0 commit comments