|
14 | 14 | import logging
|
15 | 15 | import tempfile
|
16 | 16 | from dataclasses import dataclass
|
| 17 | +from math import sqrt |
17 | 18 | from typing import (
|
18 | 19 | Any,
|
19 | 20 | cast,
|
@@ -151,7 +152,9 @@ def _populate_res_params(config: GroupedEmbeddingConfig) -> Tuple[bool, RESParam
|
151 | 152 | return (enable_raw_embedding_streaming, res_params)
|
152 | 153 |
|
153 | 154 |
|
154 |
| -def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: |
| 155 | +def _populate_ssd_tbe_params( |
| 156 | + config: GroupedEmbeddingConfig, is_kvzch: bool = False |
| 157 | +) -> Dict[str, Any]: |
155 | 158 | """
|
156 | 159 | Construct SSD TBE params dict from config and fused params dict.
|
157 | 160 | """
|
@@ -192,6 +195,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
|
192 | 195 | )
|
193 | 196 |
|
194 | 197 | # populate init min and max
|
| 198 | + if is_kvzch: |
| 199 | + _generate_init_range_for_kvzch(ssd_tbe_params, config) |
| 200 | + |
195 | 201 | if (
|
196 | 202 | "ssd_uniform_init_lower" not in ssd_tbe_params
|
197 | 203 | or "ssd_uniform_init_upper" not in ssd_tbe_params
|
@@ -245,6 +251,50 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
|
245 | 251 | return ssd_tbe_params
|
246 | 252 |
|
247 | 253 |
|
| 254 | +def _generate_init_range_for_kvzch( |
| 255 | + tbe_params: Dict[str, Any], |
| 256 | + config: GroupedEmbeddingConfig, |
| 257 | +) -> None: |
| 258 | + """ |
| 259 | + Generate uniform init range for zero collision TBE based |
| 260 | + """ |
| 261 | + # populate init min and max |
| 262 | + if ( |
| 263 | + "ssd_uniform_init_lower" not in tbe_params |
| 264 | + or "ssd_uniform_init_upper" not in tbe_params |
| 265 | + ): |
| 266 | + # Right now we do not support a per table init max and min. To use |
| 267 | + # per table init max and min, either we allow it in SSD TBE, or we |
| 268 | + # create one SSD TBE per table. |
| 269 | + weights_precision = data_type_to_sparse_type(config.data_type) |
| 270 | + |
| 271 | + # For Float32: use mathematically correct values, for Half: use safe range |
| 272 | + max_size = 4_000_000_000 # 4B virtual embeddings |
| 273 | + default_init_range = ( |
| 274 | + (-sqrt(1 / max_size), sqrt(1 / max_size)) |
| 275 | + if weights_precision.as_dtype() == torch.float32 |
| 276 | + else (-0.001, 0.001) |
| 277 | + ) |
| 278 | + |
| 279 | + def get_init_value( |
| 280 | + table_init_val: Optional[float], default_value: float |
| 281 | + ) -> float: |
| 282 | + return table_init_val if table_init_val is not None else default_value |
| 283 | + |
| 284 | + init_mins = [ |
| 285 | + get_init_value(table.weight_init_min, default_init_range[0]) |
| 286 | + for table in config.embedding_tables |
| 287 | + ] |
| 288 | + init_maxs = [ |
| 289 | + get_init_value(table.weight_init_max, default_init_range[1]) |
| 290 | + for table in config.embedding_tables |
| 291 | + ] |
| 292 | + |
| 293 | + num_tables = len(config.embedding_tables) |
| 294 | + tbe_params["ssd_uniform_init_lower"] = sum(init_mins) / num_tables |
| 295 | + tbe_params["ssd_uniform_init_upper"] = sum(init_maxs) / num_tables |
| 296 | + |
| 297 | + |
248 | 298 | def _populate_zero_collision_tbe_params(
|
249 | 299 | tbe_params: Dict[str, Any],
|
250 | 300 | sharded_local_buckets: List[Tuple[int, int, int]],
|
@@ -1878,7 +1928,7 @@ def __init__(
|
1878 | 1928 | "not divisible by 4. "
|
1879 | 1929 | )
|
1880 | 1930 |
|
1881 |
| - ssd_tbe_params = _populate_ssd_tbe_params(config) |
| 1931 | + ssd_tbe_params = _populate_ssd_tbe_params(config, is_kvzch=True) |
1882 | 1932 | self._bucket_spec: List[Tuple[int, int, int]] = (
|
1883 | 1933 | _get_sharded_local_buckets_for_zero_collision(
|
1884 | 1934 | self._config.embedding_tables, self._pg
|
@@ -2758,7 +2808,7 @@ def __init__(
|
2758 | 2808 | "not divisible by 4. "
|
2759 | 2809 | )
|
2760 | 2810 |
|
2761 |
| - ssd_tbe_params = _populate_ssd_tbe_params(config) |
| 2811 | + ssd_tbe_params = _populate_ssd_tbe_params(config, is_kvzch=True) |
2762 | 2812 | self._bucket_spec: List[Tuple[int, int, int]] = (
|
2763 | 2813 | _get_sharded_local_buckets_for_zero_collision(
|
2764 | 2814 | self._config.embedding_tables, self._pg
|
|
0 commit comments