Skip to content

Commit

Permalink
Merge pull request #283 from Xilinx/jrickert.unranked_types
Browse files Browse the repository at this point in the history
Make ONNX to TOSA lowering more resistant to dynamic shapes and unranked types.
  • Loading branch information
jorickert authored Feb 3, 2025
2 parents 3939ed0 + 063fbb0 commit a96118c
Show file tree
Hide file tree
Showing 14 changed files with 256 additions and 63 deletions.
58 changes: 35 additions & 23 deletions src/Conversion/ONNXToTOSA/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ Value TosaBuilder::createConst(
}

bool TosaBuilder::needsRankBroadcast(ValueRange valueRange) {
if (llvm::any_of(valueRange, [](const auto value) {
return !mlir::cast<ShapedType>(value.getType()).hasRank();
})) {
return false; // we have no way to determine the broadcast, so do not
// attempt it
}
int64_t firstRank = mlir::cast<ShapedType>(valueRange[0].getType()).getRank();
for (Value operand : valueRange) {
auto operandType = mlir::cast<ShapedType>(operand.getType());
Expand Down Expand Up @@ -129,9 +135,8 @@ Value TosaBuilder::getConst(ArrayRef<float> vec, ArrayRef<int64_t> shape) {
return constOp;
}

Value TosaBuilder::getSplattedConst(
float val, Type dtype, llvm::ArrayRef<int64_t> shape) {
auto constType = tosa::reduceAxisToOne(shape, rewriter().getF32Type());
Value TosaBuilder::getSplattedConst(float val, Type dtype, int64_t rank) {
auto constType = tosa::reduceAxisToOne(rank, rewriter().getF32Type());
auto constAttr = DenseElementsAttr::get(constType, val);

auto constOp =
Expand All @@ -150,8 +155,7 @@ Value TosaBuilder::transpose(Value &value, llvm::ArrayRef<int32_t> perm) {
auto valueType = mlir::cast<ShapedType>(value.getType());
// get new value type
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(
valueType.getShape().size(), ShapedType::kDynamic),
llvm::SmallVector<int64_t, 4>(perm.size(), ShapedType::kDynamic),
valueType.getElementType());
// create transpose for value
Value newValue = tosa::CreateOpAndInfer<mlir::tosa::TransposeOp>(
Expand Down Expand Up @@ -195,9 +199,12 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
rhs = valueVec[1];
}
auto lhsType = mlir::cast<ShapedType>(lhs.getType());
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
Type newValueType =
(!lhsType.hasRank())
? lhsType
: RankedTensorType::get(llvm::SmallVector<int64_t, 4>(
lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
return tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
rewriter(), loc(), newValueType, lhs, rhs, shift);
}
Expand All @@ -215,9 +222,12 @@ Value TosaBuilder::intdiv(Value &lhs, Value &rhs) {
}

auto lhsType = mlir::cast<ShapedType>(lhs.getType());
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(lhsType.getRank(), ShapedType::kDynamic),
lhsElementType);
Type newValueType =
(!lhsType.hasRank())
? lhsType
: RankedTensorType::get(llvm::SmallVector<int64_t, 4>(
lhsType.getRank(), ShapedType::kDynamic),
lhsElementType);
return tosa::CreateOpAndInfer<mlir::tosa::IntDivOp>(
rewriter(), loc(), newValueType, lhs, rhs);
}
Expand All @@ -230,9 +240,12 @@ Value TosaBuilder::binaryOp(Value &lhs, Value &rhs) {
rhs = valueVec[1];
}
auto lhsType = mlir::cast<ShapedType>(lhs.getType());
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
Type newValueType =
(!lhsType.hasRank())
? lhsType
: RankedTensorType::get(llvm::SmallVector<int64_t, 4>(
lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
return tosa::CreateOpAndInfer<T>(rewriter(), loc(), newValueType, lhs, rhs);
}

Expand All @@ -246,11 +259,7 @@ template Value TosaBuilder::binaryOp<mlir::tosa::PowOp>(

template <typename T>
Value TosaBuilder::unaryOp(mlir::Value &input) {
auto inputType = cast<ShapedType>(input.getType());
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(inputType.getRank(), ShapedType::kDynamic),
inputType.getElementType());
return tosa::CreateOpAndInfer<T>(rewriter(), loc(), newValueType, input);
return tosa::CreateOpAndInfer<T>(rewriter(), loc(), input.getType(), input);
}

template Value TosaBuilder::unaryOp<mlir::tosa::ExpOp>(mlir::Value &input);
Expand Down Expand Up @@ -305,9 +314,12 @@ Value TosaBuilder::select(
rhs = valueVec[2];
}
auto lhsType = cast<ShapedType>(lhs.getType());
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
Type newValueType =
(!lhsType.hasRank())
? lhsType
: RankedTensorType::get(llvm::SmallVector<int64_t, 4>(
lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
return tosa::CreateOpAndInfer<mlir::tosa::SelectOp>(
rewriter(), loc(), newValueType, cond, lhs, rhs);
}
Expand All @@ -328,7 +340,7 @@ mlir::Value TosaBuilder::castToNewTensorElementType(
Value TosaBuilder::sqrt(mlir::Value &input) {
auto inputType = cast<ShapedType>(input.getType());
auto oneHalf = this->getSplattedConst(
0.5, inputType.getElementType(), inputType.getShape());
0.5, inputType.getElementType(), inputType.getRank());
return this->binaryOp<mlir::tosa::PowOp>(input, oneHalf);
}

Expand Down
3 changes: 1 addition & 2 deletions src/Conversion/ONNXToTOSA/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ struct TosaBuilder : DialectBuilder {
// The tensor will have the same rank as shape but all dimensions will
// have size 1 (differs from tensorflow impl.)
// If dtype is provided, it also cast the value to the appropriate dtype.
mlir::Value getSplattedConst(
float val, mlir::Type dtype, llvm::ArrayRef<int64_t> shape = {});
mlir::Value getSplattedConst(float val, mlir::Type dtype, int64_t rank);

// Creates a constant of shape <1x1x...x1> of rank `rank` with all values set
// to `value`.
Expand Down
56 changes: 39 additions & 17 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,12 @@ static LogicalResult legalizeFloatingPointPrelu(Operation *op,
auto loc = op->getLoc();
TosaBuilder tosaBuilder(rewriter, loc);
Value constZero = tosaBuilder.getSplattedConst(
0.0, outputType.getElementType(), outputType.getShape());
0.0, outputType.getElementType(), outputType.getRank());

auto mul = tosaBuilder.mul(input, alphaOrSlope);
auto greaterEqual = tosaBuilder.greaterEqual(input, constZero);
auto select = tosaBuilder.select(greaterEqual, input, mul);

copySingleResultType(op, select);
rewriter.replaceOp(op, {select});
return success();
}
Expand Down Expand Up @@ -274,7 +274,7 @@ class ONNXLeakyReluOpLoweringToTOSA
TosaBuilder tosaBuilder(rewriter, loc);
return legalizeFloatingPointPrelu(op, rewriter, adaptor.getX(),
tosaBuilder.getSplattedConst(
alpha, outputType.getElementType(), outputType.getShape()),
alpha, outputType.getElementType(), outputType.getRank()),
outputType);
}
};
Expand Down Expand Up @@ -312,6 +312,7 @@ class ONNXComparisonOpLoweringToTOSA : public OpConversionPattern<OnnxCompOp> {
} else if constexpr (std::is_same_v<OnnxCompOp, ONNXLessOp>) {
res = tosaBuilder.less(input1, input2);
}
copySingleResultType(op, res);
rewriter.replaceOp(op, {res});
return success();
}
Expand Down Expand Up @@ -393,7 +394,7 @@ class ONNXCastOpLoweringToTOSA : public OpConversionPattern<ONNXCastOp> {
// onnx.Cast and tosa.cast.
if (resultTy.getElementType().getIntOrFloatBitWidth() != 1) {
auto zero = tosaBuilder.getSplattedConst(
0.0f, inputTy.getElementType(), resultTy.getShape());
0.0f, inputTy.getElementType(), resultTy.getRank());
auto positive = tosaBuilder.greaterEqual(input, zero);

auto floor = tosaBuilder.unaryOp<mlir::tosa::FloorOp>(input);
Expand Down Expand Up @@ -421,13 +422,15 @@ class ONNXDivOpLoweringToTOSA : public OpConversionPattern<ONNXDivOp> {

if (isa<IntegerType>(resultElementType)) {
Value divOp = tosaBuilder.intdiv(lhs, rhs);
copySingleResultType(op, divOp);
rewriter.replaceOp(op, {divOp});
return success();
}
// For floating point types, decompose ONNXDivOp into
// tosa::ReciprocalOp and tosa::MulOp.
Value reciprocalOp = tosaBuilder.unaryOp<mlir::tosa::ReciprocalOp>(rhs);
Value mulOp = tosaBuilder.mul(lhs, reciprocalOp);
copySingleResultType(op, mulOp);
rewriter.replaceOp(op, {mulOp});
return success();
}
Expand Down Expand Up @@ -472,20 +475,21 @@ class ONNXEluOpLoweringToTOSA : public OpConversionPattern<ONNXEluOp> {
TosaBuilder tosaBuilder(rewriter, op->getLoc());

Value one = tosaBuilder.getSplattedConst(
1.0, resultTensorType.getElementType(), resultTensorType.getShape());
1.0, resultTensorType.getElementType(), resultTensorType.getRank());
Value alpha =
tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(),
resultTensorType.getElementType(), resultTensorType.getShape());
resultTensorType.getElementType(), resultTensorType.getRank());
Value constZero = tosaBuilder.getSplattedConst(
0.0, resultTensorType.getElementType(), resultTensorType.getShape());
0.0, resultTensorType.getElementType(), resultTensorType.getRank());

