Skip to content

Conversation

@nastya236
Copy link
Contributor

@nastya236 nastya236 commented Jan 20, 2026

Add tensor fp32 scale for nvfp4 quantization.
Python API changes:

  • quantize returns tensor_scale if mode==nvfp4
  • dequantize inputs tensor_scale if mode==nvfp4
  • qqmm inputs 4 arrays if weights provided quantized and mode==nvfp4: quantized weights, scales and tensor_scale

Changes to the quantization:

  • Quantize::eval_gpu quantize : 2 kernels launched: 1) all_reduce with AbsMax reduction 2) use computed absmax in the fp_quantize kernel
  • Quantize::eval_gpu dequantize: inputs absmax and uses it in fp_dequantize.

Changes to qqmm:

  • quantize as described above
  • use tensor_amax_x and tensor_amax_w to compute alpha for BLAS operation. Also allocate beta = 0 because cublaslt requires both pointers to be on device or host.

It still looks ugly.

@nastya236 nastya236 closed this Jan 20, 2026
@nastya236 nastya236 reopened this Jan 20, 2026
@nastya236 nastya236 changed the title Tensor scale nvfp4 [WIP] Tensor scale nvfp4 Jan 20, 2026
* w_q (array): The quantized version of ``w``
* scales (array): The quantization scales
* tensor_scale (array): The per-tensor float32 absolute max
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a nit, but I don't really love the name tensor_scale. Partly because we call them array in MLX. What do you think about global_scale, or maybe array_scale?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants