Skip to content

Commit

Permalink
Quarter Wide Int Arith Pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Wouter Legiest committed Dec 24, 2024
1 parent 276bc7c commit 1f8c618
Show file tree
Hide file tree
Showing 15 changed files with 667 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ struct ConvertConstant : public OpConversionPattern<mlir::arith::ConstantOp> {
}
};

struct ConvertExt : public OpConversionPattern<mlir::arith::ExtSIOp> {
struct ConvertExt : public OpConversionPattern<mlir::arith::ExtUIOp> {
ConvertExt(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ExtSIOp>(context) {}
: OpConversionPattern<mlir::arith::ExtUIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
::mlir::arith::ExtSIOp op, OpAdaptor adaptor,
::mlir::arith::ExtUIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

Expand Down
40 changes: 40 additions & 0 deletions lib/Dialect/Arith/Transforms/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Transforms",
hdrs = ["Passes.h"],
deps = [
":QuarterWideInt",
":pass_inc_gen",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "QuarterWideInt",
srcs = ["QuarterWideInt.cpp"],
hdrs = ["QuarterWideInt.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Utils:ConversionUtils",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)

add_heir_transforms(
header_filename = "Passes.h.inc",
pass_name = "Arith",
td_file = "Passes.td",
)
17 changes: 17 additions & 0 deletions lib/Dialect/Arith/Transforms/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef LIB_DIALECT_ARITH_TRANSFORMS_PASSES_H_
#define LIB_DIALECT_ARITH_TRANSFORMS_PASSES_H_

#include "lib/Dialect/Arith/Transforms/QuarterWideInt.h"

namespace mlir {
namespace heir {
namespace arith {

#define GEN_PASS_REGISTRATION
#include "lib/Dialect/Arith/Transforms/Passes.h.inc"

} // namespace arith
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_ARITH_TRANSFORMS_PASSES_H_
26 changes: 26 additions & 0 deletions lib/Dialect/Arith/Transforms/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef LIB_DIALECT_ARITH_TRANSFORMS_PASSES_TD_
#define LIB_DIALECT_ARITH_TRANSFORMS_PASSES_TD_

include "mlir/Pass/PassBase.td"

def QuarterWideInt : Pass<"arith-quarter-wide-int"> {

let summary = "Convert high precision arithmetic operations to a sequence of lower precision operations";
let description = [{
This pass converts high precision arithmetic operations, i.e. operations on 32 bit integer,
into a sequence of lower precision operations, i.e 8b operations.
Currently, the pass splits the 32b integer into four 8b integers, using the tensor dialect.
These smaller integers are stored in an 16b integer, so that we don't lose the carry information.

Based on the `arith-emulate-wide-int` pass from the MLIR arith dialect.

General assumption: the first element in the tensor is also the LSB element.
}];
let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::vector::VectorDialect",
"mlir::tensor::TensorDialect",
];
}

#endif // LIB_DIALECT_ARITH_TRANSFORMS_PASSES_TD_
Loading

0 comments on commit 1f8c618

Please sign in to comment.