Conversation
and test for matrix add where it is inferred from writes Signed-off-by: Alex Zinenko <git@ozinenko.com>
| // TODO: pywave just ignores this not sure if we want to, including the | ||
| // case below where there may be zero constraints. Interestingly, it | ||
| // asserts if trailing dimensions are not found when computing the | ||
| // stride... |
There was a problem hiding this comment.
Is it safe to simply ignore the symbols for which there are no constraints when setting index sequences from write?
There was a problem hiding this comment.
If no constraints are specified or the vector shape is not set to 0 (dimensions we don't want to expand), then the symbol either corresponds to the actual tensor dimension or is set dynamically in the kernel. I don't think we should ignore the symbol because it could be meaningful in the analysis.
| emitError() << "expected a single workgroup constraint for dimension " | ||
| << tensorType.getShape()[i] | ||
| << " used in the write operation without explicit " | ||
| "`elements_per_thread`"; | ||
| return failure(); |
There was a problem hiding this comment.
Ditto, but in absence of a workgroup constraint?
It feels like we need to set it to start=0, and likely size=1 and stride=1 but not sure
There was a problem hiding this comment.
I think the code does make some assumptions like that where it falls back to start = 0, size and stride of 1, but I think we shouldn't allow that and instead be more explicit.
| // TODO: in pywave, we always do `startExpr % threadsPerWave` where | ||
| // threadsPerWave == 1 for workgroup dims other than X, making it | ||
| // always zero. It mentions an assumption about the (64, 1, 1) thread | ||
| // shape, but it is unclear whether that assumption always holds. | ||
| // It looks like the intention for this was to express lane ID rather | ||
| // than thread ID, but it is unclear how it accounts for multiple | ||
| // wavefronts running in parallel. |
There was a problem hiding this comment.
The comment in the original source (
wave/wave_lang/kernel/wave/constraints.py
Lines 498 to 501 in 601ab68
There was a problem hiding this comment.
This comes up in the SIMT context (no MMA, you can also see this in the example for the atomic case). If you look at the original code, what was happening was that because we dont have an MMA, the default pattern for SIMT is a thread linear pattern and so for the atomicAdd we were getting a dependence on x and y, even though that shouldn't be the case for the example. So this was a fix to handle that scenario. Will also tag @nithinsubbiah to add more context.
There was a problem hiding this comment.
Okay, in absence of further comments (would have been appreciated), my investigation turns up the following: this is indeed laneId as I suspected and was intentionally added in 2070bcf. However, this has an implicit assumption that a WaveConstraint is present on the same dimension and contributing a component that involves wave_id, which is floordiv(threadId, threadPerWave), to the start expression. In absence of a WaveConstraint, it appears that the start expression will simply be incorrect for the multi-wave-along-X case.
There was a problem hiding this comment.
For other dimensions, I suppose the intent is that they are going to be expanded at which point they should just start with 0 and then the expansion will update them. No need to obfuscate that via modulo operations.
| SmallVector<unsigned> nonThreadLikePositions = | ||
| getPositionsOfSymbols(getNonThreadLikeIndexSymbols(ctx), allSymbols); | ||
| if (isIndexExprMapFunctionOf(difference, threadLikePositions) && | ||
| if (lhsPriority == rhsPriority && |
There was a problem hiding this comment.
the comparison above should have handled this
| // if one has higher priority than another, take the thread-dependent part of | ||
| // the higher-priority one, handle the rest; otherwise handle everything | ||
| SmallVector<AffineExpr> localReplacements = symReplacements; | ||
| for (unsigned position : threadLikePositions) | ||
| localReplacements[position] = zeroExpr; | ||
| SmallVector<Attribute> lhsSymbolsFiltered = llvm::to_vector(lhsSymbols); | ||
| if (lhsPriority > rhsPriority) { | ||
| // nullify thread-like part of RHS | ||
| rhs = rhs.replaceDimsAndSymbols(/*dimReplacements=*/{}, localReplacements, | ||
| /*numResultDims=*/0, | ||
| /*numResultSyms=*/rhs.getNumSymbols()); | ||
| } else if (rhsPriority > lhsPriority) { | ||
| // nullify thread-like part of LHS | ||
| lhs = lhs.replaceDimsAndSymbols(/*dimReplacements=*/{}, localReplacements, | ||
| /*numResultDims=*/0, | ||
| /*numResultSyms=*/lhs.getNumSymbols()); | ||
| lhsSymbolsFiltered = | ||
| llvm::filter_to_vector(lhsSymbolsFiltered, [&](Attribute symbol) { | ||
| return !llvm::is_contained( | ||
| threadLikeIndexSymbols, | ||
| llvm::cast<wave::WaveIndexSymbolAttr>(symbol).getValue()); | ||
| }); | ||
| } | ||
|
|
| pair.first, | ||
| pair.second); | ||
| })), | ||
| std::max(lhs.getPriority(), rhs.getPriority())); |
There was a problem hiding this comment.
do we need priority per-symbol?..
| << " used in the write operation without explicit " | ||
| "`elements_per_thread`"; |
There was a problem hiding this comment.
Nit: at this point, we haven't checked for elements_per_thread yet...
| // XXX: don't report this error immediately since we may be able to proceed | ||
| // without it, e.g., when index expressions can be propagate from non-write | ||
| // operations to this one. This may be a questionable design choice carried | ||
| // over from the initial Python prototype. |
There was a problem hiding this comment.
We still use these in the stride computation below... But it may be wrong
| // TODO: in pywave, we always do `startExpr % threadsPerWave` where | ||
| // threadsPerWave == 1 for workgroup dims other than X, making it | ||
| // always zero. It mentions an assumption about the (64, 1, 1) thread | ||
| // shape, but it is unclear whether that assumption always holds. | ||
| // It looks like the intention for this was to express lane ID rather | ||
| // than thread ID, but it is unclear how it accounts for multiple | ||
| // wavefronts running in parallel. |
There was a problem hiding this comment.
Okay, in absence of further comments (would have been appreciated), my investigation turns up the following: this is indeed laneId as I suspected and was intentionally added in 2070bcf. However, this has an implicit assumption that a WaveConstraint is present on the same dimension and contributing a component that involves wave_id, which is floordiv(threadId, threadPerWave), to the start expression. In absence of a WaveConstraint, it appears that the start expression will simply be incorrect for the multi-wave-along-X case.
| // TODO: in pywave, we always do `startExpr % threadsPerWave` where | ||
| // threadsPerWave == 1 for workgroup dims other than X, making it | ||
| // always zero. It mentions an assumption about the (64, 1, 1) thread | ||
| // shape, but it is unclear whether that assumption always holds. | ||
| // It looks like the intention for this was to express lane ID rather | ||
| // than thread ID, but it is unclear how it accounts for multiple | ||
| // wavefronts running in parallel. |
There was a problem hiding this comment.
For other dimensions, I suppose the intent is that they are going to be expanded at which point they should just start with 0 and then the expansion will update them. No need to obfuscate that via modulo operations.
| int64_t stride = 1; | ||
| for (int64_t j = i + 1; j < e; ++j) { | ||
| Attribute vectorShape = hardwareConstraint.getVectorShapes().get( | ||
| tensorType.getShape()[j].getName()); | ||
| if (!vectorShape) { | ||
| emitError() << "couldn't find vector shape for dimension " | ||
| << tensorType.getShape()[j]; | ||
| return failure(); | ||
| } | ||
| stride *= cast<IntegerAttr>(vectorShape).getValue().getSExtValue(); | ||
| } |
There was a problem hiding this comment.
I have a strong suspicion that usage of strides is inconsistent: here we use linear row-major stride (changed in 1ddf92d), but for MMAs these remain per-dimension strides. That being said, they don't seem to affect code generation at all, which makes me wonder why do we even use them (except for mma-style strides that may be involved in strided write splitting)....
and test for matrix add where it is inferred from writes