Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions docs/wave/ir_design_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,38 @@ node directly. The MLIR side does not have this problem: `makeIsolated`
walks the region and discovers all outer references regardless of how they
are represented. As a result, the MLIR-imported trace may list more
captures than the source trace.


IndexMapping
------------

The `mapping` attribute on read and write operations on the Python side allows
for separate mappings for "inputs" (memory operand for reads, value operand for
writes) and "outputs" (value operand for reads, memory operand for writes). Each
of this is a dictionary from symbol names used to index the corresponding tensor
to a full-fledged sympy expression that may involve, in addition to the usual
symbols, special placeholder `iterator` symbols that refer to
positionally-indexed iterators of the notional iteration space that surrounds
the op. The order of elements in the dictionary is load-bearing, though its
exact meaning is not properly documented. It does not necessarily match the
order of symbols in the shape. None of this has verification logic and
unsupported cases just hit assertions or other exceptions inside the compilation
flow.

The simultaneous presence of both "inputs" and "outputs" mapping means that one
of them may be kept as identity, i.e., the symbols are mapped to positional
iterators where the position matches the position of the symbol in the
corresponding shape. For reads, this is the "outputs" mapping and, for writes,
this is the "inputs" mapping. There is currently no enforcement that it is
indeed the case, only a verbalized implicit assumption. This redundancy allows
one to (almost always) map every symbol to a single positional iterator. When a
more complex expression is used, additional logic attempts to extract the single
iterator that is used in it. This in turn allows to compute a permutation of
dimensions during _code generation_ of reads and writes: index expressions that
appear in a specific order are mapped to positional iterations with the same
position. Then this mapping is used to update the mapping from the memory shape
dimensions to co-indexed iterators, potentially resulting in a permuted index
expression list. For example, given a memory shape `[A, B, C, D]` and a mapping
`{A: i0, B: i3, C: i2, D: i1}` first creates a map `{i0: index[A], i1: index[B],
i2: index[C], i3: index[D]}` and then obtains the permuted index map
`{A: index[A], B: index[D], C: index[C], D: index[B]}`.
188 changes: 188 additions & 0 deletions lit_tests/kernel/wave/mlir_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,3 +1360,191 @@ def permute_kernel(

# CHECK: wave.write %[[PERMUTE]], %[[ARG1]]
# CHECK-SAME: !wave.tensor<[@N, @M] of f16, <register>>, !wave.tensor<[@N, @M] of f16, <global>>


@run_test
def mlir_converter_read_with_mapping():
"""Test MLIR converter with read operation using cyclic permutation mapping."""

constraints = [
tkw.WorkgroupConstraint(M, BLOCK_M, 0),
tkw.WorkgroupConstraint(N, BLOCK_N, 1),
tkw.WaveConstraint(M, sympy.floor(BLOCK_M / 2)),
tkw.WaveConstraint(N, sympy.floor(BLOCK_N / 2)),
tkw.HardwareConstraint(
threads_per_wave=64,
vector_shapes={M: 64, N: 64, K: 32},
),
]

@wave.wave(constraints)
def read_with_mapping_kernel(
a: Memory[M, N, K, ADDRESS_SPACE_A, tkl.f16],
b: Memory[N, K, M, ADDRESS_SPACE_C, tkl.f16],
):
# Create a cyclic permutation mapping for read: (d0, d1, d2) -> (d1, d2, d0)
# Memory has shape [M, N, K], register will have shape [N, K, M]
# This is a non-self-inverse permutation (requires 3 applications to return to identity)
# inputs = memory dimensions, outputs = register dimensions
i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
k = tkw.IndexMapping.iterator(2)
cyclic_mapping = tkw.IndexMapping(
num_iterators=3,
inputs={
M: k,
N: i,
K: j,
}, # Memory[M,N,K]: permutation maps (d0,d1,d2) -> (d1,d2,d0)
outputs={
N: i,
K: j,
M: k,
}, # Register[N,K,M]: N→iter(0), K→iter(1), M→iter(2)
)

# Read with cyclic permutation mapping
a_reg = wave.read(a, mapping=cyclic_mapping)
# Write to permuted memory
wave.write(a_reg, b)

# Set parameters for compilation
subs = {
ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE,
ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE,
BLOCK_M: 64,
BLOCK_N: 64,
M: 128,
N: 128,
K: 128,
}

# Compile the kernel to get the trace
options = WaveCompileOptions(
subs=subs,
compile_to_mlir=True,
location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE),
enforce_locations=False,
)
options = set_default_run_config(options)

compiled_kernel = wave_compile(options, read_with_mapping_kernel)
trace = compiled_kernel.get_compiled_graph()
kernel_constraints = read_with_mapping_kernel.constraints

# Use the mlir_converter to emit wave MLIR dialect
mlir_output, diagnostics, _ = emit_wave_dialect(trace, kernel_constraints, options)

if diagnostics:
for diagnostic in diagnostics:
print(diagnostic, file=sys.stderr)
assert (
len(diagnostics) == 0
), "dialect emission should create valid IR, therefore diagnostics should be empty"

# Print to stdout for FileCheck
print(mlir_output)

# CHECK-LABEL: mlir_converter_read_with_mapping
# CHECK: func.func @kernel(%[[ARG0:.*]]: !wave.tensor<[@M, @N, @K] of f16, <global>>, %[[ARG1:.*]]: !wave.tensor<[@N, @K, @M] of f16, <global>>)

# CHECK: %[[READ:.*]] = wave.read %[[ARG0]]
# CHECK-SAME: mapping = #wave.expr_list<[](d0, d1, d2) -> (d1, d2, d0)>
# CHECK-SAME: (!wave.tensor<[@M, @N, @K] of f16, <global>>) -> !wave.tensor<[@N, @K, @M] of f16, <register>>

# CHECK: wave.write %[[READ]], %[[ARG1]]
# CHECK-SAME: !wave.tensor<[@N, @K, @M] of f16, <register>>, !wave.tensor<[@N, @K, @M] of f16, <global>>


@run_test
def mlir_converter_write_with_mapping():
"""Test MLIR converter with write operation using cyclic permutation mapping."""

constraints = [
tkw.WorkgroupConstraint(M, BLOCK_M, 0),
tkw.WorkgroupConstraint(N, BLOCK_N, 1),
tkw.WaveConstraint(M, sympy.floor(BLOCK_M / 2)),
tkw.WaveConstraint(N, sympy.floor(BLOCK_N / 2)),
tkw.HardwareConstraint(
threads_per_wave=64,
vector_shapes={M: 64, N: 64, K: 32},
),
]

@wave.wave(constraints)
def write_with_mapping_kernel(
a: Memory[N, K, M, ADDRESS_SPACE_A, tkl.f16],
b: Memory[M, N, K, ADDRESS_SPACE_C, tkl.f16],
):
# Create a cyclic permutation mapping for write: (d0, d1, d2) -> (d1, d2, d0)
# Register has shape [N, K, M], memory has shape [M, N, K]
# This is a non-self-inverse permutation (requires 3 applications to return to identity)
# inputs = register dimensions, outputs = memory dimensions
i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
k = tkw.IndexMapping.iterator(2)
cyclic_mapping = tkw.IndexMapping(
num_iterators=3,
inputs={
N: i,
K: j,
M: k,
}, # Register[N,K,M]: N→iter(0), K→iter(1), M→iter(2)
outputs={
M: k,
N: i,
K: j,
}, # Memory[M,N,K]: permutation maps (d0,d1,d2) -> (d1,d2,d0)
)

# Read from memory (no mapping)
a_reg = wave.read(a)
# Write with cyclic permutation mapping
wave.write(a_reg, b, mapping=cyclic_mapping)

# Set parameters for compilation
subs = {
ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE,
ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE,
BLOCK_M: 64,
BLOCK_N: 64,
M: 128,
N: 128,
K: 128,
}

# Compile the kernel to get the trace
options = WaveCompileOptions(
subs=subs,
compile_to_mlir=True,
location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE),
enforce_locations=False,
)
options = set_default_run_config(options)

compiled_kernel = wave_compile(options, write_with_mapping_kernel)
trace = compiled_kernel.get_compiled_graph()
kernel_constraints = write_with_mapping_kernel.constraints

# Use the mlir_converter to emit wave MLIR dialect
mlir_output, diagnostics, _ = emit_wave_dialect(trace, kernel_constraints, options)

if diagnostics:
for diagnostic in diagnostics:
print(diagnostic, file=sys.stderr)
assert (
len(diagnostics) == 0
), "dialect emission should create valid IR, therefore diagnostics should be empty"

# Print to stdout for FileCheck
print(mlir_output)

# CHECK-LABEL: mlir_converter_write_with_mapping
# CHECK: func.func @kernel(%[[ARG0:.*]]: !wave.tensor<[@N, @K, @M] of f16, <global>>, %[[ARG1:.*]]: !wave.tensor<[@M, @N, @K] of f16, <global>>)

# CHECK: %[[READ:.*]] = wave.read %[[ARG0]]
# CHECK-SAME: (!wave.tensor<[@N, @K, @M] of f16, <global>>) -> !wave.tensor<[@N, @K, @M] of f16, <register>>

# CHECK: wave.write %[[READ]], %[[ARG1]]
# CHECK-SAME: mapping = #wave.expr_list<[](d0, d1, d2) -> (d1, d2, d0)>
# CHECK-SAME: !wave.tensor<[@N, @K, @M] of f16, <register>>, !wave.tensor<[@M, @N, @K] of f16, <global>>
8 changes: 6 additions & 2 deletions water/include/water/Dialect/Wave/IR/WaveInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,16 @@ llvm::FailureOr<mlir::ChangeResult>
propagateShapeInformation(wave::WaveTensorType from, wave::WaveTensorType &to,
llvm::StringRef fromName, llvm::StringRef toName,
llvm::raw_ostream &errs);
llvm::FailureOr<mlir::ChangeResult>
propagateShapeInformation(llvm::ArrayRef<wave::WaveSymbolAttr> from,
wave::WaveTensorType &to, llvm::StringRef fromName,
llvm::StringRef toName, llvm::raw_ostream &errs);

// Propagate shape information from `source` to `target` and drop the `n`
// `source` dims. Expects both to be fully-specified tensor types. If
// propagation discovers a type conflict, prints the error message to the
// `errs` stream and returns failure. Otherwise returns a tag indicating whether
// the target type changed.
// `errs` stream and returns failure. Otherwise returns a tag indicating
// whether the target type changed.
llvm::FailureOr<mlir::ChangeResult> propagateShapeDropTrailingDims(
wave::WaveTensorType source, wave::WaveTensorType &target,
llvm::StringRef sourceName, llvm::StringRef targetName, unsigned n,
Expand Down
24 changes: 19 additions & 5 deletions water/include/water/Dialect/Wave/IR/WaveOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -499,14 +499,18 @@ def ReshapeOp : WaveOp<"reshape", [
}

def ReadOp : WaveOp<"read", [
WaveInferTypeOpInterface, IdentityTypeInferenceOpTrait,
DeclareOpInterfaceMethods<WaveInferTypeOpInterface>,
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
CompatibleOperandsAndResultsIgnoreSpaceOpTrait,
WaveInferIndexExprsOpInterface, IdentityIndexExprsOpTrait]> {
let summary = "Reads from memory";
let description = [{
Moves data from a memory-resident tensor to a register-resident tensor
preserving the shape.

If mapping is provided, it indicates how dimensions of the register-resident
value tensor map to those of the memory-resident tensor. Currently, this must
be a permutation of dimensions such that the symbolic shapes of the tensors
match.
}];

let arguments = !con((ins
Expand All @@ -516,7 +520,10 @@ def ReadOp : WaveOp<"read", [
Arg<OptionalAttr<WaveSymbolMappingToNResultExprListAttr<1>>,
"Bound expressions for symbolic dimensions that need masking">:$bounds,
Arg<OptionalAttr<WaveSymbolArrayAttr>,
"Ordered dimension symbols from memory type shape">:$ordered_syms
"Ordered dimension symbols from memory type shape">:$ordered_syms,
Arg<OptionalAttr<WaveExprListAttr>,
"Indicates how value dimensions are remapped into memory dimensions">
:$mapping
), commonArguments);

let results = (outs
Expand Down Expand Up @@ -565,13 +572,17 @@ def RegisterOp : WaveOp<"register", [
def WriteOp : WaveOp<"write", [
WaveInferTypeOpInterface, NoOpTypeInferenceOpTrait,
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
CompatibleOperandsAndResultsIgnoreSpaceOpTrait,
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface>,
RequiresSidewaysBackwardPropagationOpTrait]> {
let summary = "Writes into memory";
let description = [{
Moves data from a register-resident tensor into a memory-resident tensor
preserving the shape.

If mapping is provided, it indicates how dimensions of the register-resident
value tensor map to those of the memory-resident tensor. Currently, this must
be a permutation of dimensions such that the symbolic shapes of the tensors
match.
}];

let arguments = !con((ins
Expand All @@ -582,7 +593,10 @@ def WriteOp : WaveOp<"write", [
Arg<OptionalAttr<WaveSymbolMappingToNResultExprListAttr<1>>,
"Bound expressions for symbolic dimensions that need masking">:$bounds,
Arg<OptionalAttr<WaveSymbolArrayAttr>,
"Ordered dimension symbols from memory type shape">:$ordered_syms
"Ordered dimension symbols from memory type shape">:$ordered_syms,
Arg<OptionalAttr<WaveExprListAttr>,
"Indicates how value dimensions are remapped into memory dimensions">
:$mapping
), commonArguments);

let assemblyFormat =
Expand Down
Loading
Loading