From ac679b2e76f7946552e9cac2de48fec07acc5083 Mon Sep 17 00:00:00 2001 From: Armand Sauzay Date: Tue, 9 Dec 2025 11:53:59 -0800 Subject: [PATCH] =?UTF-8?q?Enable=20direct=20MX4=E2=86=92BF16=20dequantiza?= =?UTF-8?q?tion=20to=20reduce=20memory?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2202 Add output_dtype parameter to MX4 dequantization stack to support direct conversion to BF16/FP16, avoiding expensive FP32 intermediate step. Differential Revision: D87826479 --- torchrec/distributed/fbgemm_qcomm_codec.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/torchrec/distributed/fbgemm_qcomm_codec.py b/torchrec/distributed/fbgemm_qcomm_codec.py index 4f191f8cc..933475728 100644 --- a/torchrec/distributed/fbgemm_qcomm_codec.py +++ b/torchrec/distributed/fbgemm_qcomm_codec.py @@ -69,10 +69,12 @@ class QCommsConfig: fp8_quantize_dim: Optional[int] = None fp8_quantize_dim_bwd: Optional[int] = None fp8_bwd_uses_143: Optional[bool] = False - fp8_output_dtype: Optional[SparseType] = None mx4_quantize_dim: Optional[int] = None mx4_quantize_dim_bwd: Optional[int] = None mx4_rounding_mode: Optional[RoundingMode] = None + output_dtype: Optional[SparseType] = ( + None # Unified output dtype for both FP8 and MX4 + ) def __post_init__(self) -> None: if ( @@ -136,9 +138,9 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode "row_dim": row_dim, "rounding_mode": rounding_mode, } - # kwargs approach for bwd compatibility (D86890315) - if qcomms_config.fp8_output_dtype is not None: - forward_kwargs["fp8_output_dtype"] = qcomms_config.fp8_output_dtype + # kwargs approach for bwd compatibility (D87826479) + if qcomms_config.output_dtype is not None: + forward_kwargs["output_dtype"] = qcomms_config.output_dtype codecs.forward = cast( QuantizedCommCodec[QuantizationContext], @@ -161,9 +163,9 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode "row_dim": row_dim_bwd, "rounding_mode": rounding_mode, } - # kwargs approach for bwd compatibility (D86890315) - if qcomms_config.fp8_output_dtype is not None: - backward_kwargs["fp8_output_dtype"] = qcomms_config.fp8_output_dtype + # kwargs approach for bwd compatibility (D87826479) + if qcomms_config.output_dtype is not None: + backward_kwargs["output_dtype"] = qcomms_config.output_dtype codecs.backward = cast( QuantizedCommCodec[QuantizationContext],