Conversation
|
Review updated until commit 14b113a Description
|
| Relevant files | |||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Parameter Validation
|
|
!test |
Greptile SummaryThis PR implements a TMA (Tensor Memory Accelerator) based transpose auto-scheduler for NVIDIA GPUs. The implementation adds a new The core implementation includes:
Previous review threads have identified several edge case handling issues in the heuristics calculation that should be addressed. Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[TransposeScheduler::computeHeuristics] --> B{TmaTranspose enabled?}
B -->|Yes| C[transpose::tma::getTransposeHeuristics]
B -->|No| D[transpose::non_tma::getTransposeHeuristics]
C --> E{tparams != nullptr?}
E -->|No| D
E -->|Yes| F[Return TMA params]
D --> F
F --> G[TransposeScheduler::schedule]
G --> H[transpose::tma::scheduleTranspose]
C --> I[Calculate max_input_dtype_size]
I --> J[Compute tile_size2 = 128 / dtype_size]
J --> K[Compute tile_size1 based on n_input]
K --> L[Compute chunks_per_thread = tile_size1 * 8 / target_bdimx]
L --> M[Compute elements_per_chunk = 16 / dtype_size]
H --> N[Cache inputs/outputs]
N --> O[Setup TMA load ops]
O --> P[TMA tiling with tile_size1/tile_size2]
P --> Q[Apply XOR swizzle to input smem]
Q --> R[Split by elements_per_chunk]
R --> S[Split by chunks_per_thread]
S --> T[Parallelize with TIDx and Unroll]
Last reviewed commit: 14b113a |
Additional Comments (2)
|
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
|
!test |
| // the inner most dim, range is [1, 8] | ||
| // bdimx = tile_size1 * 8 / chunks_per_thread | ||
| const int64_t target_bdimx = (n_input == 1) ? 256 : 128; | ||
| tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx; |
There was a problem hiding this comment.
potential division by zero if target_bdimx is zero (though unlikely given the constants)
Additional Comments (2)
|
…user into llu/transpose_tma_auto2
| // the inner most dim, range is [1, 8] | ||
| // bdimx = tile_size1 * 8 / chunks_per_thread | ||
| const int64_t target_bdimx = (n_input == 1) ? 256 : 128; | ||
| tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx; |
There was a problem hiding this comment.
if tile_size1 is 0 (which could happen if max_input_dtype_size > 128), chunks_per_thread will be 0, causing issues with split at line 217
| tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx; | |
| NVF_CHECK(tparams->tile_size1 > 0, "tile_size1 must be positive"); | |
| tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx; |
| if (max_output_dims == 0 && max_input_dims == 0) { | ||
| return; | ||
| } |
There was a problem hiding this comment.
consider adding validation or logging when both max_output_dims and max_input_dims are 0 before early return
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
!test |
|
!test |
1 similar comment
|
!test |
| //! These can be set through the `NVFUSER_DUMP` environment variable | ||
| //! | ||
| enum class DebugDumpOption { | ||
| enum class DebugDumpOption : std::uint8_t { |
| std::back_inserter(option_values), | ||
| [](const auto& kv) { return kv.first; }); | ||
| std::sort(option_values.begin(), option_values.end()); | ||
| std::ranges::sort(option_values); |
| options_ = other.options_; | ||
| } | ||
| Options(const Options& other) | ||
| : options_([&other]() { |
| tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx; | ||
| tparams->elements_per_chunk = kBytesPerChunk / max_input_dtype_size; |
There was a problem hiding this comment.
if chunks_per_thread is 0 (can happen when tile_size1 * 8 < target_bdimx, e.g., when data type is large), the split at line 218 will fail
| tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx; | |
| tparams->elements_per_chunk = kBytesPerChunk / max_input_dtype_size; | |
| tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx; | |
| if (tparams->chunks_per_thread == 0) { | |
| return nullptr; // Fall back to non-TMA scheduler | |
| } |
| // bdimx = tile_size1 * 8 / chunks_per_thread | ||
| const int64_t target_bdimx = (n_input == 1) ? 256 : 128; | ||
| tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx; | ||
| tparams->elements_per_chunk = kBytesPerChunk / max_input_dtype_size; |
There was a problem hiding this comment.
if elements_per_chunk is 0 (can happen when max_input_dtype_size > 16, e.g., for complex or very large data types), the split at line 216 will fail
| tparams->elements_per_chunk = kBytesPerChunk / max_input_dtype_size; | |
| tparams->elements_per_chunk = kBytesPerChunk / max_input_dtype_size; | |
| if (tparams->elements_per_chunk == 0) { | |
| return nullptr; // Fall back to non-TMA scheduler | |
| } |
| for (auto inp : ir_utils::filterByType<TensorView>(fusion->inputs())) { | ||
| max_input_dtype_size = std::max( | ||
| max_input_dtype_size, | ||
| dataTypeSizeByte(valueOrError(inp->getDataType()))); | ||
| n_input++; | ||
| } |
There was a problem hiding this comment.
if no inputs exist (n_input == 0), the heuristics are still computed and returned, potentially leading to incorrect scheduling
| for (auto inp : ir_utils::filterByType<TensorView>(fusion->inputs())) { | |
| max_input_dtype_size = std::max( | |
| max_input_dtype_size, | |
| dataTypeSizeByte(valueOrError(inp->getDataType()))); | |
| n_input++; | |
| } | |
| for (auto inp : ir_utils::filterByType<TensorView>(fusion->inputs())) { | |
| max_input_dtype_size = std::max( | |
| max_input_dtype_size, | |
| dataTypeSizeByte(valueOrError(inp->getDataType()))); | |
| n_input++; | |
| } | |
| if (n_input == 0) { | |
| return nullptr; // No inputs to transpose | |
| } |
| // [Tunable] In 128-bytes swizzled tma load, inner most dim is split into 8 | ||
| // chunks each with 16 bytes. Each thread may handle multiple chunks along | ||
| // the inner most dim, range is [1, 8] | ||
| // bdimx = tile_size1 * 8 / chunks_per_thread |
There was a problem hiding this comment.
comment formula is backwards - should be chunks_per_thread = tile_size1 * 8 / bdimx not bdimx = tile_size1 * 8 / chunks_per_thread
| // bdimx = tile_size1 * 8 / chunks_per_thread | |
| // chunks_per_thread = tile_size1 * 8 / target_bdimx |
The heuristics is a basic version and current performance is in this doc.