Skip to content

Commit 74cc6bc

Browse files
authored
[spirv] Vectorize Linalg ops with reduction dimensions (iree-org#8836)
This helps to generate better code for Linalg ops with both parallel and reduction iterator types but not using minor identity maps. To make the flow work, SPIRVVectorizePass is tweaked to pull in more patterns. Also along the way, enabled tiling along `linalg.generic` reduction dimensions.
1 parent 28bc420 commit 74cc6bc

File tree

9 files changed

+522
-45
lines changed

9 files changed

+522
-45
lines changed

compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -410,17 +410,14 @@ static LogicalResult setDefaultOpConfig(spirv::ResourceLimitsAttr limits,
410410
SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
411411
bool vectorizable =
412412
allowVectorization &&
413-
// The vectorization pipeline assumes tensor semantics when tiling.
414-
!linalgOp.hasBufferSemantics() && !linalgOp.hasIndexSemantics() &&
415-
// Skip vectorization for non-minor identity inputs as it generates
416-
// vector.transfer_read ops with permutation maps that we currently
417-
// cannot lower.
418-
// TODO: Remove this restriction once the lowering of the permutation
419-
// map is supported in core.
420-
llvm::all_of(linalgOp.getIndexingMaps(),
421-
[](AffineMap &map) { return map.isMinorIdentity(); }) &&
422-
// TODO: Lowering of integers other than i32 may require emulation.
423-
// This is currently not supported for vector operation.
413+
// The vectorization pipeline assumes tensor semantics for tiling.
414+
linalgOp.hasTensorSemantics() && !linalgOp.hasIndexSemantics() &&
415+
// Require all affine maps to be projected permutation so that we can
416+
// generate vector transfer ops.
417+
llvm::all_of(
418+
linalgOp.getIndexingMaps(),
419+
[](AffineMap map) { return map.isProjectedPermutation(); }) &&
420+
// TODO: Fix non-32-bit element type vectorization and remove this.
424421
llvm::all_of(linalgOp->getOperands(), has32BitElementType) &&
425422
llvm::none_of(loopBounds, ShapedType::isDynamic);
426423

@@ -496,7 +493,7 @@ static LogicalResult setDefaultOpConfig(spirv::ResourceLimitsAttr limits,
496493
if (distributeToThreads(subgroupSize) != 1) {
497494
// Otherwise, allow larger and larger loss factor.
498495

499-
// Threads for distribution Use 32 at least.
496+
// Threads for distribution. Use 32 at least.
500497
int64_t numThreads = std::max(subgroupSize, 32);
501498
// We can tolerate (1 / lossFactor) of threads in the workgroup to be idle.
502499
int64_t lossFactor = 32;
@@ -515,6 +512,21 @@ static LogicalResult setDefaultOpConfig(spirv::ResourceLimitsAttr limits,
515512
tileSizes.push_back(workgroupTileSizes);
516513
tileSizes.push_back(threadTileSizes);
517514

515+
if (vectorizable) {
516+
// Try to tile all reductions by size 4 if possible. This gives us a chance
517+
// to perform vector4 load if an input has its innnermost dimension being
518+
// reduction. It also avoidss generating too many instructions when
519+
// unrolling vector later.
520+
SmallVector<int64_t> reductionTileSizes(linalgOp.getNumLoops(), 0);
521+
for (const auto &it : llvm::enumerate(linalgOp.getIteratorTypes())) {
522+
if (isReductionIterator(it.value()) && loopBounds[it.index()] % 4 == 0)
523+
reductionTileSizes[it.index()] = 4;
524+
}
525+
if (llvm::any_of(reductionTileSizes, [](int64_t s) { return s != 0; })) {
526+
tileSizes.push_back(reductionTileSizes);
527+
}
528+
}
529+
518530
return setOpConfigAndEntryPointFnTranslation(funcOp, op, tileSizes, pipeline,
519531
workgroupSize);
520532
}

compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ static void populateTilingReductionPatterns(RewritePatternSet &patterns) {
5252
auto filter = linalg::LinalgTransformationFilter({marker}, llvm::None);
5353

5454
linalg::TilingPatterns<linalg::BatchMatmulOp, linalg::Conv2DNhwcHwcfOp,
55-
linalg::DepthwiseConv2DNhwcHwcOp,
55+
linalg::DepthwiseConv2DNhwcHwcOp, linalg::GenericOp,
5656
linalg::MatmulOp>::insert(patterns, tilingOptions,
5757
filter);
5858
}
@@ -283,7 +283,8 @@ class SPIRVTilePass final : public SPIRVTileBase<SPIRVTilePass> {
283283
auto marker = builder.getStringAttr(getTileReductionMarker());
284284
funcOp.walk([&](linalg::LinalgOp op) {
285285
if (isa<linalg::ContractionOpInterface>(*op) ||
286-
isa<linalg::ConvolutionOpInterface>(*op)) {
286+
isa<linalg::ConvolutionOpInterface>(*op) ||
287+
isa<linalg::GenericOp>(*op)) {
287288
op->setAttr(linalg::LinalgTransforms::kLinalgTransformMarker, marker);
288289
}
289290
});

compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,20 @@ Optional<SmallVector<int64_t, 4>> getNativeVectorShape(Operation *op) {
8787
SmallVector<int64_t, 4> nativeSize(contractOp.getIteratorTypes().size(), 1);
8888
nativeSize[lastParalleldim] = 4; // Map to vec4 fma operations.
8989
return nativeSize;
90+
} else if (auto reductionOp = dyn_cast<vector::MultiDimReductionOp>(op)) {
91+
// Unroll all reduction dimensions by size 1 for vector.multi_reduction.
92+
auto srcVectorType = reductionOp.getSourceVectorType();
93+
auto nativeSize = llvm::to_vector<4>(srcVectorType.getShape());
94+
auto dims = reductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
95+
for (const auto &dimAttr : dims) {
96+
nativeSize[dimAttr.getZExtValue()] = 1;
97+
}
98+
return nativeSize;
99+
} else if (auto transposeOp = dyn_cast<vector::TransposeOp>(op)) {
100+
auto vectorType = transposeOp.getResultType();
101+
SmallVector<int64_t, 4> nativeSize(vectorType.getRank(), 1);
102+
nativeSize.back() = getComputeVectorSize(vectorType.getShape().back());
103+
return nativeSize;
90104
}
91105
return llvm::None;
92106
}
@@ -100,6 +114,10 @@ void populateVectorizationPatterns(RewritePatternSet &patterns) {
100114
patterns.add<linalg::LinalgVectorizationPattern>(
101115
patterns.getContext(), f.addOpFilter<linalg::ContractionOpInterface>(),
102116
opt);
117+
// Additinally pull in patterns to canonicalize transfer ops and to shuffle
118+
// broadcast/transpose ops around in order to cancel them or embed into
119+
// contract ops. Embedding in the flexible contract ops will help to sustain
120+
// the structure through various transformations.
103121
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
104122
vector::populateVectorReductionToContractPatterns(patterns);
105123
}
@@ -163,6 +181,34 @@ class SPIRVVectorizePass : public SPIRVVectorizeBase<SPIRVVectorizePass> {
163181
llvm::dbgs() << "\n\n";
164182
});
165183

184+
// Lower vector.multi_dimension early if any operand is a transpose op.
185+
// The lowering itself generates transpose ops. This helps to cancel
186+
// transpose ops. vector.multi_reduction is arguably a higher level op and
187+
// the lowering also unrolls the multi_reduction op, so it makes sense to
188+
// happen before normal unrolling.
189+
{
190+
SmallVector<Operation *> reductionOps;
191+
funcOp.walk([&](vector::MultiDimReductionOp reductionOp) {
192+
if (llvm::any_of(reductionOp->getOperands(), [](Value operand) {
193+
return operand.getDefiningOp<vector::TransposeOp>();
194+
}))
195+
reductionOps.push_back(reductionOp);
196+
return WalkResult::advance();
197+
});
198+
RewritePatternSet patterns(context);
199+
vector::populateVectorMultiReductionLoweringPatterns(
200+
patterns, vector::VectorMultiReductionLowering::InnerParallel);
201+
FrozenRewritePatternSet frozenSet(std::move(patterns));
202+
applyOpPatternsAndFold(reductionOps, frozenSet,
203+
/*strict=*/false);
204+
}
205+
206+
LLVM_DEBUG({
207+
llvm::dbgs() << "--- After lowering multi_reduction ops ---\n";
208+
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
209+
llvm::dbgs() << "\n\n";
210+
});
211+
166212
// Then unroll vectors to native vector size. We try to use 128-bit
167213
// vectors for memory access and 4/2/1 vector sizes for computation.
168214
{
@@ -182,7 +228,7 @@ class SPIRVVectorizePass : public SPIRVVectorizeBase<SPIRVVectorizePass> {
182228
// Next run canonicalization to cast away leading size-1 dimensions. They
183229
// can be generated from vector unrolling and generally cause issues to
184230
// cancel corresponding read/write or insert/extract op pairs. This also
185-
// need to happen befor hositing, where we would make certain vectors loop
231+
// need to happen before hositing, where we would make certain vectors loop
186232
// carried. Once that's done, it's hard to handle the leading size-1
187233
// dimensions across regions.
188234
{
@@ -251,33 +297,39 @@ class SPIRVVectorizePass : public SPIRVVectorizeBase<SPIRVVectorizePass> {
251297
llvm::dbgs() << "\n\n";
252298
});
253299

254-
// Lower vector broadcast and contraction.
300+
// Lower vector broadcast/transpose and contraction.
255301
{
256302
RewritePatternSet patterns(context);
303+
auto options = vector::VectorTransformsOptions()
304+
.setVectorTransformsOptions(
305+
vector::VectorContractLowering::OuterProduct)
306+
.setVectorTransposeLowering(
307+
vector::VectorTransposeLowering::EltWise);
257308
vector::populateVectorBroadcastLoweringPatterns(patterns);
258-
vector::populateVectorContractLoweringPatterns(
259-
patterns,
260-
vector::VectorTransformsOptions().setVectorTransformsOptions(
261-
vector::VectorContractLowering::OuterProduct));
309+
vector::populateVectorContractLoweringPatterns(patterns, options);
310+
vector::populateVectorMultiReductionLoweringPatterns(
311+
patterns, vector::VectorMultiReductionLowering::InnerParallel);
312+
vector::populateVectorTransposeLoweringPatterns(patterns, options);
262313
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
263314
return signalPassFailure();
264315
}
265316
}
266317

267318
LLVM_DEBUG({
268-
llvm::dbgs() << "--- After lowering contract ops ---\n";
319+
llvm::dbgs() << "--- After lowering various vector ops ---\n";
269320
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
270321
llvm::dbgs() << "\n\n";
271322
});
272323

273-
// Cast away leading size-1 dimensions again.
324+
// Run all sorts of canonicalization patterns to clean up again.
274325
{
275326
RewritePatternSet patterns(context);
276-
// We need to pull in casting way leading one dims to allow cancelling
277-
// some read/write ops.
278327
vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
328+
vector::InsertOp::getCanonicalizationPatterns(patterns, context);
329+
vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
279330
vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
280331
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
332+
vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
281333
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
282334
return signalPassFailure();
283335
}

compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@ iree_lit_test_suite(
4242
"tile_and_vectorize_conv.mlir",
4343
"tile_and_vectorize_matmul.mlir",
4444
"tile_and_vectorize_to_cooperative_ops.mlir",
45+
"tile_linalg_ops.mlir",
4546
"vector_to_cooperative_matrix.mlir",
4647
"vectorize_elementwise_ops.mlir",
4748
"vectorize_matmul.mlir",
4849
"vectorize_load_store.mlir",
50+
"vectorize_reduction.mlir",
4951
"vectorize_tensor_pad.mlir",
5052
],
5153
include = ["*.mlir"],

compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ iree_lit_test_suite(
3737
"tile_and_vectorize_conv.mlir"
3838
"tile_and_vectorize_matmul.mlir"
3939
"tile_and_vectorize_to_cooperative_ops.mlir"
40+
"tile_linalg_ops.mlir"
4041
"vector_to_cooperative_matrix.mlir"
4142
"vectorize_elementwise_ops.mlir"
4243
"vectorize_load_store.mlir"
4344
"vectorize_matmul.mlir"
45+
"vectorize_reduction.mlir"
4446
"vectorize_tensor_pad.mlir"
4547
TOOLS
4648
FileCheck

compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_linalg_ops.mlir

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,115 @@ hal.executable @dwconv_elementwise {
360360
// CHECK-SAME: translation_info = #[[TRANSLATION]]
361361
// CHECK: linalg.generic
362362
// CHECK-SAME: lowering_config = #[[CONFIG]]
363+
364+
365+
// -----
366+
367+
#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
368+
#hal.descriptor_set.layout<0, bindings = [
369+
#hal.descriptor_set.binding<0, storage_buffer>,
370+
#hal.descriptor_set.binding<1, storage_buffer>
371+
]>
372+
]>
373+
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
374+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
375+
376+
hal.executable @outermost_reduction {
377+
hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
378+
spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
379+
max_compute_shared_memory_size = 32768 : i32,
380+
max_compute_workgroup_invocations = 512 : i32,
381+
max_compute_workgroup_size = dense<512> : vector<3xi32>,
382+
subgroup_size = 32 : i32}>
383+
}> {
384+
hal.executable.entry_point @outermost_reduction layout(#executable_layout)
385+
builtin.module {
386+
func.func @outermost_reduction() {
387+
%cst = arith.constant 0.000000e+00 : f32
388+
%c0 = arith.constant 0 : index
389+
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:4x2048x512xf32>
390+
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:2048x512xf32>
391+
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4, 2048, 512], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:4x2048x512xf32> -> tensor<4x2048x512xf32>
392+
%3 = linalg.init_tensor [2048, 512] : tensor<2048x512xf32>
393+
%4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
394+
%5 = linalg.generic {
395+
indexing_maps = [#map0, #map1],
396+
iterator_types = ["parallel", "parallel", "reduction"]
397+
} ins(%2 : tensor<4x2048x512xf32>) outs(%4 : tensor<2048x512xf32>) {
398+
^bb0(%arg0: f32, %arg1: f32):
399+
%6 = arith.addf %arg0, %arg1 : f32
400+
linalg.yield %6 : f32
401+
} -> tensor<2048x512xf32>
402+
flow.dispatch.tensor.store %5, %1, offsets = [0, 0], sizes = [2048, 512], strides = [1, 1] : tensor<2048x512xf32> -> !flow.dispatch.tensor<writeonly:2048x512xf32>
403+
return
404+
}
405+
}
406+
}
407+
}
408+
409+
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 128], [1, 4], [0, 0, 4]{{\]}}>
410+
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVVectorize>
411+
// CHECK-LABEL: hal.executable.entry_point public @outermost_reduction
412+
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
413+
// CHECK: linalg.generic
414+
// CHECK-SAME: lowering_config = #[[$CONFIG]]
415+
416+
// -----
417+
418+
#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
419+
#hal.descriptor_set.layout<0, bindings = [
420+
#hal.descriptor_set.binding<0, storage_buffer>,
421+
#hal.descriptor_set.binding<1, storage_buffer>
422+
]>
423+
]>
424+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
425+
#map1 = affine_map<(d0, d1) -> (d0)>
426+
427+
hal.executable private @innermost_reduction {
428+
hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
429+
spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
430+
max_compute_shared_memory_size = 32768 : i32,
431+
max_compute_workgroup_invocations = 512 : i32,
432+
max_compute_workgroup_size = dense<512> : vector<3xi32>,
433+
subgroup_size = 32 : i32}>
434+
}> {
435+
hal.executable.entry_point public @innermost_reduction ordinal(0) layout(#executable_layout)
436+
builtin.module {
437+
func.func @innermost_reduction() {
438+
%cst = arith.constant -0.000000e+00 : f32
439+
%0 = hal.interface.constant.load[0] : i32
440+
%1 = hal.interface.constant.load[1] : i32
441+
%2 = hal.interface.constant.load[2] : i32
442+
%3 = arith.index_cast %0 {stream.alignment = 512 : index, stream.values = [0 : index, 394752 : index, 984064 : index]} : i32 to index
443+
%4 = arith.index_cast %1 {stream.alignment = 512 : index, stream.values = [0 : index, 196608 : index, 197120 : index]} : i32 to index
444+
%5 = arith.index_cast %2 {stream.alignment = 512 : index, stream.values = [512 : index, 197120 : index, 197632 : index]} : i32 to index
445+
%6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%3) alignment(64) : !flow.dispatch.tensor<readonly:128x384xf32>
446+
%7 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%4) alignment(64) : !flow.dispatch.tensor<readonly:128xf32>
447+
%8 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%5) alignment(64) : !flow.dispatch.tensor<writeonly:128xf32>
448+
%9 = flow.dispatch.tensor.load %6, offsets = [0, 0], sizes = [128, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x384xf32> -> tensor<128x384xf32>
449+
%10 = flow.dispatch.tensor.load %7, offsets = [0], sizes = [128], strides = [1] : !flow.dispatch.tensor<readonly:128xf32> -> tensor<128xf32>
450+
%11 = linalg.init_tensor [128] : tensor<128xf32>
451+
%12 = linalg.fill ins(%cst : f32) outs(%11 : tensor<128xf32>) -> tensor<128xf32>
452+
%13 = linalg.generic {
453+
indexing_maps = [#map0, #map1, #map1],
454+
iterator_types = ["parallel", "reduction"]
455+
} ins(%9, %10 : tensor<128x384xf32>, tensor<128xf32>) outs(%12 : tensor<128xf32>) {
456+
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
457+
%14 = arith.subf %arg0, %arg1 : f32
458+
%15 = arith.mulf %14, %14 : f32
459+
%16 = arith.addf %15, %arg2 : f32
460+
linalg.yield %16 : f32
461+
} -> tensor<128xf32>
462+
flow.dispatch.tensor.store %13, %8, offsets = [0], sizes = [128], strides = [1] : tensor<128xf32> -> !flow.dispatch.tensor<writeonly:128xf32>
463+
return
464+
}
465+
}
466+
}
467+
}
468+
469+
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[128], [4], [0, 4]{{\]}}>
470+
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVVectorize>
471+
// CHECK-LABEL: hal.executable.entry_point public @innermost_reduction
472+
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
473+
// CHECK: linalg.generic
474+
// CHECK-SAME: lowering_config = #[[$CONFIG]]

0 commit comments

Comments
 (0)