-
Notifications
You must be signed in to change notification settings - Fork 583
[PyTorch][NVFP4][MOE] NVFP4 Grouped Quantize with Hadamard Transform #2411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[PyTorch][NVFP4][MOE] NVFP4 Grouped Quantize with Hadamard Transform #2411
Conversation
Greptile SummaryThis PR implements grouped NVFP4 quantization kernels with Hadamard Transform (RHT) fusion for MOE workloads. The implementation provides three kernel variants: Key Changes:
RNG State Management: Test Coverage: Performance Notes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant PyAPI as Python API<br/>(split_quantize)
participant CSrc as C++ Bindings<br/>(cast.cpp)
participant Alloc as Bulk Allocator
participant RNG as RNG State Setup
participant Kernel as CUDA Kernels
User->>PyAPI: split_quantize(tensor, splits, quantizers)
PyAPI->>CSrc: Convert tensors & quantizers
alt NVFP4 Quantizers
CSrc->>Alloc: bulk_allocate_nvfp4_tensors()
Alloc-->>CSrc: Contiguous FP4 data + FP8 scales + amax
alt 128-aligned splits + RHT enabled
CSrc->>RNG: setup RNG states (shared for row+col)
RNG-->>CSrc: Single RNG state tensor
CSrc->>Kernel: nvte_group_hadamard_transform_cast_fusion<br/>(fully fused row+col)
Kernel->>Kernel: Rowwise quantization
Kernel->>Kernel: RHT + Colwise quantization
else Unaligned or no RHT
alt Need separate RNG
CSrc->>RNG: setup separate row/col RNG states
RNG-->>CSrc: Two RNG state tensors
end
CSrc->>Kernel: nvte_group_nvfp4_quantize_with_amax<br/>(rowwise)
CSrc->>Kernel: nvte_group_hadamard_transform_cast_fusion_columnwise<br/>(colwise RHT)
end
else Other Quantizers
CSrc->>Kernel: multi_tensor_quantize_impl<br/>(unfused path)
end
Kernel-->>CSrc: Quantized tensors
CSrc-->>PyAPI: NVFP4Tensor objects
PyAPI-->>User: List of quantized tensors
|
68564c7 to
39a8251
Compare
39a8251 to
999a9be
Compare
999a9be to
ff268af
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
16 files reviewed, 3 comments
transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu
Outdated
Show resolved
Hide resolved
| if (m == 8192 && n == 5120) { | ||
| k_tile_size = 512; | ||
| } else if (m == 8192 && n == 10240) { | ||
| k_tile_size = 1024; | ||
| } else if (m == 8192 && n == 2560) { | ||
| k_tile_size = 1280; | ||
| } else if (m == 8192 && n == 11328) { | ||
| k_tile_size = 1024; | ||
| } else if (m == 8192 && n == 512) { | ||
| k_tile_size = 256; | ||
| } else if (m == 8192 && n == 3584) { | ||
| k_tile_size = 512; | ||
| } else if (m == 11328 && n == 8192) { | ||
| k_tile_size = 1024; | ||
| } else if (m == 5120 && n == 8192) { | ||
| k_tile_size = 512; | ||
| } else if (m == 10240 && n == 8192) { | ||
| k_tile_size = 1024; | ||
| } else if (m == 2560 && n == 8192) { | ||
| k_tile_size = 1280; | ||
| } else if (m == 512 && n == 8192) { | ||
| k_tile_size = 256; | ||
| } else if (m == 3584 && n == 8192) { | ||
| k_tile_size = 512; | ||
| } else if (m < 1024 || n < 1024) { | ||
| k_tile_size = 512; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: hardcoded tile sizes for specific tensor shapes
This switch statement hardcodes k_tile_size for specific (M, N) combinations. This approach:
- Doesn't scale to new shapes
- Requires manual tuning for each workload
- Falls back to generic logic that may be suboptimal
Consider implementing a heuristic-based selector that accounts for SM count, memory bandwidth, and tensor dimensions, or move this tuning to a configuration file.
transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu
Outdated
Show resolved
Hide resolved
992400c to
fcf32d6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
20 files reviewed, 6 comments
transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu
Outdated
Show resolved
Hide resolved
| if (m == 8192 && n == 5120) { | ||
| k_tile_size = 512; | ||
| } else if (m == 8192 && n == 10240) { | ||
| k_tile_size = 1024; | ||
| } else if (m == 8192 && n == 2560) { | ||
| k_tile_size = 1280; | ||
| } else if (m == 8192 && n == 11328) { | ||
| k_tile_size = 1024; | ||
| } else if (m == 8192 && n == 512) { | ||
| k_tile_size = 256; | ||
| } else if (m == 8192 && n == 3584) { | ||
| k_tile_size = 512; | ||
| } else if (m == 11328 && n == 8192) { | ||
| k_tile_size = 1024; | ||
| } else if (m == 5120 && n == 8192) { | ||
| k_tile_size = 512; | ||
| } else if (m == 10240 && n == 8192) { | ||
| k_tile_size = 1024; | ||
| } else if (m == 2560 && n == 8192) { | ||
| k_tile_size = 1280; | ||
| } else if (m == 512 && n == 8192) { | ||
| k_tile_size = 256; | ||
| } else if (m == 3584 && n == 8192) { | ||
| k_tile_size = 512; | ||
| } else if (m < 1024 || n < 1024) { | ||
| k_tile_size = 512; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: hardcoded tile sizes don't scale to new shapes
This switch statement hardcodes k_tile_size for specific (M, N) combinations. Consider implementing a heuristic-based selector that accounts for SM count, memory bandwidth, and tensor dimensions, or move this tuning to a configuration file so it can be extended without code changes.
transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu
Outdated
Show resolved
Hide resolved
| // TODO(zhongbo): the math operations are very expensive | ||
| // since the kernel is persistent, we can have a cache for all the possible scaling factors | ||
| if (tensor_id != new_tensor_id) { | ||
| global_amax_val = *global_amax_ptr; | ||
| global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); | ||
| global_decode_scale = 1.0f / global_encode_scale; | ||
| tCgC = thr_mma_epilogue.partition_C(gCs_mn[new_tensor_id]); | ||
| gSFC_mn = gSFCs_mn[new_tensor_id]; | ||
| tensor_id = new_tensor_id; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: math operations for recomputing scaling factors are expensive in persistent kernel
Consider caching pre-computed scaling factors for all possible tensor IDs at kernel start, since num_tensors is known and typically small (≤64). This would avoid recomputing ComputeGlobalEncodeScaleFP4 on every tensor boundary crossing.
f55aea4 to
73954e8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
20 files reviewed, 8 comments
transformer_engine/common/hadamard_transform/group_hadamard_transform.cu
Show resolved
Hide resolved
transformer_engine/common/hadamard_transform/group_hadamard_transform.cu
Outdated
Show resolved
Hide resolved
| kernel_args.split_sections_range[kernel_args.num_tensors + 1] = | ||
| kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i]; | ||
| // check overflow | ||
| NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors + 1] >= 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Overflow check only validates >= 0 but doesn't catch actual int32_t overflow. Consider checking against INT_MAX or using size_t arithmetic. Should this check for actual overflow against INT_MAX rather than just non-negative values?
73954e8 to
320cf62
Compare
62dc64b to
f3f91de
Compare
e0cbb9a to
e3320f4
Compare
d9bd13d to
c08ba2e
Compare
Signed-off-by: Tim Moon <[email protected]>
6fac043 to
0d245ae
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (4)
-
transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu, line 79-92 (link)style: boundary guard prevents out-of-bounds but relies on implicit assumption
The early return at line 85-87 prevents out-of-bounds access when
offset >= boundary, but this assumes the host code correctly setssplit_sections_range[num_tensors]as a sentinel value. While this works, the implicit dependency makes the code fragile to future changes.Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
transformer_engine/pytorch/csrc/extensions/cast.cpp, line 810-811 (link)logic: overflow check validates non-negative but doesn't prevent int32_t overflow
NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors + 1] >= 0, ...)only checks for negative values after overflow has occurred. With cumulative sums, overflow to negative is possible, but overflow within positive range (wrapping past INT_MAX) won't be caught. Should this validate against INT_MAX explicitly, or is the tensor size constrained elsewhere to make this impossible? -
transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh, line 293-296 (link)style: tensor ID is fetched every stage even when unchanged
GetTensorIdis called for every stage to check if tensor boundary is crossed. While the checkif (new_tensor_id != tensor_id)prevents unnecessary updates, the function call and boundary scan happen unconditionally. Could this be optimized by pre-computing which stages cross tensor boundaries, or would the complexity outweigh the benefit?Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
transformer_engine/pytorch/csrc/extensions/cast.cpp, line 813-815 (link)style: document the performance implications of 128-alignment
The
all_aligned_token_dimcheck determines whether the fully-fused high-performance kernel is used. Users should be aware when they're not getting optimal performance due to alignment. Should there be a runtime warning or logging when falling back to the unfused path due to alignment?Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
17 files reviewed, 4 comments
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Zhongbo Zhu <[email protected]>
Signed-off-by: Zhongbo Zhu <[email protected]>
Signed-off-by: Zhongbo Zhu <[email protected]>
…e output Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (4)
-
transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu, line 75-86 (link)logic: no bounds check after loop termination
The while loop exits when
split_sections_range[tensor_id + 1] <= offsetis false, but there's no verification thattensor_id < num_tensorsbefore returning. If host sets incorrectsplit_sections_rangeor offset is corrupted, this could return an invalidtensor_id. -
transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu, line 1001-1027 (link)style: hardcoded tile size tuning doesn't scale
The switch-case hardcodes optimal
k_tile_sizefor specific tensor shapes. This requires manual updates for new workloads and doesn't handle intermediate sizes well. The fallback logic at lines 1025-1027 only handles small tensors.Consider: Move tuning parameters to a config file, or implement auto-tuning based on SM count and tensor dimensions.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh, line 810-811 (link)logic: overflow check doesn't validate actual overflow
The check only ensures non-negative, but
split_sections_range[i+1]could legitimately overflow int32_t with large tensor sizes or many splits. Usesize_tfor accumulation or check againstINT_MAX: -
transformer_engine/pytorch/csrc/extensions/cast.cpp, line 712-760 (link)style: resource management for RNG states could leak on exceptions
The
StochasticRngStateResourcesstruct holds tensors and vectors but lacks explicit cleanup. If an exception occurs after partial initialization, resources may not be properly released. Consider adding RAII wrappers or using smart pointers consistently.
20 files reviewed, 4 comments
...former_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
Show resolved
Hide resolved
Signed-off-by: Tim Moon <[email protected]>
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu, line 1289-1292 (link)logic: potential memory leak if kernel launch fails
If
NVTE_CHECKthrows/aborts at line 1290, the workspace allocated at line 1281 won't be freed. Move the free before the status check, or use RAII wrapper.
19 files reviewed, 1 comment
Signed-off-by: Zhongbo Zhu <[email protected]>
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L1 |
Description
Status:
NVTE_RHT_CAST_FUSION_USE_FAST_MATHNVTE_USE_FAST_MATHenv var to control this behavior.Notes:
Why is zero padding bad?
Zero padding means full tensor memory read and write, and often stems from the scaling factor layout of quantized recipes like mxfp8.
How to remove the padding overhead?
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: