Skip to content

Commit 6cd43b2

Browse files
emlinfacebook-github-bot
authored andcommitted
fix default uniform init range (#3336)
Summary: Pull Request resolved: #3336 since zch v.Next is using a very large virtual table size, 2^50, the default uniform init value becomes very small, and when the weight dtype is half, those value essentially becomes 0. We have observed the weight init value is all 0 from the debug log: https://fburl.com/mlhub/aea9mbzf {F1981621246} Reviewed By: kathyxuyy Differential Revision: D81296621 fbshipit-source-id: 46a4eb87d6df7a8073efdacffad355b66d683364
1 parent 85ec396 commit 6cd43b2

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import logging
1515
import tempfile
1616
from dataclasses import dataclass
17+
from math import sqrt
1718
from typing import (
1819
Any,
1920
cast,
@@ -192,6 +193,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
192193
)
193194

194195
# populate init min and max
196+
if config.is_using_virtual_table:
197+
_generate_init_range_for_virtual_tables(ssd_tbe_params, config)
198+
195199
if (
196200
"ssd_uniform_init_lower" not in ssd_tbe_params
197201
or "ssd_uniform_init_upper" not in ssd_tbe_params
@@ -245,6 +249,50 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
245249
return ssd_tbe_params
246250

247251

252+
def _generate_init_range_for_virtual_tables(
253+
tbe_params: Dict[str, Any],
254+
config: GroupedEmbeddingConfig,
255+
) -> None:
256+
"""
257+
Generate uniform init range for zero collision TBE based
258+
"""
259+
# populate init min and max
260+
if (
261+
"ssd_uniform_init_lower" not in tbe_params
262+
or "ssd_uniform_init_upper" not in tbe_params
263+
):
264+
# Right now we do not support a per table init max and min. To use
265+
# per table init max and min, either we allow it in SSD TBE, or we
266+
# create one SSD TBE per table.
267+
weights_precision = data_type_to_sparse_type(config.data_type)
268+
269+
# For Float32: use mathematically correct values, for Half: use safe range
270+
max_size = 4_000_000_000 # 4B virtual embeddings
271+
default_init_range = (
272+
(-sqrt(1 / max_size), sqrt(1 / max_size))
273+
if weights_precision.as_dtype() == torch.float32
274+
else (-0.001, 0.001)
275+
)
276+
277+
def get_init_value(
278+
table_init_val: Optional[float], default_value: float
279+
) -> float:
280+
return table_init_val if table_init_val is not None else default_value
281+
282+
init_mins = [
283+
get_init_value(table.weight_init_min, default_init_range[0])
284+
for table in config.embedding_tables
285+
]
286+
init_maxs = [
287+
get_init_value(table.weight_init_max, default_init_range[1])
288+
for table in config.embedding_tables
289+
]
290+
291+
num_tables = len(config.embedding_tables)
292+
tbe_params["ssd_uniform_init_lower"] = sum(init_mins) / num_tables
293+
tbe_params["ssd_uniform_init_upper"] = sum(init_maxs) / num_tables
294+
295+
248296
def _populate_zero_collision_tbe_params(
249297
tbe_params: Dict[str, Any],
250298
sharded_local_buckets: List[Tuple[int, int, int]],
@@ -1872,6 +1920,9 @@ def __init__(
18721920
assert (
18731921
len({table.embedding_dim for table in config.embedding_tables}) == 1
18741922
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
1923+
assert (
1924+
config.is_using_virtual_table
1925+
), "Try to create ZeroCollisionKeyValueEmbedding for non virtual tables"
18751926
for table in config.embedding_tables:
18761927
assert table.local_cols % 4 == 0, (
18771928
f"table {table.name} has local_cols={table.local_cols} "
@@ -2751,6 +2802,9 @@ def __init__(
27512802
assert (
27522803
len({table.embedding_dim for table in config.embedding_tables}) == 1
27532804
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
2805+
assert (
2806+
config.is_using_virtual_table
2807+
), "Try to create ZeroCollisionKeyValueEmbeddingBag for non virtual tables"
27542808

27552809
for table in config.embedding_tables:
27562810
assert table.local_cols % 4 == 0, (

0 commit comments

Comments
 (0)