Skip to content

Commit 4eb7167

Browse files
author
Max Dawkins
committed
[GPU] Add pattern to fuse tensor.collapse_shape into forall producer
Signed-off-by: Max Dawkins <[email protected]>
1 parent 53e9601 commit 4eb7167

File tree

8 files changed

+475
-3
lines changed

8 files changed

+475
-3
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp

+19-3
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
1212
#include "iree/compiler/Codegen/Transforms/Transforms.h"
1313
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
14-
#include "llvm/ADT/TypeSwitch.h"
15-
#include "llvm/Support/Casting.h"
1614
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1715
#include "mlir/Dialect/Func/IR/FuncOps.h"
1816
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -22,7 +20,6 @@
2220
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2321
#include "mlir/Interfaces/FunctionInterfaces.h"
2422
#include "mlir/Interfaces/LoopLikeInterface.h"
25-
#include "mlir/Support/LogicalResult.h"
2623
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2724

2825
#define DEBUG_TYPE "iree-codegen-gpu-fuse-and-hoist-parallel-loops"
@@ -325,6 +322,24 @@ struct FuseTilableForallConsumers final
325322
}
326323
};
327324

325+
struct FuseCollapseShapeConsumers final
326+
: OpRewritePattern<tensor::CollapseShapeOp> {
327+
using OpRewritePattern::OpRewritePattern;
328+
LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp,
329+
PatternRewriter &rewriter) const override {
330+
auto forallOp = collapseOp.getSrc().getDefiningOp<scf::ForallOp>();
331+
if (!forallOp) {
332+
return rewriter.notifyMatchFailure(collapseOp, "No forall op producer");
333+
}
334+
335+
if (failed(fuseCollapseShapeIntoProducerForall(rewriter, forallOp,
336+
collapseOp))) {
337+
return failure();
338+
}
339+
return success();
340+
}
341+
};
342+
328343
void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
329344
MLIRContext *context = &getContext();
330345

@@ -369,6 +384,7 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
369384
patterns.add<FuseTilableDestinationProducers>(context);
370385
patterns.add<FuseUnitLoopDestination>(context);
371386
patterns.add<FuseTilableForallConsumers>(context);
387+
patterns.add<FuseCollapseShapeConsumers>(context);
372388
tensor::populateFoldTensorEmptyPatterns(patterns);
373389
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
374390
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {

compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp

+48
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,54 @@ void transform_dialect::FuseForallOp::getEffects(
218218
transform::modifiesPayload(effects);
219219
}
220220

221+
//===---------------------------------------------------------------------===//
222+
// FuseCollapseShapeWithForallOp
223+
//===---------------------------------------------------------------------===//
224+
225+
DiagnosedSilenceableFailure
226+
transform_dialect::FuseCollapseShapeWithForallOp::apply(
227+
transform::TransformRewriter &rewriter,
228+
transform::TransformResults &results, transform::TransformState &state) {
229+
auto producers = state.getPayloadOps(getProducer());
230+
auto consumers = state.getPayloadOps(getConsumer());
231+
232+
int64_t numProducers = llvm::range_size(producers);
233+
int64_t numConsumers = llvm::range_size(consumers);
234+
if (numProducers != 1 || numConsumers != 1) {
235+
return mlir::emitDefiniteFailure(state.getTopLevel(),
236+
"More than one producer or consumer");
237+
}
238+
239+
auto producer = dyn_cast<scf::ForallOp>(*producers.begin());
240+
if (!producer) {
241+
return mlir::emitDefiniteFailure(state.getTopLevel(),
242+
"Non-forall producer");
243+
}
244+
auto consumer = dyn_cast<tensor::CollapseShapeOp>(*consumers.begin());
245+
if (!consumer) {
246+
return mlir::emitDefiniteFailure(state.getTopLevel(),
247+
"Non-collapse_shape consumer");
248+
}
249+
250+
FailureOr<scf::ForallOp> fusedForallOp =
251+
GPU::fuseCollapseShapeIntoProducerForall(rewriter, producer, consumer);
252+
if (failed(fusedForallOp)) {
253+
return mlir::emitSilenceableFailure(state.getTopLevel(),
254+
"failed to fuse collapse_shape op");
255+
}
256+
257+
results.set(getOperation()->getOpResult(0), {fusedForallOp.value()});
258+
return DiagnosedSilenceableFailure::success();
259+
}
260+
261+
void transform_dialect::FuseCollapseShapeWithForallOp::getEffects(
262+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
263+
transform::consumesHandle(getProducerMutable(), effects);
264+
transform::consumesHandle(getConsumerMutable(), effects);
265+
transform::producesHandle(getOperation()->getOpResults(), effects);
266+
transform::modifiesPayload(effects);
267+
}
268+
221269
} // namespace mlir::iree_compiler::IREE
222270

223271
void mlir::iree_compiler::registerTransformDialectIREEGPUExtension(

compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td

+34
Original file line numberDiff line numberDiff line change
@@ -228,4 +228,38 @@ def FuseForallOp : Op<Transform_Dialect, "iree.fuse_forall",
228228
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
229229
}
230230

231+
def FuseCollapseShapeWithForallOp : Op<Transform_Dialect, "iree.fuse_collapse_shape_with_forall",
232+
[FunctionalStyleTransformOpTrait,
233+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
234+
DeclareOpInterfaceMethods<TransformOpInterface>,
235+
ReportTrackingListenerFailuresOpTrait]> {
236+
let description = [{
237+
Fuses a consumer tensor.collapse_shape op into a producer scf.forall op.
238+
The users of the block argument for the corresponding forall output operand
239+
should be only a tensor.parallel_insert_slice op, and tensor.extract_slice
240+
ops that extract an equivalent subset. After the fusion, the output of the
241+
forall will be collapsed, and all users of this block arg will also be
242+
collapsed. Additional tensor.expand_shape ops will be inserted after any
243+
tensor.extract_slice users inside the forall so that types match. Similarly,
244+
a tensor.collapse_shape will be inserted before the
245+
tensor.parallel_insert_slice.
246+
247+
#### Return modes
248+
Emits a definite failure if either the producer is not an scf.forall op or
249+
if the consumer is not a tensor.collapse_shape op.
250+
}];
251+
252+
let arguments = (
253+
ins TransformHandleTypeInterface:$producer,
254+
TransformHandleTypeInterface:$consumer
255+
);
256+
let results = (outs TransformHandleTypeInterface:$result);
257+
258+
let assemblyFormat = [{
259+
$consumer `into` $producer attr-dict
260+
`:` functional-type(operands, results)
261+
}];
262+
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
263+
}
264+
231265
#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS

compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ iree_lit_test_suite(
2424
"drop_multi_mma_unit_dims.mlir",
2525
"lower_multi_mma.mlir",
2626
"lower_vector_barrier.mlir",
27+
"transform_fuse_collapse_shape_with_forall.mlir",
2728
"transform_fuse_forall.mlir",
2829
"transform_lower_barrier_region.mlir",
2930
"vectorize_iree_gpu_ops.mlir",

compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ iree_lit_test_suite(
2020
"drop_multi_mma_unit_dims.mlir"
2121
"lower_multi_mma.mlir"
2222
"lower_vector_barrier.mlir"
23+
"transform_fuse_collapse_shape_with_forall.mlir"
2324
"transform_fuse_forall.mlir"
2425
"transform_lower_barrier_region.mlir"
2526
"unroll_multi_mma.mlir"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s
2+
3+
#map = affine_map<(d0) -> (d0 * 2)>
4+
module {
5+
func.func @fuse_collapse_shape_with_forall(%arg0: tensor<8x8xf32>) -> tensor<64xf32> {
6+
%0 = tensor.empty() : tensor<8x8xf32>
7+
%1 = scf.forall (%arg1) in (4) shared_outs(%arg2 = %0) -> (tensor<8x8xf32>) {
8+
%2 = affine.apply #map(%arg1)
9+
%extracted_slice = tensor.extract_slice %arg0[%2, 0] [2, 8] [1, 1] : tensor<8x8xf32> to tensor<2x8xf32>
10+
%extracted_slice_0 = tensor.extract_slice %arg2[%2, 0] [2, 8] [1, 1] : tensor<8x8xf32> to tensor<2x8xf32>
11+
%3 = linalg.copy ins(%extracted_slice : tensor<2x8xf32>) outs(%extracted_slice_0 : tensor<2x8xf32>) -> tensor<2x8xf32>
12+
scf.forall.in_parallel {
13+
tensor.parallel_insert_slice %3 into %arg2[%2, 0] [2, 8] [1, 1] : tensor<2x8xf32> into tensor<8x8xf32>
14+
}
15+
} {mapping = [#gpu.thread<x>]}
16+
%collapsed = tensor.collapse_shape %1 [[0, 1]] : tensor<8x8xf32> into tensor<64xf32>
17+
return %collapsed : tensor<64xf32>
18+
}
19+
}
20+
21+
module attributes {transform.with_named_sequence} {
22+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
23+
%producer = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
24+
%consumer = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
25+
%2 = transform.iree.fuse_collapse_shape_with_forall %consumer into %producer
26+
: (!transform.any_op, !transform.any_op) -> !transform.any_op
27+
transform.yield
28+
}
29+
}
30+
31+
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 2)>
32+
33+
// CHECK-LABEL: func @fuse_collapse_shape_with_forall
34+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<8x8xf32>
35+
36+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8xf32>
37+
// CHECK: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[EMPTY]] {{\[}}[0, 1]] : tensor<8x8xf32> into tensor<64xf32>
38+
// CHECK: %[[FORALL_RESULT:.+]] = scf.forall (%[[IDX:.+]]) in (4) shared_outs(%[[COLLAPSED_BBARG:.+]] = %[[COLLAPSED_OUT]]) -> (tensor<64xf32>) {
39+
// CHECK-DAG: %[[SLICE_IDX_0:.+]] = affine.apply #[[$MAP]](%[[IDX]])
40+
// CHECK-DAG: %[[SLICE_IDX_1:.+]] = arith.constant 0 : index
41+
// CHECK: %[[LINEAR_SLICE_IDX:.+]] = affine.linearize_index [%[[SLICE_IDX_0]], %[[SLICE_IDX_1]]] by (8, 8) : index
42+
// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[SLICE_IDX_0]], 0] [2, 8] [1, 1] : tensor<8x8xf32> to tensor<2x8xf32>
43+
// CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[COLLAPSED_BBARG]][%[[LINEAR_SLICE_IDX]]] [16] [1] : tensor<64xf32> to tensor<16xf32>
44+
// CHECK: %[[EXPANDED_OUT_SLICE:.+]] = tensor.expand_shape %[[OUT_SLICE]] {{\[}}[0, 1]] output_shape [2, 8] : tensor<16xf32> into tensor<2x8xf32>
45+
// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<2x8xf32>) outs(%[[EXPANDED_OUT_SLICE]] : tensor<2x8xf32>) -> tensor<2x8xf32>
46+
// CHECK: %[[COLLAPSED_COPY:.+]] = tensor.collapse_shape %[[COPY]] {{\[}}[0, 1]] : tensor<2x8xf32> into tensor<16xf32>
47+
// CHECK: scf.forall.in_parallel {
48+
// CHECK: tensor.parallel_insert_slice %[[COLLAPSED_COPY]] into %[[COLLAPSED_BBARG]][%[[LINEAR_SLICE_IDX]]] [16] [1] : tensor<16xf32> into tensor<64xf32>
49+
// CHECK: }
50+
// CHECK: } {mapping = [#gpu.thread<x>]}
51+
// CHECK: return %[[FORALL_RESULT]]

0 commit comments

Comments
 (0)