Skip to content

Commit

Permalink
Merge pull request #198 from Xilinx/bump_to_74f9df65
Browse files Browse the repository at this point in the history
[AutoBump] Merge with 74f9df6 (Aug 07) (27)
  • Loading branch information
mgehre-amd authored Oct 1, 2024
2 parents d9371bc + 8788179 commit a4e43f1
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 132 deletions.
1 change: 1 addition & 0 deletions src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
ONNXEntryPointOp::getEntryPointFuncAttrName());
StringRef entryPointName = funcRefAttr.getLeafReference().getValue();
Operation *entryPointOp = module.lookupSymbol(entryPointName);
assert(entryPointOp && "entry point name not found!");
func::FuncOp entryPointFunc = cast<func::FuncOp>(entryPointOp);

IntegerAttr numInputsAttr =
Expand Down
9 changes: 5 additions & 4 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ Value emitPostProcessingFor(ConversionPatternRewriter &rewriter, Location loc,

template <typename Op>
static void CheckIfCustomScalarOpIsSupported(Type elementType) {
Type actualElementType = MathBuilder::elementTypeWithVector(elementType);
Type actualElementType =
MathBuilder::elementTypeOfScalarOrVector(elementType);
if (mlir::isa<mlir::IntegerType>(actualElementType)) {
if constexpr (std::is_same<ScalarIOp<Op>, CustomScalarOp>::value)
return;
Expand Down Expand Up @@ -914,7 +915,7 @@ Value emitScalarOpFor<ONNXSignOp>(ConversionPatternRewriter &rewriter,
// ConstantOp 0,
// %Y)
Value plusSelect;
if (create.math.isUnsignedIntegerWithVector(elementType)) {
if (create.math.isScalarOrVectorUnsignedInteger(elementType)) {
// Unsigned integers are by definition positive.
plusSelect = one;
} else {
Expand Down Expand Up @@ -1188,7 +1189,7 @@ Value emitScalarOpFor<ONNXModOp>(ConversionPatternRewriter &rewriter,
MultiDialectBuilder<MathBuilder, KrnlBuilder> create(rewriter, loc);

// TODO: here we assume fmod=1, what should if that is not the case?
if (create.math.isFloatWithVector(elementType)) {
if (create.math.isScalarOrVectorFloat(elementType)) {
// fmod is always 1. Behavior is like numpy.fmod.
// The sign of the remainder is the same as the dividend.
Value rem = create.math.rem(dividend, divisor);
Expand All @@ -1201,7 +1202,7 @@ Value emitScalarOpFor<ONNXModOp>(ConversionPatternRewriter &rewriter,
return create.math.copySign(rem, dividend);
#endif
}
if (create.math.isIntegerWithVector(elementType)) {
if (create.math.isScalarOrVectorInteger(elementType)) {
// "math.rem" returns "minus" for minus dividend and "plus or zero" for plus
// dividend. We call the math.rem's return value "mathRemainder". However
// onnx.ModOp should return "minus" for minus divisor and "plus or zero" for
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ mlir::Value emitScalarOpFor(mlir::ConversionPatternRewriter &rewriter,
// int. Thus we look at the type the first input argument, and not the output
// elementType.
mlir::Type actualElementType =
MathBuilder::elementTypeWithVector(scalarOperands[0].getType());
MathBuilder::elementTypeOfScalarOrVector(scalarOperands[0]);
// Perform int or float operation depending on the actual elementary type.
if (mlir::isa<mlir::IntegerType>(actualElementType)) {
// Generate the integer code only if the scalar integer op is non-void
Expand Down
Loading

0 comments on commit a4e43f1

Please sign in to comment.