|
14 | 14 | import logging
|
15 | 15 | import tempfile
|
16 | 16 | from dataclasses import dataclass
|
| 17 | +from math import sqrt |
17 | 18 | from typing import (
|
18 | 19 | Any,
|
19 | 20 | cast,
|
@@ -192,6 +193,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
|
192 | 193 | )
|
193 | 194 |
|
194 | 195 | # populate init min and max
|
| 196 | + if config.is_using_virtual_table: |
| 197 | + _generate_init_range_for_virtual_tables(ssd_tbe_params, config) |
| 198 | + |
195 | 199 | if (
|
196 | 200 | "ssd_uniform_init_lower" not in ssd_tbe_params
|
197 | 201 | or "ssd_uniform_init_upper" not in ssd_tbe_params
|
@@ -245,6 +249,50 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
|
245 | 249 | return ssd_tbe_params
|
246 | 250 |
|
247 | 251 |
|
| 252 | +def _generate_init_range_for_virtual_tables( |
| 253 | + tbe_params: Dict[str, Any], |
| 254 | + config: GroupedEmbeddingConfig, |
| 255 | +) -> None: |
| 256 | + """ |
| 257 | + Generate uniform init range for zero collision TBE based |
| 258 | + """ |
| 259 | + # populate init min and max |
| 260 | + if ( |
| 261 | + "ssd_uniform_init_lower" not in tbe_params |
| 262 | + or "ssd_uniform_init_upper" not in tbe_params |
| 263 | + ): |
| 264 | + # Right now we do not support a per table init max and min. To use |
| 265 | + # per table init max and min, either we allow it in SSD TBE, or we |
| 266 | + # create one SSD TBE per table. |
| 267 | + weights_precision = data_type_to_sparse_type(config.data_type) |
| 268 | + |
| 269 | + # For Float32: use mathematically correct values, for Half: use safe range |
| 270 | + max_size = 4_000_000_000 # 4B virtual embeddings |
| 271 | + default_init_range = ( |
| 272 | + (-sqrt(1 / max_size), sqrt(1 / max_size)) |
| 273 | + if weights_precision.as_dtype() == torch.float32 |
| 274 | + else (-0.001, 0.001) |
| 275 | + ) |
| 276 | + |
| 277 | + def get_init_value( |
| 278 | + table_init_val: Optional[float], default_value: float |
| 279 | + ) -> float: |
| 280 | + return table_init_val if table_init_val is not None else default_value |
| 281 | + |
| 282 | + init_mins = [ |
| 283 | + get_init_value(table.weight_init_min, default_init_range[0]) |
| 284 | + for table in config.embedding_tables |
| 285 | + ] |
| 286 | + init_maxs = [ |
| 287 | + get_init_value(table.weight_init_max, default_init_range[1]) |
| 288 | + for table in config.embedding_tables |
| 289 | + ] |
| 290 | + |
| 291 | + num_tables = len(config.embedding_tables) |
| 292 | + tbe_params["ssd_uniform_init_lower"] = sum(init_mins) / num_tables |
| 293 | + tbe_params["ssd_uniform_init_upper"] = sum(init_maxs) / num_tables |
| 294 | + |
| 295 | + |
248 | 296 | def _populate_zero_collision_tbe_params(
|
249 | 297 | tbe_params: Dict[str, Any],
|
250 | 298 | sharded_local_buckets: List[Tuple[int, int, int]],
|
@@ -1872,6 +1920,9 @@ def __init__(
|
1872 | 1920 | assert (
|
1873 | 1921 | len({table.embedding_dim for table in config.embedding_tables}) == 1
|
1874 | 1922 | ), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
|
| 1923 | + assert ( |
| 1924 | + config.is_using_virtual_table |
| 1925 | + ), "Try to create ZeroCollisionKeyValueEmbedding for non virtual tables" |
1875 | 1926 | for table in config.embedding_tables:
|
1876 | 1927 | assert table.local_cols % 4 == 0, (
|
1877 | 1928 | f"table {table.name} has local_cols={table.local_cols} "
|
@@ -2751,6 +2802,9 @@ def __init__(
|
2751 | 2802 | assert (
|
2752 | 2803 | len({table.embedding_dim for table in config.embedding_tables}) == 1
|
2753 | 2804 | ), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
|
| 2805 | + assert ( |
| 2806 | + config.is_using_virtual_table |
| 2807 | + ), "Try to create ZeroCollisionKeyValueEmbeddingBag for non virtual tables" |
2754 | 2808 |
|
2755 | 2809 | for table in config.embedding_tables:
|
2756 | 2810 | assert table.local_cols % 4 == 0, (
|
|
0 commit comments