Skip to content

Commit b75063b

Browse files
optimiseafacebook-github-bot
authored andcommitted
expose rounding_mode in quantization for performance (#3368)
Summary: X-link: facebookresearch/FBGEMM#1884 X-link: pytorch/FBGEMM#4862 Pull Request resolved: #3368 Expose the rounding_mode for mx4 as it could impact the QPS. Previous work was done here. D62466094 ``` class RoundingMode(IntEnum): """Rounding options for quantization.""" nearest = 0 floor = 1 even = 2 stochastic = 3 ceil = 4 ``` https://fburl.com/code/8prz4mem Reviewed By: victor-eds Differential Revision: D82001579 fbshipit-source-id: 872cd8ba62292b95e568ece47ac09052f28ca59e
1 parent 8e7fd24 commit b75063b

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

torchrec/distributed/fbgemm_qcomm_codec.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
QuantizationContext,
2222
QuantizedCommCodec as FbgemmQuantizedCommCodec,
2323
)
24+
from fbgemm_gpu.quantize_utils import RoundingMode
2425
from fbgemm_gpu.split_embedding_configs import SparseType
2526
from torchrec.distributed.types import CommOp, QuantizedCommCodec, QuantizedCommCodecs
2627

@@ -70,6 +71,7 @@ class QCommsConfig:
7071
fp8_bwd_uses_143: Optional[bool] = False
7172
mx4_quantize_dim: Optional[int] = None
7273
mx4_quantize_dim_bwd: Optional[int] = None
74+
mx4_rounding_mode: Optional[RoundingMode] = None
7375

7476
def __post_init__(self) -> None:
7577
if (
@@ -119,10 +121,12 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode
119121
codecs = QuantizedCommCodecs()
120122
if qcomms_config is not None:
121123
row_dim = None
124+
rounding_mode = None
122125
if qcomms_config.forward_precision == CommType.FP8:
123126
row_dim = qcomms_config.fp8_quantize_dim
124127
elif qcomms_config.forward_precision == CommType.MX4:
125128
row_dim = qcomms_config.mx4_quantize_dim
129+
rounding_mode = qcomms_config.mx4_rounding_mode
126130
codecs.forward = cast(
127131
QuantizedCommCodec[QuantizationContext],
128132
FbgemmQuantizedCommCodec(
@@ -132,6 +136,7 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode
132136
loss_scale=qcomms_config.forward_loss_scale,
133137
is_fwd=True,
134138
row_dim=row_dim,
139+
rounding_mode=rounding_mode,
135140
),
136141
)
137142
row_dim_bwd = None
@@ -151,6 +156,7 @@ def get_qcomm_codecs(qcomms_config: Optional[QCommsConfig]) -> QuantizedCommCode
151156
), # if fp8_bwd_uses_143 is True, bwd will use 1-4-3
152157
# if fp8_bwd_uses_143 is False/None, bwd will use 1-5-2
153158
row_dim=row_dim_bwd,
159+
rounding_mode=rounding_mode,
154160
),
155161
)
156162
return codecs

0 commit comments

Comments
 (0)