Value exp = tosaBuilder.unaryOp<mlir::tosa::ExpOp>(input);
copySingleResultType(op, exp);
Value expMinusOne = tosaBuilder.binaryOp<mlir::tosa::SubOp>(exp, one);
Value alphaTimesExpMinusOne = tosaBuilder.mul(expMinusOne, alpha);
Value greaterEqual = tosaBuilder.greaterEqual(input, constZero);
auto select =
tosaBuilder.select(greaterEqual, input, alphaTimesExpMinusOne);

copySingleResultType(op, select);
rewriter.replaceOp(op, {select});
return success();
}
Expand Down Expand Up @@ -516,11 +520,16 @@ class ONNXHardSigmoidOpLoweringToTOSA
APFloat oneOverAlpha(alpha.getSemantics(), 1);
oneOverAlpha.divide(alpha, APFloat::rmNearestTiesToEven);

if (!resultType.hasRank()) {
return rewriter.notifyMatchFailure(
op, "HardSigmoid: Static shape required to create splatted const");
}

Value constBetaOverAlpha =
tosaBuilder.getSplattedConst(betaOverAlpha.convertToDouble(),
resultElementType, resultType.getShape());
resultElementType, resultType.getRank());
Value constAlpha = tosaBuilder.getSplattedConst(
alpha.convertToDouble(), resultElementType, resultType.getShape());
alpha.convertToDouble(), resultElementType, resultType.getRank());

auto addOp =
tosaBuilder.binaryOp<mlir::tosa::AddOp>(input, constBetaOverAlpha);
Expand All @@ -530,7 +539,7 @@ class ONNXHardSigmoidOpLoweringToTOSA
rewriter.getF32FloatAttr(0),
rewriter.getF32FloatAttr(oneOverAlpha.convertToDouble()));
auto mulOp = tosaBuilder.mul(clampOp, constAlpha);

copySingleResultType(op, mulOp);
rewriter.replaceOp(op, {mulOp});
return success();
}
Expand Down Expand Up @@ -565,14 +574,19 @@ class ONNXSoftplusOpLoweringToTOSA
if (failed(IsFloat::checkType(rewriter, outputType.getElementType(), op))) {
return failure();
}
if (!outputType.hasRank()) {
return rewriter.notifyMatchFailure(
op, "ONNXSoftplusOp: Rank required to create splatted const");
}

Value input = adaptor.getX();

TosaBuilder tosaBuilder(rewriter, op->getLoc());
auto one = tosaBuilder.getSplattedConst(
1.0, outputType.getElementType(), outputType.getShape());
1.0, outputType.getElementType(), outputType.getRank());

auto expOp = tosaBuilder.unaryOp<mlir::tosa::ExpOp>(input);
copySingleResultType(op, expOp);
auto expPlusOne = tosaBuilder.binaryOp<mlir::tosa::AddOp>(expOp, one);
auto logOp = tosaBuilder.unaryOp<mlir::tosa::LogOp>(expPlusOne);
rewriter.replaceOp(op, {logOp});
Expand All @@ -594,15 +608,19 @@ class ONNXSeluOpLoweringToTOSA : public OpConversionPattern<ONNXSeluOp> {
Value input = adaptor.getX();

TosaBuilder tosaBuilder(rewriter, op->getLoc());
if (!outputType.hasRank()) {
return rewriter.notifyMatchFailure(
op, "ONNXSeluOp: Rank required to create splatted const");
}

Value alpha =
tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(),
outputType.getElementType(), outputType.getShape());
outputType.getElementType(), outputType.getRank());
Value gamma =
tosaBuilder.getSplattedConst(adaptor.getGamma().convertToDouble(),
outputType.getElementType(), outputType.getShape());
outputType.getElementType(), outputType.getRank());
Value constZero = tosaBuilder.getSplattedConst(
0.0, outputType.getElementType(), outputType.getShape());
0.0, outputType.getElementType(), outputType.getRank());

Value exp = tosaBuilder.unaryOp<mlir::tosa::ExpOp>(input);
Value expTimesAlpha = tosaBuilder.mul(exp, alpha);
Expand Down Expand Up @@ -630,15 +648,19 @@ class ONNXThresholdedReluOpLoweringToTOSA
rewriter, outputType.getElementType(), op))) {
return failure();
}
if (!outputType.hasRank()) {
return rewriter.notifyMatchFailure(
op, "ONNXThresholdedReluOp: Rank required to create splatted const");
}

Value input = adaptor.getX();

TosaBuilder tosaBuilder(rewriter, op->getLoc());
auto alpha =
tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(),
outputType.getElementType(), outputType.getShape());
outputType.getElementType(), outputType.getRank());
auto zero = tosaBuilder.getSplattedConst(
0.0, outputType.getElementType(), outputType.getShape());
0.0, outputType.getElementType(), outputType.getRank());

auto greater = tosaBuilder.greater(input, alpha);
auto select = tosaBuilder.select(greater, input, zero);
Expand Down
8 changes: 6 additions & 2 deletions src/Conversion/ONNXToTOSA/Math/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern<ONNXGemmOp> {
FloatAttr beta = adaptor.getBetaAttr();
auto AType = mlir::cast<TensorType>(A.getType());
auto BType = mlir::cast<TensorType>(B.getType());
if (!AType.hasRank() || !BType.hasRank()) {
return rewriter.notifyMatchFailure(
op, "Lowering Gemm to MatMul requires ranked A and B.");
}
auto shapeA = AType.getShape();
auto shapeB = BType.getShape();
auto resultType = mlir::cast<TensorType>(
Expand Down Expand Up @@ -103,7 +107,7 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern<ONNXGemmOp> {
if (alpha && alpha.getValueAsDouble() != 1.) {
Value splattedConstAlpha = tosaBuilder.getSplattedConst(
static_cast<float>(alpha.getValueAsDouble()), AType.getElementType(),
newShapeA);
newShapeA.size());
alphaMulResult = tosaBuilder.mul(splattedConstAlpha, A, 0);
}

Expand All @@ -112,7 +116,7 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern<ONNXGemmOp> {
if (beta && isCPresent && beta.getValueAsDouble() != 1.) {
Value splattedConstBeta = tosaBuilder.getSplattedConst(
static_cast<float>(beta.getValueAsDouble()), AType.getElementType(),
newShapeA);
newShapeA.size());
betaMulResult = tosaBuilder.mul(splattedConstBeta, C, 0);
}

Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToTOSA/Math/Reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ LogicalResult reduceMeanLowering(ONNXReduceMeanOp op,

TosaBuilder tosaBuilder(rewriter, op->getLoc());
Value divConst = tosaBuilder.getSplattedConst(
divScale, outputType.getElementType(), outputType.getShape());
divScale, outputType.getElementType(), outputType.getRank());
auto output = tosaBuilder.mul(val, divConst);

if (!output) {
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToTOSA/NN/AveragePool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ LogicalResult handleIncludePadAttr(
Value padding = tosa::buildOnnxToTosaPaddingConstOp(
rewriter, pads, loc, {0, 0, 0, 0}, {});
auto constTosaTensor =
tosaBuilder.getSplattedConst(0.0, inputType.getElementType());
tosaBuilder.getSplattedConst(0.0, inputType.getElementType(), 0);

auto padOp = tosa::CreateOpAndInfer<mlir::tosa::PadOp>(rewriter, loc,
mlir::RankedTensorType::get(
Expand Down
7 changes: 6 additions & 1 deletion src/Conversion/ONNXToTOSA/NN/BatchNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ class ONNXBatchNormalizationInferenceModeOpLoweringToTOSA
OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {

auto outType = getTypeConverter()->convertType(op.getResult().getType());
if (!cast<ShapedType>(outType).hasRank()) {
return rewriter.notifyMatchFailure(op,
"ONNXBatchNormalizationInferenceModeOp to "
"TOSA requires a ranked result type");
}
auto outTensorType = cast<RankedTensorType>(outType);

// The layout of the output is N x C x D1 x D2 … Dn. For batch
Expand Down Expand Up @@ -60,7 +65,7 @@ class ONNXBatchNormalizationInferenceModeOpLoweringToTOSA
// epsilon's shape: constant -> {1, 1, 1, ...}
newShape[1] = 1;
auto eps = tosaBuilder.getSplattedConst(op.getEpsilon().convertToFloat(),
outTensorType.getElementType(), newShape);
outTensorType.getElementType(), newShape.size());

// output = (input - mean) * scale * rsqrt(var + eps) + bias
auto op1SubInputMean =
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ std::optional<Value> convertReduceMeanOp(PatternRewriter &rewriter,

if (!input_is_qtype) {
Value div_const = tosaBuilder.getSplattedConst(
div_scale, output_type.getElementType(), output_type.getShape());
div_scale, output_type.getElementType(), output_type.getRank());
return tosaBuilder.mul(val.value(), div_const);
}

Expand Down
4 changes: 2 additions & 2 deletions src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ llvm::SmallVector<int64_t> createInt64VectorFromIndexExpr(
}

mlir::RankedTensorType reduceAxisToOne(
llvm::ArrayRef<int64_t> shape, Type elementType, Attribute encoding) {
int64_t rank, Type elementType, Attribute encoding) {
return mlir::RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(shape.size(), 1), elementType, encoding);
llvm::SmallVector<int64_t, 4>(rank, 1), elementType, encoding);
}

mlir::ElementsAttr getElementsAttrFromConst(mlir::Value &val) {
Expand Down
Loading

0 comments on commit a96118c

Please sign in to comment.