Draft
Conversation
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
- Removed compile option (enabled by default) - Added lit test - Added new pytests for non-contiguous output & strided w/ offsets Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
raikonenfnu
reviewed
Feb 17, 2026
|
|
||
| @require_e2e | ||
| @require_cdna_3_or_4 | ||
| @pytest.mark.xfail(reason="Dynamic strides are not supported in the ASM backend yet") |
Contributor
There was a problem hiding this comment.
Can we set dynamic_stride=False instead of xfailing?
Comment on lines
+156
to
+162
| if self.options.wave_runtime: | ||
| stride_arg_count = sum( | ||
| len(b.kernel_buffer_type.symbolic_shape) | ||
| for b in self.root_sig.sig.kernel_buffer_bindings | ||
| ) | ||
| if stride_arg_count > 0: | ||
| arg_types += [IndexType.get()] * stride_arg_count |
Contributor
There was a problem hiding this comment.
does dynamic stride not work with non wave_runtime?
Comment on lines
+220
to
+223
| if rank == 1: | ||
| stride_vals = [arith_d.constant(IndexType.get(), 1)] | ||
| static_strides = [1] | ||
| layout = StridedLayoutAttr.get(offset=dyn_val, strides=[1]) |
Contributor
There was a problem hiding this comment.
Does the latter case not work when rank==1?
Comment on lines
+100
to
+101
| for d in range(arg_tensor.dim()): | ||
| stride_values.append(arg_tensor.stride(d)) |
Contributor
There was a problem hiding this comment.
is it possible to do .stride() instead of doing for d in range(arg_tensor.dim())?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Wave treats all tensor arguments as if they are contiguous. For a non-contiguous input tensor such as
A = torch.randn((M, K * 4))[:, :K], while the shape is (M, K), the strides are (K * 4, 1). Since wave runtime only passes a pointer to the linearized tensor data in memory without any information on the strides, the input buffers will be addressed incorrectly, resulting in numerical inaccuracies.Solution
This PR provides universal support for non-contiguous input & output tensor layouts through wave runtime. Wave automatically passes one stride per dimension per buffer through the pipeline and uses them in the kernel IR so loads/stores use the correct layout.
Validation
Added new test:
dynamic_strides_test.py