From 2f5056e9a4328fcee1cc9fc350a95d6b3addccf7 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Mon, 25 Nov 2024 18:49:56 -0500 Subject: [PATCH] [GPU] Add pattern to fuse tensor.extract_slice into forall producer Signed-off-by: Max Dawkins --- .../GPU/GPUFuseAndHoistParallelLoops.cpp | 22 ++ .../TransformExtensions/IREEGPUExtensions.cpp | 48 ++++ .../IREEGPUExtensionsOps.td | 34 +++ .../GPU/TransformExtensions/test/BUILD.bazel | 1 + .../TransformExtensions/test/CMakeLists.txt | 1 + ...nsform_fuse_extract_slice_with_forall.mlir | 60 +++++ .../Dialect/GPU/Transforms/Transforms.cpp | 238 ++++++++++++++++++ .../Dialect/GPU/Transforms/Transforms.h | 5 + 8 files changed, 409 insertions(+) create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_extract_slice_with_forall.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp index f6bb86d8eba7..4953e1c961c1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp @@ -340,6 +340,27 @@ struct FuseCollapseShapeConsumers final } }; +struct FuseExtractSliceConsumers final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp, + PatternRewriter &rewriter) const override { + // Find the scf::ForallOp producer, and get the corresponding + // tensor::ParallelInsertSliceOp. + auto forallOp = extractSliceOp.getSource().getDefiningOp(); + if (!forallOp) { + return rewriter.notifyMatchFailure(extractSliceOp, + "No forall op producer"); + } + + if (failed(fuseExtractSliceIntoProducerForall(rewriter, forallOp, + extractSliceOp))) { + return failure(); + } + return success(); + } +}; + void GPUFuseAndHoistParallelLoopsPass::runOnOperation() { MLIRContext *context = &getContext(); @@ -385,6 +406,7 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() { patterns.add(context); patterns.add(context); patterns.add(context); + patterns.add(context); tensor::populateFoldTensorEmptyPatterns(patterns); scf::ForallOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp index e745ba4b9203..e202de70f2c7 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp @@ -266,6 +266,54 @@ void transform_dialect::FuseCollapseShapeWithForallOp::getEffects( transform::modifiesPayload(effects); } +//===---------------------------------------------------------------------===// +// FuseExtractSliceWithForallOp +//===---------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform_dialect::FuseExtractSliceWithForallOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { + auto producers = state.getPayloadOps(getProducer()); + auto consumers = state.getPayloadOps(getConsumer()); + + int64_t numProducers = llvm::range_size(producers); + int64_t numConsumers = llvm::range_size(consumers); + if (numProducers != 1 || numConsumers != 1) { + return mlir::emitDefiniteFailure(state.getTopLevel(), + "More than one producer or consumer"); + } + + auto producer = dyn_cast(*producers.begin()); + if (!producer) { + return mlir::emitDefiniteFailure(state.getTopLevel(), + "Non-forall producer"); + } + auto consumer = dyn_cast(*consumers.begin()); + if (!consumer) { + return mlir::emitDefiniteFailure(state.getTopLevel(), + "Non-extract_slice consumer"); + } + + FailureOr fusedForallOp = + GPU::fuseExtractSliceIntoProducerForall(rewriter, producer, consumer); + if (failed(fusedForallOp)) { + return mlir::emitSilenceableFailure(state.getTopLevel(), + "failed to fuse extract_slice op"); + } + + results.set(getOperation()->getOpResult(0), {fusedForallOp.value()}); + return DiagnosedSilenceableFailure::success(); +} + +void transform_dialect::FuseExtractSliceWithForallOp::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getProducerMutable(), effects); + transform::consumesHandle(getConsumerMutable(), effects); + transform::producesHandle(getOperation()->getOpResults(), effects); + transform::modifiesPayload(effects); +} + } // namespace mlir::iree_compiler::IREE void mlir::iree_compiler::registerTransformDialectIREEGPUExtension( diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td index e570ce62889e..e77ef3a1492e 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td @@ -262,4 +262,38 @@ def FuseCollapseShapeWithForallOp : Op, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Fuses a consumer tensor.extract_slice op into a producer scf.forall op. + The users of the block argument for the corresponding forall output operand + should be only a tensor.parallel_insert_slice op, and tensor.extract_slice + ops that extract an equivalent subset. After the fusion, the output of the + forall will be an equal subset slice of the original output, and all users + of this block arg will be clamped to the slice size. Additional tensor.pad + ops will be inserted after any tensor.extract_slice users inside the forall + so that types match. Similarly, a tensor.extract_slice op will be inserted + before the tensor.parallel_insert_slice. + + #### Return modes + Emits a definite failure if either the producer is not an scf.forall op or + if the consumer is not a tensor.extract_slice op. + }]; + + let arguments = ( + ins TransformHandleTypeInterface:$producer, + TransformHandleTypeInterface:$consumer + ); + let results = (outs TransformHandleTypeInterface:$result); + + let assemblyFormat = [{ + $consumer `into` $producer attr-dict + `:` functional-type(operands, results) + }]; + let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect"; +} + #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel index 428211b3ea01..c137bef9afbe 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel @@ -25,6 +25,7 @@ iree_lit_test_suite( "lower_multi_mma.mlir", "lower_vector_barrier.mlir", "transform_fuse_collapse_shape_with_forall.mlir", + "transform_fuse_extract_slice_with_forall.mlir", "transform_fuse_forall.mlir", "transform_lower_barrier_region.mlir", "vectorize_iree_gpu_ops.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt index 344da8cf34d9..abeff344d337 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt @@ -21,6 +21,7 @@ iree_lit_test_suite( "lower_multi_mma.mlir" "lower_vector_barrier.mlir" "transform_fuse_collapse_shape_with_forall.mlir" + "transform_fuse_extract_slice_with_forall.mlir" "transform_fuse_forall.mlir" "transform_lower_barrier_region.mlir" "unroll_multi_mma.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_extract_slice_with_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_extract_slice_with_forall.mlir new file mode 100644 index 000000000000..fe790aaa126b --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_extract_slice_with_forall.mlir @@ -0,0 +1,60 @@ +// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule -canonicalize -cse --split-input-file | FileCheck %s + +#map = affine_map<(d0) -> (d0 * 2)> +module { + func.func @fuse_extract_slice_with_forall(%arg0: tensor<8xf32>, %arg1: index) -> tensor { + %0 = tensor.empty() : tensor<8xf32> + %1 = scf.forall (%arg2) in (4) shared_outs(%arg3 = %0) -> (tensor<8xf32>) { + %2 = affine.apply #map(%arg2) + %extracted_slice_0 = tensor.extract_slice %arg0[%2] [2] [1] : tensor<8xf32> to tensor<2xf32> + %extracted_slice_1 = tensor.extract_slice %arg3[%2] [2] [1] : tensor<8xf32> to tensor<2xf32> + %3 = linalg.copy ins(%extracted_slice_0 : tensor<2xf32>) outs(%extracted_slice_1 : tensor<2xf32>) -> tensor<2xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg3[%2] [2] [1] : tensor<2xf32> into tensor<8xf32> + } + } {mapping = [#gpu.thread]} + %extracted_slice = tensor.extract_slice %1[0] [%arg1] [1] : tensor<8xf32> to tensor + return %extracted_slice : tensor + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %producer = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer = transform.get_consumers_of_result %producer[0] : (!transform.any_op) -> !transform.any_op + %2 = transform.iree.fuse_extract_slice_with_forall %consumer into %producer + : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (d0 * -2 + s0, 2)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0) -> (0, d0)> +// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0) -> (-d0 + 2)> + +// CHECK-LABEL: func @fuse_extract_slice_with_forall +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: index + +// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK-DAG: %[[SLICED_OUT:.+]] = tensor.extract_slice %[[EMPTY]][0] [%[[ARG1]]] [1] : tensor<8xf32> to tensor +// CHECK: %[[FORALL_RESULT:.+]] = scf.forall (%[[IDX:.+]]) in (4) shared_outs(%[[SLICED_BBARG:.+]] = %[[SLICED_OUT]]) -> (tensor) { +// CHECK-DAG: %[[SLICE_IDX:.+]] = affine.apply #[[$MAP]](%[[IDX]]) +// CHECK-DAG: %[[SIZE_CLAMPED_HIGH:.+]] = affine.min #[[$MAP1]](%[[IDX]])[%[[ARG1]]] +// CHECK-DAG: %[[SIZE_CLAMPED_LOW:.+]] = affine.max #[[$MAP2]](%[[SIZE_CLAMPED_HIGH]]) +// CHECK-DAG: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[SLICE_IDX]]] [2] [1] : tensor<8xf32> to tensor<2xf32> +// CHECK-DAG: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[SLICED_BBARG]][%[[SLICE_IDX]]] [%[[SIZE_CLAMPED_LOW]]] [1] : tensor to tensor +// CHECK-DAG: %[[PAD_HIGH:.+]] = affine.apply #[[$MAP3]](%[[SIZE_CLAMPED_LOW]]) +// CHECK: %[[PADDED_OUT_SLICE:.+]] = tensor.pad %[[OUT_SLICE]] low[0] high[%[[PAD_HIGH]]] { +// CHECK: ^bb0({{.*}}): +// CHECK: tensor.yield %[[ZERO]] : f32 +// CHECK: } : tensor to tensor<2xf32> +// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<2xf32>) outs(%[[PADDED_OUT_SLICE]] : tensor<2xf32>) -> tensor<2xf32> +// CHECK: %[[SLICED_COPY:.+]] = tensor.extract_slice %[[COPY]][0] [%[[SIZE_CLAMPED_LOW]]] [1] : tensor<2xf32> to tensor +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[SLICED_COPY]] into %[[SLICED_BBARG]][%[[SLICE_IDX]]] [%[[SIZE_CLAMPED_LOW]]] [1] : tensor into tensor +// CHECK: } +// CHECK: } {mapping = [#gpu.thread]} +// CHECK: return %[[FORALL_RESULT]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index 6b03508fc905..bd70f47fdfdb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -604,6 +604,244 @@ fuseCollapseShapeIntoProducerForall(RewriterBase &rewriter, return newForallOp; } +static FailureOr +clampParallelInsertSliceOp(RewriterBase &rewriter, + tensor::ParallelInsertSliceOp parallelInsertOp, + SmallVector upperBoundSizes) { + // Find a valid insertion point to compute the new clamped sizes. + auto subsetOp = + cast(parallelInsertOp.getOperation()); + SmallVector neededValues = + subsetOp.getValuesNeededToBuildSubsetExtraction(); + auto forallOp = parallelInsertOp->getParentOfType(); + Operation *insertionPoint = &forallOp.getBody()->front(); + if (failed(getLatestInsertionPoint(insertionPoint, neededValues))) { + return rewriter.notifyMatchFailure(parallelInsertOp, + "could not find valid insertion point"); + } + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(insertionPoint); + Location loc = insertionPoint->getLoc(); + + // Clamp the sizes of the parallel insert source. + auto clampSizes = [&](SmallVector offsets, + SmallVector sizes, Location loc) { + SmallVector clampedSizes; + for (auto [offset, size, ub] : + llvm::zip_equal(offsets, sizes, upperBoundSizes)) { + AffineExpr d0, d1, d2; + auto ctx = rewriter.getContext(); + bindDims(ctx, d0, d1, d2); + auto ubClampMap = AffineMap::get(3, 0, {d0 - d1, d2}, ctx); + auto lbClampMap = rewriter.getMultiDimIdentityMap(2); + auto ubClamped = affine::makeComposedFoldedAffineMin( + rewriter, loc, ubClampMap, {ub, offset, size}); + auto lbClamped = affine::makeComposedFoldedAffineMax( + rewriter, loc, lbClampMap, {ubClamped, rewriter.getIndexAttr(0)}); + clampedSizes.push_back(lbClamped); + } + return clampedSizes; + }; + SmallVector clampedSizes = + clampSizes(parallelInsertOp.getMixedOffsets(), + parallelInsertOp.getMixedSizes(), loc); + + // Create an extract_slice to extract the correct size from the parallel + // insert source. + SmallVector clampedShape; + SmallVector d; + dispatchIndexOpFoldResults(clampedSizes, d, clampedShape); + RankedTensorType clampedType = + parallelInsertOp.getSourceType().clone(clampedShape); + SmallVector zeros(clampedSizes.size(), + rewriter.getIndexAttr(0)); + SmallVector ones(clampedSizes.size(), rewriter.getIndexAttr(1)); + Operation *combiningOp = + parallelInsertOp.getParallelCombiningParent().getOperation(); + rewriter.setInsertionPoint(combiningOp); + loc = combiningOp->getLoc(); + auto extractOp = rewriter.create( + loc, clampedType, parallelInsertOp.getSource(), zeros, clampedSizes, + ones); + + // Replace the parallel insert op with the clamped version, and return the + // new parallel insert slice. + rewriter.setInsertionPoint(parallelInsertOp); + loc = parallelInsertOp->getLoc(); + return rewriter.replaceOpWithNewOp( + parallelInsertOp, extractOp.getResult(), parallelInsertOp.getDest(), + parallelInsertOp.getMixedOffsets(), clampedSizes, + parallelInsertOp.getMixedStrides()); +} + +FailureOr +fuseExtractSliceIntoProducerForall(RewriterBase &rewriter, + scf::ForallOp forallOp, + tensor::ExtractSliceOp extractSliceOp) { + auto forallResult = cast(extractSliceOp.getSource()); + if (!forallResult.hasOneUse()) { + return rewriter.notifyMatchFailure(forallOp, + "forall result has multiple uses"); + } + BlockArgument initBbarg = + forallOp.getRegionIterArgs()[forallResult.getResultNumber()]; + SmallVector parallelInsertOps = + forallOp.getCombiningOps(initBbarg); + if (parallelInsertOps.size() != 1) { + return rewriter.notifyMatchFailure( + forallOp, "Expected a single parallel_insert_slice"); + } + + auto parallelInsertOp = + dyn_cast(parallelInsertOps.front()); + if (!parallelInsertOp) { + return rewriter.notifyMatchFailure( + forallOp, "Expected parallel_insert_slice combining op"); + } + + // Only zero offset extract_slice ops are supported. + if (!areAllConstantIntValue(extractSliceOp.getMixedOffsets(), 0)) { + return rewriter.notifyMatchFailure(forallOp, + "extract_slice has non-zero offsets"); + } + + // The extract_slice index operands must dominate the forall loop in order + // to extract a slice of the init operand later. + DominanceInfo domInfo; + int64_t indexOperandStartIdx = + extractSliceOp.getOffsetSizeAndStrideStartOperandIndex(); + SmallVector indexOperands(extractSliceOp->getOperands().begin() + + indexOperandStartIdx, + extractSliceOp->getOperands().end()); + if (!llvm::all_of(indexOperands, + [&](Value v) { return domInfo.dominates(v, forallOp); })) { + return rewriter.notifyMatchFailure( + extractSliceOp, + "Extract slice index operands do not dominate the forall op"); + } + + // Compute the rank reduction mask of the extract_slice for resolving rank + // reduction at the end. For rank reducing slices, the extract_slice is + // fused into the loop as a non rank reducing slice, and then a collapse + // shape is added on the result of the loop. This simplifies the logic in + // this pattern, and other patterns for collapse shape fusion can then fuse + // this collapse shape into the loop if needed. + auto maybeRankReductionMask = computeRankReductionMask( + extractSliceOp.getSourceType().getShape(), + extractSliceOp.getType().getShape(), /*matchDynamic=*/true); + if (!maybeRankReductionMask) { + return rewriter.notifyMatchFailure(extractSliceOp, + "Could not compute rank reduction mask"); + } + + // Get all users of the loop init argument, and verify that they operate on + // an equivalent subset. + int64_t resultIdx = forallResult.getResultNumber(); + FailureOr> maybeEqualSubsetUsers = + collectAllUsersIfAreEquivalentSubsets( + rewriter, forallOp.getRegionIterArgs()[resultIdx]); + if (failed(maybeEqualSubsetUsers)) { + return failure(); + } + + // Now extract from or pad the users inside the forallOp body. This is + // necessary to ensure that types match after the transformation. For any + // parallel_insert_slice ops, add an extract_slice with clamped sizes, and + // for any extract_slice ops, add a pad to the original size. Start with + // the parallel insert slice, and use the SubsetInsertionOpInterface to + // build the correct extract_slice ops. + SmallVector newInitSizes = extractSliceOp.getMixedSizes(); + FailureOr maybeClampedParallelInsertSliceOp = + clampParallelInsertSliceOp(rewriter, parallelInsertOp, newInitSizes); + if (failed(maybeClampedParallelInsertSliceOp)) { + return failure(); + } + tensor::ParallelInsertSliceOp clampedParallelInsertSliceOp = + maybeClampedParallelInsertSliceOp.value(); + + // Clamp the extract_slice users. + SmallVector equalSubsetUsers = maybeEqualSubsetUsers.value(); + auto subsetInsertionOp = cast( + clampedParallelInsertSliceOp.getOperation()); + for (Operation *op : equalSubsetUsers) { + if (isa(op)) { + continue; + } + rewriter.setInsertionPoint(op); + Location loc = op->getLoc(); + Value newExtract = subsetInsertionOp.buildSubsetExtraction(rewriter, loc); + auto newExtractOp = newExtract.getDefiningOp(); + + // Create tensor.pad so that types match. + auto extractOp = cast(op); + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(extractOp.getType().getElementType())); + SmallVector highPadding; + for (auto [ub, size] : llvm::zip_equal(extractOp.getMixedSizes(), + newExtractOp.getMixedSizes())) { + auto addMap = AffineMap::get( + 2, 0, {rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, + rewriter.getContext()); + OpFoldResult padSize = affine::makeComposedFoldedAffineApply( + rewriter, loc, addMap, {ub, size}); + highPadding.push_back(padSize); + } + SmallVector lowPadding(highPadding.size(), + rewriter.getIndexAttr(0)); + auto padOp = rewriter.create( + loc, extractOp.getType(), newExtract, lowPadding, highPadding, zero); + rewriter.replaceOp(extractOp, padOp); + } + + // Clone the extract_slice, and replace the source with the forall init + // operand. + Value forallInit = forallOp.getOutputs()[resultIdx]; + rewriter.setInsertionPoint(forallOp); + auto extractedInit = rewriter.create( + forallOp->getLoc(), forallInit, extractSliceOp.getMixedOffsets(), + extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); + + // Clone the forall op with the extracted init operand to replace the + // original forall op. + Location loc = forallOp->getLoc(); + rewriter.setInsertionPoint(forallOp); + SmallVector newForallOutputs(forallOp.getOutputs()); + newForallOutputs[resultIdx] = extractedInit.getResult(); + + scf::ForallOp newForallOp = rewriter.create( + loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), + forallOp.getMixedStep(), newForallOutputs, forallOp.getMappingAttr()); + + SmallVector argReplacements(newForallOp.getInductionVars()); + argReplacements.append(newForallOp.getRegionIterArgs().begin(), + newForallOp.getRegionIterArgs().end()); + newForallOp.getTerminator()->erase(); + rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(), + argReplacements); + + // Create a collapse_shape to handle rank reduction. + Value extractedResult = newForallOp->getResult(resultIdx); + auto forallResultType = cast(extractedResult.getType()); + SmallVector reassociations; + ReassociationIndices reassociation; + for (int i = 0; i < forallResultType.getRank(); ++i) { + if (maybeRankReductionMask->contains(i)) { + reassociation.push_back(i); + continue; + } + reassociation.push_back(i); + reassociations.push_back(reassociation); + reassociation = {}; + } + auto collapseShape = rewriter.create( + extractSliceOp->getLoc(), extractedResult, reassociations); + + // Replace forall and extract_slice ops with the new operations. + rewriter.replaceAllOpUsesWith(extractSliceOp, collapseShape); + rewriter.replaceOp(forallOp, newForallOp); + return newForallOp; +} + //===----------------------------------------------------------------------===// // MultiMmaOp Lowering //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h index a119e7e976d7..36a8c363ae3f 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h @@ -61,6 +61,11 @@ fuseCollapseShapeIntoProducerForall(RewriterBase &rewriter, scf::ForallOp forallOp, tensor::CollapseShapeOp collapseOp); +FailureOr +fuseExtractSliceIntoProducerForall(RewriterBase &rewriter, + scf::ForallOp forallOp, + tensor::ExtractSliceOp extractSliceOp); + // Helper to convert a contraction-like linalg op to an iree_gpu.multi_mma. FailureOr convertContractionToMultiMma(RewriterBase &rewriter, linalg::LinalgOp linalgOp,