Skip to content

Comments

Add support for dynamic strides#891

Draft
suryajasper wants to merge 3 commits intoiree-org:mainfrom
suryajasper:dynamic-strides
Draft

Add support for dynamic strides#891
suryajasper wants to merge 3 commits intoiree-org:mainfrom
suryajasper:dynamic-strides

Conversation

@suryajasper
Copy link
Contributor

@suryajasper suryajasper commented Feb 17, 2026

Problem

Wave treats all tensor arguments as if they are contiguous. For a non-contiguous input tensor such as A = torch.randn((M, K * 4))[:, :K], while the shape is (M, K), the strides are (K * 4, 1). Since wave runtime only passes a pointer to the linearized tensor data in memory without any information on the strides, the input buffers will be addressed incorrectly, resulting in numerical inaccuracies.

Solution

This PR provides universal support for non-contiguous input & output tensor layouts through wave runtime. Wave automatically passes one stride per dimension per buffer through the pipeline and uses them in the kernel IR so loads/stores use the correct layout.

  • Provide extra index arguments for strides
  • Emitter builds memref.reinterpret_cast with those stride values and a layout like strided<[?, 1]> (dynamic leading strides, unit stride on the last dim to satisfy vector.load).
  • During runtime invocation, Wave parses the tensor metadata to pass in the appropriate strides for each dimension; We extend the kernarg buffer in the C++ runtime with a strides section after pointers, scalars, and dynamic dims.

Validation

Added new test: dynamic_strides_test.py

Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
- Removed compile option (enabled by default)
- Added lit test
- Added new pytests for non-contiguous output & strided w/ offsets

Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>

@require_e2e
@require_cdna_3_or_4
@pytest.mark.xfail(reason="Dynamic strides are not supported in the ASM backend yet")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we set dynamic_stride=False instead of xfailing?

Comment on lines +156 to +162
if self.options.wave_runtime:
stride_arg_count = sum(
len(b.kernel_buffer_type.symbolic_shape)
for b in self.root_sig.sig.kernel_buffer_bindings
)
if stride_arg_count > 0:
arg_types += [IndexType.get()] * stride_arg_count
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does dynamic stride not work with non wave_runtime?

Comment on lines +220 to +223
if rank == 1:
stride_vals = [arith_d.constant(IndexType.get(), 1)]
static_strides = [1]
layout = StridedLayoutAttr.get(offset=dyn_val, strides=[1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the latter case not work when rank==1?

Comment on lines +100 to +101
for d in range(arg_tensor.dim()):
stride_values.append(arg_tensor.stride(d))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to do .stride() instead of doing for d in range(arg_tensor.dim())?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants