Skip to content

Conversation

@zhongbozhu
Copy link
Collaborator

@zhongbozhu zhongbozhu commented Nov 21, 2025

Description

  • This PR depends on the merge of this PR: [PyTorch][NVFP4][MOE] NVFP4 Grouped Hadamard Amax Kernel #2351 which is a grouped kernel for generating NVFP4 global amax values.
  • This PR is trying to build grouped-quantize kernels. There are two grouped quantize kernels needed for NVFP4: the regular quantize without Hadamard Transform, as well as a Hadamard Transform FP4 Cast fusion. This PR will prioritize the one with Hadamard and maybe support the other one (or in another future PR)

Status:

  • Grouped Rowwise-only quantize without Hadamard (@Oleg-Goncharov for adding transpose mode to it in case RHT get disabled)
  • Grouped Columnwise RHT-cast fusion: numerically correct
  • Grouped Columnwise RHT-cast fusion: optimization of the kernel for better memory SOL (current memory SOL is low). It has been root caused to be related with running high precision math in the epilogue of the RHT gemm. Running fast math will boost the perf but comes with some minor numerical impact due to the difference between x / y != x * (1/y) even for FP32, but divides are slower. Use NVTE_RHT_CAST_FUSION_USE_FAST_MATH NVTE_USE_FAST_MATH env var to control this behavior.
  • Further fuse Rowwise quantize & Colwise RHT transform & Colwise cast transpose into one kernel. This new kernel targeting SOL perf will further assume 128 padding in the token dimension. Current padding to NVFP4 is 64 multiple, but we might have to further increase it to 128 and then try other methods to remove the padding overhead to achieve overall SOL performance. We can eliminate padding overhead using the following tricks in the next section. Megatron PR to also bump up this padding.
  • Make Row-quant and Col-quant use two different random numbers like this PR 2487 already did in unfused path, @negvet to review.

Notes:

  • Why is zero padding bad?
    Zero padding means full tensor memory read and write, and often stems from the scaling factor layout of quantized recipes like mxfp8.

  • How to remove the padding overhead?

  1. Router padding: let the router outputs x multiple of tokens directly so that zero padding becomes an no-op. This method has been integrated to Megatron-core here.
  2. Router padding to 128 multiple sounds scary and will raise concerns about if this will incur numerical issues. A better solution is to fuse the zero padding into the token permute kernel. This is still an on-going work, but Megatron-core with the HybridEP token dispatcher backend already supports this feature to pad to a multiple here.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@zhongbozhu zhongbozhu self-assigned this Nov 21, 2025
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 21, 2025

Greptile Summary

This PR implements grouped NVFP4 quantization kernels with Hadamard Transform (RHT) fusion for MOE workloads. The implementation provides three kernel variants:

Key Changes:

  • Fully-fused kernel (group_row_cast_col_hadamard_transform_cast_fusion.cu): Combines rowwise quantization, columnwise RHT, and columnwise quantization into a single persistent kernel. Requires 128-multiple padding in token dimension for optimal memory SOL performance.
  • Columnwise-only RHT fusion (group_hadamard_transform_cast_fusion.cu): Handles RHT + columnwise quantization with configurable fast math mode via NVTE_USE_FAST_MATH env var. Fast math trades minor numerical precision for better performance.
  • Rowwise-only quantization (group_quantize_transpose_nvfp4.cuh): Grouped rowwise quantization without RHT, supporting multiple splits with proper tensor ID lookup.

RNG State Management:
The PR correctly addresses the critical issue of using different random seeds for rowwise vs columnwise quantization. When the fully-fused kernel cannot be used (non-128-aligned splits), separate RNG states are generated to ensure rowwise and colwise use independent random numbers.

Test Coverage:
Comprehensive tests validate correctness against reference implementation across multiple edge cases: zero tokens at various positions, uneven splits, RHT enabled/disabled, and random sign mask configurations.

Performance Notes:

  • Hardcoded tile size selection based on specific (M, N) shapes in group_hadamard_transform_cast_fusion.cu:938-964
  • Pre-RHT amax path disabled pending verification (line 866 in cast.cpp)
  • 128-multiple padding requirement for full fusion enables better memory SOL but may add overhead that needs mitigation via router padding or fused token permute

Confidence Score: 4/5

  • PR is safe to merge with minor considerations - well-tested core functionality with clear performance/accuracy tradeoffs documented
  • Complex GPU kernel implementation with proper synchronization, memory management, and numerical validation. RNG state handling correctly addresses randomness independence. Comprehensive test coverage validates correctness. Score reduced from 5 due to: (1) hardcoded tile size tuning that doesn't scale, (2) disabled pre-RHT amax path pending verification, (3) fast math mode's numerical impact needs monitoring in production
  • Pay close attention to group_hadamard_transform_cast_fusion.cu if adding new tensor shapes (requires manual tile size tuning). Monitor numerical stability when NVTE_USE_FAST_MATH=1

Important Files Changed

Filename Overview
transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu Fully-fused grouped kernel for rowwise+columnwise RHT quantization with persistent kernel scheduler. Complex but well-structured implementation with proper RNG handling.
transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu Grouped columnwise RHT-cast fusion kernel using persistent tile processing. Implements fast math option for performance at slight numerical cost.
transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh Grouped rowwise quantization kernel with TMA operations. Handles multi-tensor args and separate RNG states correctly.
transformer_engine/pytorch/csrc/extensions/cast.cpp Python bindings for split quantization with proper RNG state management. Correctly routes to fused vs unfused paths based on alignment and quantizer type.
tests/pytorch/nvfp4/test_nvfp4_group_quantize.py Comprehensive test coverage for grouped quantization including edge cases (zero tokens, uneven splits, RHT on/off, random sign masks).

Sequence Diagram

sequenceDiagram
    participant User
    participant PyAPI as Python API<br/>(split_quantize)
    participant CSrc as C++ Bindings<br/>(cast.cpp)
    participant Alloc as Bulk Allocator
    participant RNG as RNG State Setup
    participant Kernel as CUDA Kernels
    
    User->>PyAPI: split_quantize(tensor, splits, quantizers)
    PyAPI->>CSrc: Convert tensors & quantizers
    
    alt NVFP4 Quantizers
        CSrc->>Alloc: bulk_allocate_nvfp4_tensors()
        Alloc-->>CSrc: Contiguous FP4 data + FP8 scales + amax
        
        alt 128-aligned splits + RHT enabled
            CSrc->>RNG: setup RNG states (shared for row+col)
            RNG-->>CSrc: Single RNG state tensor
            CSrc->>Kernel: nvte_group_hadamard_transform_cast_fusion<br/>(fully fused row+col)
            Kernel->>Kernel: Rowwise quantization
            Kernel->>Kernel: RHT + Colwise quantization
        else Unaligned or no RHT
            alt Need separate RNG
                CSrc->>RNG: setup separate row/col RNG states
                RNG-->>CSrc: Two RNG state tensors
            end
            CSrc->>Kernel: nvte_group_nvfp4_quantize_with_amax<br/>(rowwise)
            CSrc->>Kernel: nvte_group_hadamard_transform_cast_fusion_columnwise<br/>(colwise RHT)
        end
    else Other Quantizers
        CSrc->>Kernel: multi_tensor_quantize_impl<br/>(unfused path)
    end
    
    Kernel-->>CSrc: Quantized tensors
    CSrc-->>PyAPI: NVFP4Tensor objects
    PyAPI-->>User: List of quantized tensors
Loading

greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as resolved.

@zhongbozhu zhongbozhu force-pushed the zhongbo/multi_rht_cast_colwise_fuse branch from 68564c7 to 39a8251 Compare November 22, 2025 01:22
greptile-apps[bot]

This comment was marked as resolved.

@zhongbozhu zhongbozhu force-pushed the zhongbo/multi_rht_cast_colwise_fuse branch from 39a8251 to 999a9be Compare November 22, 2025 01:33
greptile-apps[bot]

This comment was marked as outdated.

@ptrendx ptrendx added the MoE label Nov 22, 2025
@zhongbozhu zhongbozhu force-pushed the zhongbo/multi_rht_cast_colwise_fuse branch from 999a9be to ff268af Compare November 22, 2025 01:57
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

16 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +979 to +1028
if (m == 8192 && n == 5120) {
k_tile_size = 512;
} else if (m == 8192 && n == 10240) {
k_tile_size = 1024;
} else if (m == 8192 && n == 2560) {
k_tile_size = 1280;
} else if (m == 8192 && n == 11328) {
k_tile_size = 1024;
} else if (m == 8192 && n == 512) {
k_tile_size = 256;
} else if (m == 8192 && n == 3584) {
k_tile_size = 512;
} else if (m == 11328 && n == 8192) {
k_tile_size = 1024;
} else if (m == 5120 && n == 8192) {
k_tile_size = 512;
} else if (m == 10240 && n == 8192) {
k_tile_size = 1024;
} else if (m == 2560 && n == 8192) {
k_tile_size = 1280;
} else if (m == 512 && n == 8192) {
k_tile_size = 256;
} else if (m == 3584 && n == 8192) {
k_tile_size = 512;
} else if (m < 1024 || n < 1024) {
k_tile_size = 512;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: hardcoded tile sizes for specific tensor shapes

This switch statement hardcodes k_tile_size for specific (M, N) combinations. This approach:

  • Doesn't scale to new shapes
  • Requires manual tuning for each workload
  • Falls back to generic logic that may be suboptimal

Consider implementing a heuristic-based selector that accounts for SM count, memory bandwidth, and tensor dimensions, or move this tuning to a configuration file.

greptile-apps[bot]

This comment was marked as resolved.

@zhongbozhu zhongbozhu force-pushed the zhongbo/multi_rht_cast_colwise_fuse branch from 992400c to fcf32d6 Compare November 25, 2025 06:50
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

20 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +979 to +1028
if (m == 8192 && n == 5120) {
k_tile_size = 512;
} else if (m == 8192 && n == 10240) {
k_tile_size = 1024;
} else if (m == 8192 && n == 2560) {
k_tile_size = 1280;
} else if (m == 8192 && n == 11328) {
k_tile_size = 1024;
} else if (m == 8192 && n == 512) {
k_tile_size = 256;
} else if (m == 8192 && n == 3584) {
k_tile_size = 512;
} else if (m == 11328 && n == 8192) {
k_tile_size = 1024;
} else if (m == 5120 && n == 8192) {
k_tile_size = 512;
} else if (m == 10240 && n == 8192) {
k_tile_size = 1024;
} else if (m == 2560 && n == 8192) {
k_tile_size = 1280;
} else if (m == 512 && n == 8192) {
k_tile_size = 256;
} else if (m == 3584 && n == 8192) {
k_tile_size = 512;
} else if (m < 1024 || n < 1024) {
k_tile_size = 512;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: hardcoded tile sizes don't scale to new shapes

This switch statement hardcodes k_tile_size for specific (M, N) combinations. Consider implementing a heuristic-based selector that accounts for SM count, memory bandwidth, and tensor dimensions, or move this tuning to a configuration file so it can be extended without code changes.

Comment on lines 590 to 617
// TODO(zhongbo): the math operations are very expensive
// since the kernel is persistent, we can have a cache for all the possible scaling factors
if (tensor_id != new_tensor_id) {
global_amax_val = *global_amax_ptr;
global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val);
global_decode_scale = 1.0f / global_encode_scale;
tCgC = thr_mma_epilogue.partition_C(gCs_mn[new_tensor_id]);
gSFC_mn = gSFCs_mn[new_tensor_id];
tensor_id = new_tensor_id;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: math operations for recomputing scaling factors are expensive in persistent kernel

Consider caching pre-computed scaling factors for all possible tensor IDs at kernel start, since num_tensors is known and typically small (≤64). This would avoid recomputing ComputeGlobalEncodeScaleFP4 on every tensor boundary crossing.

greptile-apps[bot]

This comment was marked as outdated.

@zhongbozhu zhongbozhu force-pushed the zhongbo/multi_rht_cast_colwise_fuse branch from f55aea4 to 73954e8 Compare November 26, 2025 05:50
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

20 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

kernel_args.split_sections_range[kernel_args.num_tensors + 1] =
kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i];
// check overflow
NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors + 1] >= 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Overflow check only validates >= 0 but doesn't catch actual int32_t overflow. Consider checking against INT_MAX or using size_t arithmetic. Should this check for actual overflow against INT_MAX rather than just non-negative values?

@zhongbozhu zhongbozhu force-pushed the zhongbo/multi_rht_cast_colwise_fuse branch from 73954e8 to 320cf62 Compare November 26, 2025 20:26
greptile-apps[bot]

This comment was marked as outdated.

@zhongbozhu zhongbozhu force-pushed the zhongbo/multi_rht_cast_colwise_fuse branch from 62dc64b to f3f91de Compare December 2, 2025 05:43
greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

@zhongbozhu zhongbozhu force-pushed the zhongbo/multi_rht_cast_colwise_fuse branch from e0cbb9a to e3320f4 Compare December 3, 2025 19:47
greptile-apps[bot]

This comment was marked as outdated.

@zhongbozhu zhongbozhu force-pushed the zhongbo/multi_rht_cast_colwise_fuse branch from d9bd13d to c08ba2e Compare December 3, 2025 23:39
greptile-apps[bot]

This comment was marked as resolved.

greptile-apps[bot]

This comment was marked as resolved.

@timmoon10 timmoon10 force-pushed the zhongbo/multi_rht_cast_colwise_fuse branch from 6fac043 to 0d245ae Compare December 18, 2025 07:35
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu, line 79-92 (link)

    style: boundary guard prevents out-of-bounds but relies on implicit assumption

    The early return at line 85-87 prevents out-of-bounds access when offset >= boundary, but this assumes the host code correctly sets split_sections_range[num_tensors] as a sentinel value. While this works, the implicit dependency makes the code fragile to future changes.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  2. transformer_engine/pytorch/csrc/extensions/cast.cpp, line 810-811 (link)

    logic: overflow check validates non-negative but doesn't prevent int32_t overflow

    NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors + 1] >= 0, ...) only checks for negative values after overflow has occurred. With cumulative sums, overflow to negative is possible, but overflow within positive range (wrapping past INT_MAX) won't be caught. Should this validate against INT_MAX explicitly, or is the tensor size constrained elsewhere to make this impossible?

  3. transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh, line 293-296 (link)

    style: tensor ID is fetched every stage even when unchanged

    GetTensorId is called for every stage to check if tensor boundary is crossed. While the check if (new_tensor_id != tensor_id) prevents unnecessary updates, the function call and boundary scan happen unconditionally. Could this be optimized by pre-computing which stages cross tensor boundaries, or would the complexity outweigh the benefit?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  4. transformer_engine/pytorch/csrc/extensions/cast.cpp, line 813-815 (link)

    style: document the performance implications of 128-alignment

    The all_aligned_token_dim check determines whether the fully-fused high-performance kernel is used. Users should be aware when they're not getting optimal performance due to alignment. Should there be a runtime warning or logging when falling back to the unfused path due to alignment?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

17 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@timmoon10

This comment was marked as outdated.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu, line 75-86 (link)

    logic: no bounds check after loop termination

    The while loop exits when split_sections_range[tensor_id + 1] <= offset is false, but there's no verification that tensor_id < num_tensors before returning. If host sets incorrect split_sections_range or offset is corrupted, this could return an invalid tensor_id.

  2. transformer_engine/common/hadamard_transform/group_hadamard_transform_cast_fusion.cu, line 1001-1027 (link)

    style: hardcoded tile size tuning doesn't scale

    The switch-case hardcodes optimal k_tile_size for specific tensor shapes. This requires manual updates for new workloads and doesn't handle intermediate sizes well. The fallback logic at lines 1025-1027 only handles small tensors.

    Consider: Move tuning parameters to a config file, or implement auto-tuning based on SM count and tensor dimensions.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  3. transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh, line 810-811 (link)

    logic: overflow check doesn't validate actual overflow

    The check only ensures non-negative, but split_sections_range[i+1] could legitimately overflow int32_t with large tensor sizes or many splits. Use size_t for accumulation or check against INT_MAX:

  4. transformer_engine/pytorch/csrc/extensions/cast.cpp, line 712-760 (link)

    style: resource management for RNG states could leak on exceptions

    The StochasticRngStateResources struct holds tensors and vectors but lacks explicit cleanup. If an exception occurs after partial initialization, resources may not be properly released. Consider adding RAII wrappers or using smart pointers consistently.

20 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Tim Moon <[email protected]>
@timmoon10

This comment was marked as outdated.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu, line 1289-1292 (link)

    logic: potential memory leak if kernel launch fails

    If NVTE_CHECK throws/aborts at line 1290, the workspace allocated at line 1281 won't be freed. Move the free before the status check, or use RAII wrapper.

19 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps

This comment was marked as resolved.

@zhongbozhu

This comment was marked as outdated.

@timmoon10
Copy link
Collaborator

/te-ci L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants