Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 102 additions & 1 deletion wave_lang/kernel/compiler/wave_codegen/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@
from ...wave.compile_options import WaveCompileOptions
from ...wave.constraints import Constraint, HardwareConstraint, TilingConstraint
from ...wave.utils.general_utils import get_hardware_constraint
from ...wave.utils.symbol_utils import subs_idxc, is_literal
from ...wave.utils.symbol_utils import (
decompose_affine_by_uniformity,
subs_idxc,
is_literal,
)

logger = get_logger("wave.ops_location_check")

Expand Down Expand Up @@ -598,6 +602,103 @@ def add_emitter_subs(
return dynamics


def get_uniformity_classes(
emitter: WaveEmitter,
) -> list[set]:
"""Return symbol sets ordered from most-uniform to least-uniform.

The returned list is suitable for passing to
:func:`decompose_affine_by_uniformity` or
:func:`gen_sympy_index_decomposed`.

Currently returns up to two classes:
* ``{WORKGROUP_0, WORKGROUP_1, WORKGROUP_2}``
* induction-variable symbols (if the emitter is inside a loop)

Thread-ID symbols and dynamic values are the implicit *remainder*
(most divergent) and need not be listed.
"""
wg = {WORKGROUP_0, WORKGROUP_1, WORKGROUP_2}
classes: list[set] = [wg]
iv_syms = set(emitter.get_induction_vars_and_syms()[1])
if iv_syms:
classes.append(iv_syms)
return classes


def gen_sympy_index_decomposed(
emitter: WaveEmitter,
expr: "sympy.Expr",
dynamic_values: dict = {},
uniform_sym_classes: list[set] | None = None,
) -> tuple[Value, list[Value]]:
"""Lower a sympy expression to MLIR with automatic uniformity decomposition.

Decomposes *expr* into additive components by uniformity class, emits
each component via :func:`gen_sympy_index`, and combines them so that
uniform (SGPR-eligible) contributions are separate ``arith.addi`` ops.
This lets the AMDGPU backend keep uniform parts in SGPRs and
potentially fold them into hardware instruction fields (e.g. soffset).

Args:
emitter: The current wave emitter (provides symbol bindings).
expr: Sympy expression to lower.
dynamic_values: Extra symbol-to-Value mappings.
uniform_sym_classes: Override for uniformity classes. When
``None``, uses :func:`get_uniformity_classes`.

Returns:
``(combined, components)`` where *combined* is the final MLIR
Value (sum of all components) and *components* is a list of
per-class Values (one per class + remainder).
"""
import sympy as _sympy

subs = add_emitter_subs(emitter, dynamic_values)
classes = uniform_sym_classes or get_uniformity_classes(emitter)
parts = decompose_affine_by_uniformity(expr, classes)

zero = _sympy.sympify(0)
component_values = [gen_sympy_index(subs, p) for p in parts]

# Combine: sum uniform components first (SGPR + SGPR stays SGPR),
# then add the divergent remainder last (VGPR + SGPR).
overflow_flags = (
arith_d.IntegerOverflowFlags.nsw | arith_d.IntegerOverflowFlags.nuw
)
uniform_sum = None
for cv in component_values[:-1]:
if _is_zero(cv):
continue
if uniform_sum is None:
uniform_sum = cv
else:
uniform_sum = arith_d.addi(uniform_sum, cv, overflow_flags=overflow_flags)

remainder = component_values[-1]

if uniform_sum is None:
combined = remainder
elif _is_zero(remainder):
combined = uniform_sum
else:
combined = arith_d.addi(remainder, uniform_sum, overflow_flags=overflow_flags)

return combined, component_values


def _is_zero(val: Value) -> bool:
"""Return True if *val* is a constant-zero index."""
if not hasattr(val, "owner") or not hasattr(val.owner, "opview"):
return False
op = val.owner.opview
if isinstance(op, arith_d.ConstantOp):
v = op.attributes["value"]
if isinstance(v, IntegerAttr) and int(v) == 0:
return True
return False


_emulate_ceildiv = bool(int(environ.get("WAVE_EMULATE_CEILDIV", 0)))
_use_affine_expr = bool(int(environ.get("WAVE_USE_AFFINE_EXPR", 1)))

Expand Down
175 changes: 140 additions & 35 deletions wave_lang/kernel/compiler/wave_codegen/read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
)
from ...wave.utils.general_utils import get_fastest_index, infer_dim, linearize_index
from ...wave.utils.mapping_utils import transform_index_on_mapping
from ...wave.utils.symbol_utils import safe_subs
from ...wave.utils.symbol_utils import decompose_affine_by_uniformity, safe_subs
from .emitter import (
WaveEmitter,
add_emitter_subs,
Expand All @@ -71,8 +71,10 @@
cast_py_value,
cast_vector,
gen_sympy_index,
gen_sympy_index_decomposed,
get_constant_attr,
get_type_or_element_type,
get_uniformity_classes,
handle_op,
)

