Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with 9dcf0a95 (2) #155

Merged
merged 4 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Conversion/KrnlToAffine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_onnx_mlir_library(OMKrnlToAffine
ConvertKrnlToAffine.cpp
KrnlCopyFromBuffer.cpp
KrnlCopyToBuffer.cpp
KrnlGetLinearOffsetIndex.cpp
KrnlLoad.cpp
KrnlMatmul.cpp
KrnlMemset.cpp
Expand Down
2 changes: 2 additions & 0 deletions src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,8 @@ void populateKrnlToAffineConversion(TypeConverter &typeConverter,
krnl::populateLoweringKrnlCopyToBufferOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlLoadOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlStoreOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlGetLinearOffsetIndexOpPattern(
typeConverter, patterns, ctx);
krnl::populateLoweringKrnlMatmultOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlMemsetOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlTerminatorOpPattern(typeConverter, patterns, ctx);
Expand Down
4 changes: 4 additions & 0 deletions src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ void populateLoweringKrnlLoadOpPattern(mlir::TypeConverter &typeConverter,
void populateLoweringKrnlStoreOpPattern(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);

void populateLoweringKrnlGetLinearOffsetIndexOpPattern(
mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns,
mlir::MLIRContext *ctx);

void populateLoweringKrnlMatmultOpPattern(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);

Expand Down
87 changes: 87 additions & 0 deletions src/Conversion/KrnlToAffine/KrnlGetLinearOffsetIndex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===--------------- KrnlGetLinearOffsetIndex.cpp - -----------------------===//
//
// Copyright 2024- The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the KrnlGetLinearOffsetIndexOp operator.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/IR/BuiltinTypes.h"

#include "src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "krnl_to_affine"

using namespace mlir;

namespace onnx_mlir {
namespace krnl {

class KrnlGetLinearOffsetIndexLowering : public ConversionPattern {
public:
explicit KrnlGetLinearOffsetIndexLowering(
TypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(typeConverter,
KrnlGetLinearOffsetIndexOp::getOperationName(), 1, context) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MultiDialectBuilder<IndexExprBuilderForKrnl> create(rewriter, loc);
IndexExprScope scope(create.krnlIE);

auto krnlOp = llvm::cast<KrnlGetLinearOffsetIndexOp>(op);
KrnlGetLinearOffsetIndexOpAdaptor operandAdaptor(krnlOp);
// Get the input memref.
Value memref = operandAdaptor.getMemref();
// Get indices.
SmallVector<Value, 4> mapOperands(krnlOp.getMapOperands());
auto mapResults = mlir::affine::expandAffineMap(
rewriter, loc, krnlOp.getMap(), mapOperands);
if (!mapResults)
return failure();
SmallVector<Value, 8> indices = mapResults.value();

auto memrefTy = llvm::dyn_cast<MemRefType>(memref.getType());
int64_t rank = memrefTy.getRank();
assert(mapResults.value().size() == rank && "Invalid indices");

// Only lower this op after the memref is normalized.
if (!memrefTy.getLayout().isIdentity())
return failure();

// Get dimension sizes.
SmallVector<IndexExpr, 4> dims;
create.krnlIE.getShapeAsDims(memref, dims);
// Compute the linear offset using strides.
IndexExpr offsetIE = LiteralIndexExpr(0);
IndexExpr strideIE = LiteralIndexExpr(1);
for (int64_t i = rank - 1; i >= 0; --i) {
IndexExpr strideOffset = strideIE * DimIndexExpr(indices[i]);
offsetIE = offsetIE + strideOffset;
if (i > 0)
strideIE = strideIE * dims[i];
}

rewriter.replaceOp(op, offsetIE.getValue());
return success();
}
};

void populateLoweringKrnlGetLinearOffsetIndexOpPattern(
TypeConverter &typeConverter, RewritePatternSet &patterns,
MLIRContext *ctx) {
patterns.insert<KrnlGetLinearOffsetIndexLowering>(typeConverter, ctx);
}

} // namespace krnl
} // namespace onnx_mlir
12 changes: 12 additions & 0 deletions src/Dialect/Krnl/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ void KrnlBuilder::storeIE(
b().create<KrnlStoreOp>(loc(), val, memref, indexValues);
}

Value KrnlBuilder::getLinearOffsetIndex(
Value memref, ValueRange indices) const {
return b().create<KrnlGetLinearOffsetIndexOp>(loc(), memref, indices);
}

Value KrnlBuilder::getLinearOffsetIndexIE(
Value memref, ArrayRef<IndexExpr> indices) const {
SmallVector<Value, 4> indexValues;
IndexExpr::getValues(indices, indexValues);
return b().create<KrnlGetLinearOffsetIndexOp>(loc(), memref, indexValues);
}

void KrnlBuilder::seqstore(
mlir::Value element, mlir::Value seq, mlir::Value index) const {
b().create<KrnlSeqStoreOp>(loc(), element, seq, index);
Expand Down
5 changes: 5 additions & 0 deletions src/Dialect/Krnl/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ struct KrnlBuilder : public DialectBuilder {
void storeIE(mlir::Value val, mlir::Value memref,
mlir::ArrayRef<IndexExpr> indices) const;

mlir::Value getLinearOffsetIndex(
mlir::Value memref, mlir::ValueRange indices = {}) const;
mlir::Value getLinearOffsetIndexIE(
mlir::Value memref, mlir::ArrayRef<IndexExpr> indices) const;

void seqstore(mlir::Value element, mlir::Value seq, mlir::Value index) const;
void seqstore(mlir::Value element, mlir::Value seq, IndexExpr index) const;

Expand Down
49 changes: 49 additions & 0 deletions src/Dialect/Krnl/Krnl.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
include "src/Interface/SpecializedKernelOpInterface.td"

def Krnl_Dialect : Dialect {
Expand Down Expand Up @@ -667,6 +668,54 @@ def KrnlStoreOp : Op<Krnl_Dialect, "store",
}];
}

def KrnlGetLinearOffsetIndexOp : Op<Krnl_Dialect, "get_linear_offset_index",
[DeclareOpInterfaceMethods<AffineReadOpInterface>,
DeclareOpInterfaceMethods<AffineMapAccessInterface>, MemRefsNormalizable]> {
let summary = "A Krnl operation to compute a linear offset index from a N-D index.";

let description = [{
Given a MemRef and an N-D index (id_1, id_2, ..., id_n), where n is
the rank of the MemRef, this operation computes a linear offset index.
}];

let arguments = (ins Arg<AnyMemRef, "the reference memref", [MemRead]>:$memref,
Variadic<Index>:$indices,
AffineMapAttr:$map);
let results = (outs Index:$result);

// let assemblyFormat = [{$memref `[` $indices `]` attr-dict `:` type($memref)}];
let builders = [
/// Builds an op with the specified map and operands.
OpBuilder<(ins "AffineMap":$map, "ValueRange":$operands)>,
/// Builds an op with an identity map and operands.
OpBuilder<(ins "Value":$memref, CArg<"ValueRange", "{}">:$indices)>,
/// Builds an op with the specified map and its operands.
OpBuilder<(ins "Value":$memref, "AffineMap":$map,
"ValueRange":$mapOperands)>
];
let extraClassDeclaration = [{
/// Returns the operand index of the memref.
unsigned getMemRefOperandIndex() { return 0; }

void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }

MemRefType getMemRefType() {
return getMemref().getType().cast<MemRefType>();
}

/// Returns the affine map used to index the memref for this operation.
AffineMapAttr getAffineMapAttr() {
return getMapAttr();
}

static StringRef getMapAttrStrName() { return "map"; }
}];

let hasCustomAssemblyFormat = 1;
// let assemblyFormat = [{$memref `[` $indices `]` attr-dict `:` type($memref)}];

}

def KrnlMovableOp : Op<Krnl_Dialect, "movable", [ImplicitKrnlTerminator]> {
let summary = "Krnl movable operation";
let description = [{
Expand Down
66 changes: 66 additions & 0 deletions src/Dialect/Krnl/KrnlOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,72 @@ std::optional<Value> KrnlSeqAllocOp::buildClone(
.getResult();
}

//===----------------------------------------------------------------------===//
// KrnlGetLinearOffsetIndexOp
//===----------------------------------------------------------------------===//

void KrnlGetLinearOffsetIndexOp::build(OpBuilder &builder,
OperationState &result, AffineMap map, ValueRange operands) {
assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
result.addOperands(operands);
if (map)
result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
result.types.push_back(memrefType.getElementType());
}

void KrnlGetLinearOffsetIndexOp::build(OpBuilder &builder,
OperationState &result, Value memref, AffineMap map,
ValueRange mapOperands) {
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
result.addOperands(memref);
result.addOperands(mapOperands);
result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
result.types.push_back(builder.getIndexType());
}

void KrnlGetLinearOffsetIndexOp::build(OpBuilder &builder,
OperationState &result, Value memref, ValueRange indices) {
auto memrefType = llvm::cast<MemRefType>(memref.getType());
int64_t rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
// for zero-dimensional memrefs.
auto map =
rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
build(builder, result, memref, map, indices);
}

ParseResult KrnlGetLinearOffsetIndexOp::parse(
OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
auto indexTy = builder.getIndexType();

MemRefType type;
OpAsmParser::UnresolvedOperand memrefInfo;
AffineMapAttr mapAttr;
SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
return failure(
parser.parseOperand(memrefInfo) || parser.parseKeyword("at") ||
parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
KrnlGetLinearOffsetIndexOp::getMapAttrStrName(), result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperand(memrefInfo, type, result.operands) ||
parser.resolveOperands(mapOperands, indexTy, result.operands) ||
parser.addTypeToList(indexTy, result.types));
}

void KrnlGetLinearOffsetIndexOp::print(OpAsmPrinter &p) {
p << " " << getMemRef() << " at [";
if (AffineMapAttr mapAttr =
(*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
p << ']';
p.printOptionalAttrDict((*this)->getAttrs(),
/*elidedAttrs=*/{getMapAttrStrName()});
p << " : " << getMemRefType();
}

} // namespace mlir

#define GET_OP_CLASSES
Expand Down
6 changes: 6 additions & 0 deletions src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,9 @@ struct InstanceNormIntoLayerNormPattern
// Create output using layer norm.
Value Y = create.onnx.layerNorm(inputType, input, newScale, newBias, axis,
instanceNormOp.getEpsilonAttr());
// Set the type of the output to be the same as the output of the original
// operation we are trying to replace.
Y.setType(instanceNormOp.getResult().getType());
// Replace operation.
rewriter.replaceOp(instanceNormOp, Y);
return success();
Expand Down Expand Up @@ -1114,6 +1117,9 @@ struct GroupNormIntoLayerNormPattern
Value inputShape = create.onnx.shape(inputShapeType, input);
Type outputType = groupNormOp.getY().getType();
Value Y = create.onnx.reshape(outputType, layerNormY, inputShape);
// Set the type of the output to be the same as the output of the original
// operation we are trying to replace.
Y.setType(groupNormOp.getResult().getType());
// Replace operation.
rewriter.replaceOp(groupNormOp, Y);
return success();
Expand Down
6 changes: 2 additions & 4 deletions test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,14 +1385,12 @@ def get_test_models():
# ==MIN== 6
"test_instancenorm_example_cpu": {
STATIC_SHAPE: {},
# Issue #2639: Dynamic test fails. Need to be fixed.
# DYNAMIC_SHAPE: {-1: {-1}},
DYNAMIC_SHAPE: {-1: {-1}},
CONSTANT_INPUT: {-1},
},
"test_instancenorm_epsilon_cpu": {
STATIC_SHAPE: {},
# Issue #2639: Dynamic test fails. Need to be fixed.
# DYNAMIC_SHAPE: {-1: {-1}},
DYNAMIC_SHAPE: {-1: {-1}},
CONSTANT_INPUT: {-1},
},
# ==OP== IsInf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,32 @@ func.func private @memref_with_affine(%arg0: memref<3xf32, #map>) -> memref<3xf3
// CHECK: return [[RES_]] : memref<3xf32, #map>
// CHECK: }
}

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 64, d2 floordiv 32, d2 mod 32, d1 mod 64)>
func.func @krnl_get_linear_offset_index_1(%arg0: memref<?x128x256xf32, #map>, %arg1: index, %arg2: index) -> index {
%c5 = arith.constant 5: index
%c10 = arith.constant 10: index
%0 = memref.alloc(%arg1) : memref<?x128x256xf32, #map>
%1 = krnl.get_linear_offset_index %arg0 at [%arg2, %c5, %c10] : memref<?x128x256xf32, #map>
return %1: index

// CHECK-LABEL: func.func @krnl_get_linear_offset_index
// CHECK: [[VAR_0_:%.+]] = krnl.get_linear_offset_index {{.*}} at {{.*}} : memref<?x128x256xf32, #map>
}

// -----

#map = affine_map<(d0, d1, d2) -> (d0)>
func.func @krnl_get_linear_offset_index_2(%arg0: memref<?x2x8x32x64xf32>, %arg1: index, %arg2: index) -> index {
%0 = krnl.get_linear_offset_index %arg0 at [%arg2, 0, 0, 10, 5] : memref<?x2x8x32x64xf32>
return %0 : index

// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 32768 + 645)>
// CHECK-LABEL: func.func @krnl_get_linear_offset_index
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x2x8x32x64xf32>, [[PARAM_1_:%.+]]: index, [[PARAM_2_:%.+]]: index) -> index attributes {llvm.emit_c_interface} {
// CHECK: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[PARAM_2_]]{{.}}
// CHECK: return [[VAR_0_]] : index
// CHECK: }
}
26 changes: 26 additions & 0 deletions test/mlir/krnl/get_linear_offset_index.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//RUN: onnx-mlir-opt --normalize-memrefs --split-input-file %s | FileCheck %s

#map = affine_map<(d0, d1, d2) -> (d0, d1 floordiv 64, d2 floordiv 32, d2 mod 32, d1 mod 64)>
module {
func.func @krnl_get_linear_offset_index(%arg0: memref<?x128x256xf32, #map>, %arg1: index, %arg2: index) -> index {
%c5 = arith.constant 5: index
%c10 = arith.constant 10: index
%0 = memref.alloc(%arg1) : memref<?x128x256xf32, #map>
%1 = krnl.get_linear_offset_index %arg0 at [%arg2, %c5, %c10] : memref<?x128x256xf32, #map>
return %1: index
}

// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0)>
// CHECK-LABEL: func.func @krnl_get_linear_offset_index
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x2x8x32x64xf32>, [[PARAM_1_:%.+]]: index, [[PARAM_2_:%.+]]: index) -> index {
// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : index
// CHECK-DAG: [[CST_10_:%.+]] = arith.constant 10 : index
// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index
// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index
// CHECK: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]]([[PARAM_1_]], [[CST_128_]], [[CST_256_]])
// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_0_]]) : memref<?x2x8x32x64xf32>
// CHECK-DAG: [[VAR_1_:%.+]] = krnl.get_linear_offset_index [[PARAM_0_]] at [symbol([[PARAM_2_]]), 0, 0, 10, 5] : memref<?x2x8x32x64xf32>
// CHECK: return [[VAR_1_]] : index
// CHECK: }

}
Loading