9898
9999logger : logging .Logger = logging .getLogger (__name__ )
100100
101+ RES_ENABLED_TABLES_STR = "res_enabled_tables"
102+ RES_STORE_SHARDS_STR = "res_store_shards"
103+ ENABLE_RAW_EMBEDDING_STREAMING_STR = "enable_raw_embedding_streaming"
104+
105+
106+ def _populate_res_params (config : GroupedEmbeddingConfig ) -> Tuple [bool , RESParams ]:
107+ # populate res_params, which is used for raw embedding streaming
108+ # here only populates the params available in fused_params and TBE configs
109+ res_params : RESParams = RESParams ()
110+ fused_params = config .fused_params or {}
111+ # read and clean up the fused_params that are not in the constructor
112+ if RES_STORE_SHARDS_STR in fused_params :
113+ res_params .res_store_shards = fused_params .get (RES_STORE_SHARDS_STR )
114+ del fused_params [RES_STORE_SHARDS_STR ]
115+ res_enabled_tables : Optional [List [str ]] = None
116+ if RES_ENABLED_TABLES_STR in fused_params :
117+ res_enabled_tables = (
118+ fused_params .get (RES_ENABLED_TABLES_STR ).split ("," )
119+ if fused_params .get (RES_ENABLED_TABLES_STR ) is not None
120+ else None
121+ )
122+ del fused_params [RES_ENABLED_TABLES_STR ]
123+ enable_raw_embedding_streaming : Optional [bool ] = None
124+ if ENABLE_RAW_EMBEDDING_STREAMING_STR in fused_params :
125+ enable_raw_embedding_streaming = fused_params .get (
126+ ENABLE_RAW_EMBEDDING_STREAMING_STR
127+ )
128+
129+ if (
130+ enable_raw_embedding_streaming is None
131+ or enable_raw_embedding_streaming is False
132+ ):
133+ return (False , res_params )
134+ res_params .table_names = [table .name for table in config .embedding_tables ]
135+ if res_enabled_tables is not None and len (res_enabled_tables ) != 0 :
136+ if len (set (res_enabled_tables ) & set (res_params .table_names )) == 0 :
137+ logger .info (
138+ f"No table is enabled for raw embedding streaming, "
139+ f"raw embedding streaming is disabled, { res_enabled_tables = } { res_params .table_names = } "
140+ )
141+ return (False , res_params )
142+ res_params .table_offsets = []
143+ for emb_tbl in config .embedding_tables :
144+ local_metadata = emb_tbl .local_metadata
145+ if (
146+ local_metadata is not None
147+ and local_metadata .shard_offsets is not None
148+ and len (local_metadata .shard_offsets ) >= 1
149+ ):
150+ res_params .table_offsets .append (local_metadata .shard_offsets [0 ])
151+ return (enable_raw_embedding_streaming , res_params )
152+
101153
102154def _populate_ssd_tbe_params (config : GroupedEmbeddingConfig ) -> Dict [str , Any ]:
103155 """
@@ -186,22 +238,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
186238 ssd_tbe_params ["cache_sets" ] = int (max_cache_sets )
187239 ssd_tbe_params ["table_names" ] = [table .name for table in config .embedding_tables ]
188240
189- # populate res_params, which is used for raw embedding streaming
190- # here only populates the params available in fused_params and TBE configs
191- res_params : RESParams = RESParams ()
192- res_params .table_names = [table .name for table in config .embedding_tables ]
193- res_params .table_offsets = []
194- for emb_tbl in config .embedding_tables :
195- local_metadata = emb_tbl .local_metadata
196- if (
197- local_metadata is not None
198- and local_metadata .shard_offsets is not None
199- and len (local_metadata .shard_offsets ) >= 1
200- ):
201- res_params .table_offsets .append (local_metadata .shard_offsets [0 ])
202- if "res_store_shards" in fused_params :
203- res_params .res_store_shards = fused_params .get ("res_store_shards" )
241+ enable_res , res_params = _populate_res_params (config )
204242 ssd_tbe_params ["res_params" ] = res_params
243+ ssd_tbe_params [ENABLE_RAW_EMBEDDING_STREAMING_STR ] = enable_res
205244
206245 return ssd_tbe_params
207246
@@ -2190,6 +2229,9 @@ def __init__(
21902229 if "cache_precision" not in fused_params :
21912230 fused_params ["cache_precision" ] = weights_precision
21922231
2232+ enable_res , res_params = _populate_res_params (config )
2233+ fused_params [ENABLE_RAW_EMBEDDING_STREAMING_STR ] = enable_res
2234+
21932235 self ._emb_module : SplitTableBatchedEmbeddingBagsCodegen = (
21942236 SplitTableBatchedEmbeddingBagsCodegen (
21952237 embedding_specs = list (
@@ -2208,6 +2250,7 @@ def __init__(
22082250 self ._col_offset ,
22092251 )
22102252 ),
2253+ res_params = res_params ,
22112254 ** fused_params ,
22122255 )
22132256 )
@@ -3041,6 +3084,10 @@ def __init__(
30413084 fused_params ["cache_precision" ] = weights_precision
30423085 if weights_precision == SparseType .NFP8 :
30433086 fused_params ["cache_precision" ] = SparseType .FP16
3087+
3088+ enable_res , res_params = _populate_res_params (config )
3089+ fused_params [ENABLE_RAW_EMBEDDING_STREAMING_STR ] = enable_res
3090+
30443091 self ._emb_module : SplitTableBatchedEmbeddingBagsCodegen = (
30453092 SplitTableBatchedEmbeddingBagsCodegen (
30463093 embedding_specs = list (
@@ -3059,6 +3106,7 @@ def __init__(
30593106 self ._col_offset ,
30603107 )
30613108 ),
3109+ res_params = res_params ,
30623110 ** fused_params ,
30633111 )
30643112 )
0 commit comments