Skip to content

Commit

Permalink
[GPU] Add pattern to fuse tensor.extract_slice into forall producer
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max Dawkins committed Nov 26, 2024
1 parent 04c4f2b commit 8f9be22
Show file tree
Hide file tree
Showing 9 changed files with 789 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,27 @@ struct FuseCollapseShapeConsumers final
}
};

struct FuseExtractSliceConsumers final
: OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp,
PatternRewriter &rewriter) const override {
// Find the scf::ForallOp producer, and get the corresponding
// tensor::ParallelInsertSliceOp.
auto forallOp = extractSliceOp.getSource().getDefiningOp<scf::ForallOp>();
if (!forallOp) {
return rewriter.notifyMatchFailure(extractSliceOp,
"No forall op producer");
}

if (failed(fuseExtractSliceIntoProducerForall(rewriter, forallOp,
extractSliceOp))) {
return failure();
}
return success();
}
};

void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
MLIRContext *context = &getContext();

Expand Down Expand Up @@ -385,6 +406,7 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
patterns.add<FuseUnitLoopDestination>(context);
patterns.add<FuseTilableForallConsumers>(context);
patterns.add<FuseCollapseShapeConsumers>(context);
patterns.add<FuseExtractSliceConsumers>(context);
tensor::populateFoldTensorEmptyPatterns(patterns);
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,3 +599,30 @@ func.func @no_fuse_collapse_shape_rank_reduced(%arg0: tensor<8x8xf32>) -> tensor
// CHECK: } {mapping = [#gpu.thread<x>]}
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FORALL_RESULT]]
// CHECK: return %[[COLLAPSE]]

// -----

#map = affine_map<(d0) -> (d0 * 2)>
func.func @no_fuse_extract_slice_rank_reduced(%arg0: tensor<4x8xf32>, %size1: index) -> tensor<?xf32> {
%0 = tensor.empty() : tensor<4x8xf32>
%1 = scf.forall (%arg2) in (4) shared_outs(%arg3 = %0) -> (tensor<4x8xf32>) {
%2 = affine.apply #map(%arg2)
%extracted_slice_0 = tensor.extract_slice %arg0[0, %2] [1, 2] [1, 1] : tensor<4x8xf32> to tensor<2xf32>
%extracted_slice_1 = tensor.extract_slice %arg3[0, %2] [1, 2] [1, 1] : tensor<4x8xf32> to tensor<2xf32>
%3 = linalg.copy ins(%extracted_slice_0 : tensor<2xf32>) outs(%extracted_slice_1 : tensor<2xf32>) -> tensor<2xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %3 into %arg3[0, %2] [1, 2] [1, 1] : tensor<2xf32> into tensor<4x8xf32>
}
} {mapping = [#gpu.thread<x>]}
%extracted_slice = tensor.extract_slice %1[0, 0] [1, %size1] [1, 1] : tensor<4x8xf32> to tensor<?xf32>
return %extracted_slice : tensor<?xf32>
}

// CHECK-LABEL: func @no_fuse_extract_slice_rank_reduced
// CHECK: %[[FORALL_RESULT:.+]] = scf.forall {{.*}} -> (tensor<4x8xf32>) {
// CHECK: scf.forall.in_parallel {
// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor<2xf32> into tensor<4x8xf32>
// CHECK: }
// CHECK: } {mapping = [#gpu.thread<x>]}
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[FORALL_RESULT]]
// CHECK: return %[[EXTRACT]]
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,54 @@ void transform_dialect::FuseCollapseShapeWithForallOp::getEffects(
transform::modifiesPayload(effects);
}

//===---------------------------------------------------------------------===//
// FuseExtractSliceWithForallOp
//===---------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform_dialect::FuseExtractSliceWithForallOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
auto producers = state.getPayloadOps(getProducer());
auto consumers = state.getPayloadOps(getConsumer());

int64_t numProducers = llvm::range_size(producers);
int64_t numConsumers = llvm::range_size(consumers);
if (numProducers != 1 || numConsumers != 1) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"More than one producer or consumer");
}

auto producer = dyn_cast<scf::ForallOp>(*producers.begin());
if (!producer) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"Non-forall producer");
}
auto consumer = dyn_cast<tensor::ExtractSliceOp>(*consumers.begin());
if (!consumer) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"Non-extract_slice consumer");
}

FailureOr<scf::ForallOp> fusedForallOp =
GPU::fuseExtractSliceIntoProducerForall(rewriter, producer, consumer);
if (failed(fusedForallOp)) {
return mlir::emitSilenceableFailure(state.getTopLevel(),
"failed to fuse extract_slice op");
}

results.set(getOperation()->getOpResult(0), {fusedForallOp.value()});
return DiagnosedSilenceableFailure::success();
}

void transform_dialect::FuseExtractSliceWithForallOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getProducerMutable(), effects);
transform::consumesHandle(getConsumerMutable(), effects);
transform::producesHandle(getOperation()->getOpResults(), effects);
transform::modifiesPayload(effects);
}

} // namespace mlir::iree_compiler::IREE

void mlir::iree_compiler::registerTransformDialectIREEGPUExtension(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,4 +262,38 @@ def FuseCollapseShapeWithForallOp : Op<Transform_Dialect, "iree.fuse_collapse_sh
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
}

def FuseExtractSliceWithForallOp : Op<Transform_Dialect, "iree.fuse_extract_slice_with_forall",
[FunctionalStyleTransformOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Fuses a consumer tensor.extract_slice op into a producer scf.forall op.
The users of the block argument for the corresponding forall output operand
should be only a tensor.parallel_insert_slice op, and tensor.extract_slice
ops that extract an equivalent subset. After the fusion, the output of the
forall will be an equal subset slice of the original output, and all users
of this block arg will be clamped to the slice size. Additional tensor.pad
ops will be inserted after any tensor.extract_slice users inside the forall
so that types match. Similarly, a tensor.extract_slice op will be inserted
before the tensor.parallel_insert_slice.

#### Return modes
Emits a definite failure if either the producer is not an scf.forall op or
if the consumer is not a tensor.extract_slice op.
}];

let arguments = (
ins TransformHandleTypeInterface:$producer,
TransformHandleTypeInterface:$consumer
);
let results = (outs TransformHandleTypeInterface:$result);

let assemblyFormat = [{
$consumer `into` $producer attr-dict
`:` functional-type(operands, results)
}];
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
}

#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"lower_multi_mma.mlir",
"lower_vector_barrier.mlir",
"transform_fuse_collapse_shape_with_forall.mlir",
"transform_fuse_extract_slice_with_forall.mlir",
"transform_fuse_forall.mlir",
"transform_lower_barrier_region.mlir",
"vectorize_iree_gpu_ops.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_lit_test_suite(
"lower_multi_mma.mlir"
"lower_vector_barrier.mlir"
"transform_fuse_collapse_shape_with_forall.mlir"
"transform_fuse_extract_slice_with_forall.mlir"
"transform_fuse_forall.mlir"
"transform_lower_barrier_region.mlir"
"unroll_multi_mma.mlir"
Expand Down
Loading

0 comments on commit 8f9be22

Please sign in to comment.