Skip to content

Commit 2503b81

Browse files
emlinfacebook-github-bot
authored andcommitted
fix default uniform init range for kvzch (#3336)
Summary: since zch v.Next is using a very large virtual table size, 2^50, the default uniform init value becomes very small, and when the weight dtype is half, those value essentially becomes 0. We have observed the weight init value is all 0 from the debug log: https://fburl.com/mlhub/aea9mbzf {F1981621246} Differential Revision: D81296621
1 parent 85ec396 commit 2503b81

File tree

1 file changed

+53
-3
lines changed

1 file changed

+53
-3
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import logging
1515
import tempfile
1616
from dataclasses import dataclass
17+
from math import sqrt
1718
from typing import (
1819
Any,
1920
cast,
@@ -151,7 +152,9 @@ def _populate_res_params(config: GroupedEmbeddingConfig) -> Tuple[bool, RESParam
151152
return (enable_raw_embedding_streaming, res_params)
152153

153154

154-
def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
155+
def _populate_ssd_tbe_params(
156+
config: GroupedEmbeddingConfig, is_kvzch: bool = False
157+
) -> Dict[str, Any]:
155158
"""
156159
Construct SSD TBE params dict from config and fused params dict.
157160
"""
@@ -192,6 +195,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
192195
)
193196

194197
# populate init min and max
198+
if is_kvzch:
199+
_generate_init_range_for_kvzch(ssd_tbe_params, config)
200+
195201
if (
196202
"ssd_uniform_init_lower" not in ssd_tbe_params
197203
or "ssd_uniform_init_upper" not in ssd_tbe_params
@@ -245,6 +251,50 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
245251
return ssd_tbe_params
246252

247253

254+
def _generate_init_range_for_kvzch(
255+
tbe_params: Dict[str, Any],
256+
config: GroupedEmbeddingConfig,
257+
) -> None:
258+
"""
259+
Generate uniform init range for zero collision TBE based
260+
"""
261+
# populate init min and max
262+
if (
263+
"ssd_uniform_init_lower" not in tbe_params
264+
or "ssd_uniform_init_upper" not in tbe_params
265+
):
266+
# Right now we do not support a per table init max and min. To use
267+
# per table init max and min, either we allow it in SSD TBE, or we
268+
# create one SSD TBE per table.
269+
weights_precision = data_type_to_sparse_type(config.data_type)
270+
271+
# For Float32: use mathematically correct values, for Half: use safe range
272+
max_size = 4_000_000_000 # 4B virtual embeddings
273+
default_init_range = (
274+
(-sqrt(1 / max_size), sqrt(1 / max_size))
275+
if weights_precision.as_dtype() == torch.float32
276+
else (-0.001, 0.001)
277+
)
278+
279+
def get_init_value(
280+
table_init_val: Optional[float], default_value: float
281+
) -> float:
282+
return table_init_val if table_init_val is not None else default_value
283+
284+
init_mins = [
285+
get_init_value(table.weight_init_min, default_init_range[0])
286+
for table in config.embedding_tables
287+
]
288+
init_maxs = [
289+
get_init_value(table.weight_init_max, default_init_range[1])
290+
for table in config.embedding_tables
291+
]
292+
293+
num_tables = len(config.embedding_tables)
294+
tbe_params["ssd_uniform_init_lower"] = sum(init_mins) / num_tables
295+
tbe_params["ssd_uniform_init_upper"] = sum(init_maxs) / num_tables
296+
297+
248298
def _populate_zero_collision_tbe_params(
249299
tbe_params: Dict[str, Any],
250300
sharded_local_buckets: List[Tuple[int, int, int]],
@@ -1878,7 +1928,7 @@ def __init__(
18781928
"not divisible by 4. "
18791929
)
18801930

1881-
ssd_tbe_params = _populate_ssd_tbe_params(config)
1931+
ssd_tbe_params = _populate_ssd_tbe_params(config, is_kvzch=True)
18821932
self._bucket_spec: List[Tuple[int, int, int]] = (
18831933
_get_sharded_local_buckets_for_zero_collision(
18841934
self._config.embedding_tables, self._pg
@@ -2758,7 +2808,7 @@ def __init__(
27582808
"not divisible by 4. "
27592809
)
27602810

2761-
ssd_tbe_params = _populate_ssd_tbe_params(config)
2811+
ssd_tbe_params = _populate_ssd_tbe_params(config, is_kvzch=True)
27622812
self._bucket_spec: List[Tuple[int, int, int]] = (
27632813
_get_sharded_local_buckets_for_zero_collision(
27642814
self._config.embedding_tables, self._pg

0 commit comments

Comments
 (0)