Expand Down Expand Up @@ -105,26 +107,17 @@ def _simplify(expr):
return sympy.simplify(expr)


def _split_index(src: IndexExpr | int) -> tuple[IndexExpr, IndexExpr]:
"""
Split index expr into thread-dependent and thread-independent parts
"""
subs_wg = {WORKGROUP_0: 0, WORKGROUP_1: 0, WORKGROUP_2: 0}
# Replace all wg symbols with 0s to get thread-dependent index.
# All dynamic values will also be part of thread-index.
thread_dependent_index = safe_subs(src, subs_wg)
_WG_SYMS = {WORKGROUP_0, WORKGROUP_1, WORKGROUP_2}

# Compute thread-independent index as `orig_index - thread_dependent_index`
# All thread symbols and dynamic should cancel-out in the result.
thread_independent_index = _simplify(src - thread_dependent_index)
if thread_independent_index.free_symbols - set(subs_wg.keys()):
# If we have any symbols besides wg symbols, means some thread or
# dynamic symbols were not canceled out, use the entire index as
# thread dependent index.
thread_independent_index = sympy.sympify(0)
thread_dependent_index = src

return thread_independent_index, thread_dependent_index
def _split_index(src: IndexExpr | int) -> tuple[IndexExpr, IndexExpr]:
"""Split index expr into workgroup and thread-dependent parts.

Thin wrapper around :func:`decompose_affine_by_uniformity` with a
single symbol class (workgroup symbols).
"""
parts = decompose_affine_by_uniformity(src, [_WG_SYMS])
return parts[0], parts[1]


def _extract0(src):
Expand All @@ -151,15 +144,82 @@ def _build_start_indices(
emitter: WaveEmitter,
src_indices: dict[IndexExpr, IndexSequence | IndexExpr],
dynamic_values: dict[IndexExpr, Any] = {},
) -> tuple[list[OpResult], list[OpResult], list[OpResult]]:
uniform_sym_classes: list[set] | None = None,
) -> tuple:
"""Build MLIR index values with N-way uniformity decomposition.

When *uniform_sym_classes* is ``None`` (default), performs the legacy
two-way split (workgroup / thread) and returns a 3-tuple::

(full_indices, wg_indices, thread_indices)

When *uniform_sym_classes* is a list of symbol sets (ordered
most-uniform first, e.g. ``[induction_var_syms]``), prepends the
workgroup class automatically and returns an ``(n + 2)``-tuple::

(full_indices, wg_indices, class_0_indices, ..., thread_indices)
"""
start_indices = _get_start_indices(src_indices)
split_indices = [_split_index(i) for i in start_indices]
subs = add_emitter_subs(emitter, dynamic_values)
indices = [gen_sympy_index(subs, i) for i in start_indices]
indices_wg = [gen_sympy_index(subs, i[0]) for i in split_indices]
indices_th = [gen_sympy_index(subs, i[1]) for i in split_indices]

return indices, indices_wg, indices_th
classes = [_WG_SYMS] + (uniform_sym_classes or [])
decomposed = [
decompose_affine_by_uniformity(i, classes) for i in start_indices
]

n_parts = len(classes) + 1 # one per class + remainder
parts: list[list[OpResult]] = []
for k in range(n_parts):
parts.append([gen_sympy_index(subs, d[k]) for d in decomposed])

if not uniform_sym_classes:
return indices, parts[0], parts[1]

return (indices,) + tuple(parts)


def _compute_linear_offset(
indices: list[Value | int],
strides: list[Value],
) -> Value | None:
"""Linearize per-dimension index values with strides into a scalar offset.

Returns *None* when ``indices`` is empty.
"""
overflow_flags = arith_d.IntegerOverflowFlags.nsw
offset = None
for idx, stride in zip(indices, strides):
if isinstance(idx, int):
idx = arith_d.constant(IndexType.get(), idx)
off = arith_d.muli(idx, stride, overflow_flags=overflow_flags)
if offset is None:
offset = off
else:
offset = arith_d.addi(offset, off, overflow_flags=overflow_flags)
return offset


def _apply_uniform_offsets(
offset_th: Value,
uniform_parts: list[list[Value]],
strides: list[Value],
) -> Value:
"""Add uniform (SGPR-eligible) contributions to the thread offset.

Each entry in *uniform_parts* is a list of per-dimension index values
for one uniformity class. They are linearized with *strides* and
added to *offset_th* as separate ``arith.addi`` ops so the backend
can keep them in SGPRs (e.g. fold into ``buffer_load`` soffset).
"""
overflow_flags = arith_d.IntegerOverflowFlags.nsw
for unif_indices in uniform_parts:
unif_offset = _compute_linear_offset(unif_indices, strides)
if unif_offset is not None and _get_constant_value(unif_offset) != 0:
offset_th = arith_d.addi(
offset_th, unif_offset, overflow_flags=overflow_flags
)
return offset_th


def _get_symbolic_shape(node: fx.Node) -> tuple[IndexExpr]:
Expand Down Expand Up @@ -469,6 +529,7 @@ def _create_vec_read_write(
memory: CustomOp,
mask: Optional[Value],
node_index: Optional[IndexSequence] = None,
uniform_parts: list[list[Value]] | None = None,
) -> Optional[Value]:
is_read = value is None
uint32 = IntegerType.get_signless(32)
Expand Down Expand Up @@ -512,6 +573,8 @@ def extract(vec, ind):
mem, start_indices_wg, start_indices_th, strides
)
mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter)
if uniform_parts:
offset_th = _apply_uniform_offsets(offset_th, uniform_parts, strides)
if linearize_shared_mem:
mem = _linearize_shared_mem(mem)
linearized_index = {
Expand Down Expand Up @@ -548,6 +611,8 @@ def extract(vec, ind):
mem, start_indices_wg, start_indices_th, strides
)
mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter)
if uniform_parts:
offset_th = _apply_uniform_offsets(offset_th, uniform_parts, strides)

indices = [offset_th] if buffer_ops_enabled else start_indices

Expand Down Expand Up @@ -711,9 +776,21 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
else:
mask = _build_mask(emitter, index, elements_per_thread, bounds)

start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, index, dynamic_vals_map_start
)
induction_vars = set(emitter.get_induction_vars_and_syms()[1])
uniform_parts: list[list[Value]] = []
if induction_vars:
start_indices, start_indices_wg, *uniform_parts, start_indices_th = (
_build_start_indices(
emitter,
index,
dynamic_vals_map_start,
uniform_sym_classes=[induction_vars],
)
)
else:
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, index, dynamic_vals_map_start
)

use_llvm_load = flags != MemoryAccessFlags.NONE
if use_llvm_load:
Expand All @@ -738,6 +815,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
get_custom(memory),
mask,
node_index=index,
uniform_parts=uniform_parts or None,
)

emitter.bind_node_proxy(node, IRProxyValue(result))
Expand Down Expand Up @@ -802,9 +880,21 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
else:
mask = _build_mask(emitter, index, elements_per_thread, bounds)

start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, index, dynamic_vals_map_start
)
induction_vars = set(emitter.get_induction_vars_and_syms()[1])
uniform_parts: list[list[Value]] = []
if induction_vars:
start_indices, start_indices_wg, *uniform_parts, start_indices_th = (
_build_start_indices(
emitter,
index,
dynamic_vals_map_start,
uniform_sym_classes=[induction_vars],
)
)
else:
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, index, dynamic_vals_map_start
)

use_llvm_store = flags != MemoryAccessFlags.NONE
if use_llvm_store:
Expand All @@ -825,6 +915,7 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
get_custom(memory),
mask,
node_index=index,
uniform_parts=uniform_parts or None,
)


Expand Down Expand Up @@ -1070,13 +1161,24 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node):

store_type = VectorType.get((elements_per_thread,), element_type)

src_index, src_index_wg, src_index_th = _build_start_indices(
emitter, new_src_idx, src_dynamic_vals_map_start
)
induction_vars = set(emitter.get_induction_vars_and_syms()[1])

ip = InsertionPoint.current
uniform_parts: list[list[Value]] = []
if induction_vars:
src_index, src_index_wg, *uniform_parts, src_index_th = (
_build_start_indices(
emitter,
new_src_idx,
src_dynamic_vals_map_start,
uniform_sym_classes=[induction_vars],
)
)
else:
src_index, src_index_wg, src_index_th = _build_start_indices(
emitter, new_src_idx, src_dynamic_vals_map_start
)

induction_vars = set(emitter.get_induction_vars_and_syms()[1])
ip = InsertionPoint.current

# Hoist to the function level, if not using induction variables.
if not any(
Expand Down Expand Up @@ -1105,6 +1207,9 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node):
src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides)
src = _cast_buffer_and_encode_stride(src, strides, element_type, emitter)

if uniform_parts:
offset_th = _apply_uniform_offsets(offset_th, uniform_parts, strides)

# We previously checked mask is same for all elements, so we can use
# elements_per_thread=1 to build the mask.
mask = _build_mask(
Expand Down
Loading
Loading