Skip to content

Comments

WIP: priority-based index propagation#734

Draft
ftynse wants to merge 1 commit intomainfrom
users/ftynse/priority-index-expr
Draft

WIP: priority-based index propagation#734
ftynse wants to merge 1 commit intomainfrom
users/ftynse/priority-index-expr

Conversation

@ftynse
Copy link
Contributor

@ftynse ftynse commented Jan 14, 2026

and test for matrix add where it is inferred from writes

and test for matrix add where it is inferred from writes

Signed-off-by: Alex Zinenko <git@ozinenko.com>
Comment on lines +1757 to +1760
// TODO: pywave just ignores this not sure if we want to, including the
// case below where there may be zero constraints. Interestingly, it
// asserts if trailing dimensions are not found when computing the
// stride...
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to simply ignore the symbols for which there are no constraints when setting index sequences from write?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no constraints are specified or the vector shape is not set to 0 (dimensions we don't want to expand), then the symbol either corresponds to the actual tensor dimension or is set dynamically in the kernel. I don't think we should ignore the symbol because it could be meaningful in the analysis.

Comment on lines +1782 to +1786
emitError() << "expected a single workgroup constraint for dimension "
<< tensorType.getShape()[i]
<< " used in the write operation without explicit "
"`elements_per_thread`";
return failure();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, but in absence of a workgroup constraint?

It feels like we need to set it to start=0, and likely size=1 and stride=1 but not sure

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code does make some assumptions like that where it falls back to start = 0, size and stride of 1, but I think we shouldn't allow that and instead be more explicit.

Comment on lines +1861 to +1867
// TODO: in pywave, we always do `startExpr % threadsPerWave` where
// threadsPerWave == 1 for workgroup dims other than X, making it
// always zero. It mentions an assumption about the (64, 1, 1) thread
// shape, but it is unclear whether that assumption always holds.
// It looks like the intention for this was to express lane ID rather
// than thread ID, but it is unclear how it accounts for multiple
// wavefronts running in parallel.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment in the original source (

# We have an assumption that the thread dimensions in each wave is of shape (64,1,1).
# In cases other than dimension 0, we also calculate the modulus of thread_id with the
# number of threads in that dimension to prevent double counting of thread ID in thread
# independent index.
) says something about preventing double counting of thread id, but I can't infer where and why it would be counted twice. The support for it was added in a commit for atomics, 4eeee9a, which is doesn't provide an explanation either

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comes up in the SIMT context (no MMA, you can also see this in the example for the atomic case). If you look at the original code, what was happening was that because we dont have an MMA, the default pattern for SIMT is a thread linear pattern and so for the atomicAdd we were getting a dependence on x and y, even though that shouldn't be the case for the example. So this was a fix to handle that scenario. Will also tag @nithinsubbiah to add more context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, in absence of further comments (would have been appreciated), my investigation turns up the following: this is indeed laneId as I suspected and was intentionally added in 2070bcf. However, this has an implicit assumption that a WaveConstraint is present on the same dimension and contributing a component that involves wave_id, which is floordiv(threadId, threadPerWave), to the start expression. In absence of a WaveConstraint, it appears that the start expression will simply be incorrect for the multi-wave-along-X case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For other dimensions, I suppose the intent is that they are going to be expanded at which point they should just start with 0 and then the expansion will update them. No need to obfuscate that via modulo operations.

SmallVector<unsigned> nonThreadLikePositions =
getPositionsOfSymbols(getNonThreadLikeIndexSymbols(ctx), allSymbols);
if (isIndexExprMapFunctionOf(difference, threadLikePositions) &&
if (lhsPriority == rhsPriority &&
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comparison above should have handled this

Comment on lines +770 to +793
// if one has higher priority than another, take the thread-dependent part of
// the higher-priority one, handle the rest; otherwise handle everything
SmallVector<AffineExpr> localReplacements = symReplacements;
for (unsigned position : threadLikePositions)
localReplacements[position] = zeroExpr;
SmallVector<Attribute> lhsSymbolsFiltered = llvm::to_vector(lhsSymbols);
if (lhsPriority > rhsPriority) {
// nullify thread-like part of RHS
rhs = rhs.replaceDimsAndSymbols(/*dimReplacements=*/{}, localReplacements,
/*numResultDims=*/0,
/*numResultSyms=*/rhs.getNumSymbols());
} else if (rhsPriority > lhsPriority) {
// nullify thread-like part of LHS
lhs = lhs.replaceDimsAndSymbols(/*dimReplacements=*/{}, localReplacements,
/*numResultDims=*/0,
/*numResultSyms=*/lhs.getNumSymbols());
lhsSymbolsFiltered =
llvm::filter_to_vector(lhsSymbolsFiltered, [&](Attribute symbol) {
return !llvm::is_contained(
threadLikeIndexSymbols,
llvm::cast<wave::WaveIndexSymbolAttr>(symbol).getValue());
});
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dead code

pair.first,
pair.second);
})),
std::max(lhs.getPriority(), rhs.getPriority()));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need priority per-symbol?..

Comment on lines +1764 to +1765
<< " used in the write operation without explicit "
"`elements_per_thread`";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: at this point, we haven't checked for elements_per_thread yet...

Comment on lines +1742 to +1745
// XXX: don't report this error immediately since we may be able to proceed
// without it, e.g., when index expressions can be propagate from non-write
// operations to this one. This may be a questionable design choice carried
// over from the initial Python prototype.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still use these in the stride computation below... But it may be wrong

Comment on lines +1861 to +1867
// TODO: in pywave, we always do `startExpr % threadsPerWave` where
// threadsPerWave == 1 for workgroup dims other than X, making it
// always zero. It mentions an assumption about the (64, 1, 1) thread
// shape, but it is unclear whether that assumption always holds.
// It looks like the intention for this was to express lane ID rather
// than thread ID, but it is unclear how it accounts for multiple
// wavefronts running in parallel.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, in absence of further comments (would have been appreciated), my investigation turns up the following: this is indeed laneId as I suspected and was intentionally added in 2070bcf. However, this has an implicit assumption that a WaveConstraint is present on the same dimension and contributing a component that involves wave_id, which is floordiv(threadId, threadPerWave), to the start expression. In absence of a WaveConstraint, it appears that the start expression will simply be incorrect for the multi-wave-along-X case.

Comment on lines +1861 to +1867
// TODO: in pywave, we always do `startExpr % threadsPerWave` where
// threadsPerWave == 1 for workgroup dims other than X, making it
// always zero. It mentions an assumption about the (64, 1, 1) thread
// shape, but it is unclear whether that assumption always holds.
// It looks like the intention for this was to express lane ID rather
// than thread ID, but it is unclear how it accounts for multiple
// wavefronts running in parallel.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For other dimensions, I suppose the intent is that they are going to be expanded at which point they should just start with 0 and then the expansion will update them. No need to obfuscate that via modulo operations.

Comment on lines +1823 to +1833
int64_t stride = 1;
for (int64_t j = i + 1; j < e; ++j) {
Attribute vectorShape = hardwareConstraint.getVectorShapes().get(
tensorType.getShape()[j].getName());
if (!vectorShape) {
emitError() << "couldn't find vector shape for dimension "
<< tensorType.getShape()[j];
return failure();
}
stride *= cast<IntegerAttr>(vectorShape).getValue().getSExtValue();
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a strong suspicion that usage of strides is inconsistent: here we use linear row-major stride (changed in 1ddf92d), but for MMAs these remain per-dimension strides. That being said, they don't seem to affect code generation at all, which makes me wonder why do we even use them (except for mma-style strides that may be involved in strided write splitting)....

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants