Skip to content

Commit

Permalink
[mlir][linalg][transform][python] Extend mix-in for Vectorize
Browse files Browse the repository at this point in the history
Extends the existing mix-in for VectorizeOp with support for the missing unit attributes.

Also fixes the unintuitive implementation where
`structured.VectorizeOp(target=target, vectorize_padding=False)` still resulted in the creation of the UnitAttr `vectorize_padding`.

Reviewed By: ingomueller-net

Differential Revision: https://reviews.llvm.org/D158726
  • Loading branch information
ingomueller-net committed Aug 28, 2023
1 parent fff1830 commit a470df3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 7 deletions.
10 changes: 7 additions & 3 deletions mlir/python/mlir/dialects/_structured_transform_ops_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,16 +783,20 @@ def __init__(
self,
target: Union[Operation, Value],
*,
vectorize_padding: Union[bool, BoolAttr] = False,
disable_multi_reduction_to_contract_patterns: bool = False,
disable_transfer_permutation_map_lowering_patterns: bool = False,
vectorize_nd_extract: bool = False,
vectorize_padding: bool = False,
loc=None,
ip=None,
):
pdl_operation_type = pdl.OperationType.get()
if isinstance(vectorize_padding, bool):
vectorize_padding = UnitAttr.get()
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
vectorize_nd_extract=vectorize_nd_extract,
vectorize_padding=vectorize_padding,
loc=loc,
ip=ip,
Expand Down
40 changes: 36 additions & 4 deletions mlir/test/python/dialects/transform_structured_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,17 +560,49 @@ def testTileToForallMapping():


@run
def testVectorize():
def testVectorizeAllAttrs():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
structured.VectorizeOp(
sequence.bodyTarget,
disable_multi_reduction_to_contract_patterns=True,
disable_transfer_permutation_map_lowering_patterns=True,
vectorize_nd_extract=True,
vectorize_padding=True,
)
transform.YieldOp()
# CHECK-LABEL: TEST: testVectorizeAllAttrs
# CHECK: transform.sequence
# CHECK: = transform.structured.vectorize
# CHECK-SAME: disable_multi_reduction_to_contract_patterns
# CHECK-SAME: disable_transfer_permutation_map_lowering_patterns
# CHECK-SAME: vectorize_nd_extract
# CHECK-SAME: vectorize_padding


@run
def testVectorizeNoAttrs():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.VectorizeOp(
sequence.bodyTarget,
disable_multi_reduction_to_contract_patterns=False,
disable_transfer_permutation_map_lowering_patterns=False,
vectorize_nd_extract=False,
vectorize_padding=False,
)
transform.YieldOp()
# CHECK-LABEL: TEST: testVectorize
# CHECK-LABEL: TEST: testVectorizeNoAttrs
# CHECK: transform.sequence
# CHECK: = transform.structured.vectorize
# CHECK: {vectorize_padding}
# CHECK-NOT: disable_multi_reduction_to_contract_patterns
# CHECK-NOT: disable_transfer_permutation_map_lowering_patterns
# CHECK-NOT: vectorize_nd_extract
# CHECK-NOT: vectorize_padding


@run
Expand Down

0 comments on commit a470df3

Please sign in to comment.