Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions torchrec/distributed/fbgemm_qcomm_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ class QCommsConfig:
fp8_quantize_dim: Optional[int] = None
fp8_quantize_dim_bwd: Optional[int] = None
fp8_bwd_uses_143: Optional[bool] = False
fp8_output_dtype: Optional[SparseType] = None
mx4_quantize_dim: Optional[int] = None
mx4_quantize_dim_bwd: Optional[int] = None
mx4_rounding_mode: Optional[RoundingMode] = None
output_dtype: Optional[SparseType] = (
None # Unified output dtype for both FP8 and MX4
)

def __post_init__(self) -> None:
if (
Expand Down Expand Up @@ -136,9 +138,9 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode
"row_dim": row_dim,
"rounding_mode": rounding_mode,
}
# kwargs approach for bwd compatibility (D86890315)
if qcomms_config.fp8_output_dtype is not None:
forward_kwargs["fp8_output_dtype"] = qcomms_config.fp8_output_dtype
# kwargs approach for bwd compatibility (D87826479)
if qcomms_config.output_dtype is not None:
forward_kwargs["output_dtype"] = qcomms_config.output_dtype

codecs.forward = cast(
QuantizedCommCodec[QuantizationContext],
Expand All @@ -161,9 +163,9 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode
"row_dim": row_dim_bwd,
"rounding_mode": rounding_mode,
}
# kwargs approach for bwd compatibility (D86890315)
if qcomms_config.fp8_output_dtype is not None:
backward_kwargs["fp8_output_dtype"] = qcomms_config.fp8_output_dtype
# kwargs approach for bwd compatibility (D87826479)
if qcomms_config.output_dtype is not None:
backward_kwargs["output_dtype"] = qcomms_config.output_dtype

codecs.backward = cast(
QuantizedCommCodec[QuantizationContext],
Expand Down
Loading