13
13
from functools import partial
14
14
from typing import (
15
15
Any ,
16
- Callable ,
17
16
cast ,
18
17
Dict ,
19
18
Iterator ,
39
38
EmbeddingShardingInfo ,
40
39
KJTListSplitsAwaitable ,
41
40
Multistreamable ,
42
- USE_ONE_TBE_PER_TABLE ,
43
41
)
44
42
from torchrec .distributed .embedding_types import (
45
43
BaseEmbeddingSharder ,
77
75
optimizer_type_to_emb_opt_type ,
78
76
)
79
77
from torchrec .modules .embedding_configs import (
80
- BaseEmbeddingConfig ,
81
78
EmbeddingBagConfig ,
82
79
EmbeddingTableConfig ,
83
80
PoolingType ,
@@ -200,15 +197,7 @@ def create_embedding_bag_sharding(
200
197
raise ValueError (f"Sharding type not supported { sharding_type } " )
201
198
202
199
203
- def get_sharding_group (
204
- config : BaseEmbeddingConfig ,
205
- param_sharding : ParameterSharding ,
206
- fused_params : Optional [Dict [str , Any ]] = None ,
207
- ) -> str :
208
- return param_sharding .sharding_type
209
-
210
-
211
- def create_sharding_infos_by_group (
200
+ def create_sharding_infos_by_sharding (
212
201
module : EmbeddingBagCollectionInterface ,
213
202
table_name_to_parameter_sharding : Dict [str , ParameterSharding ],
214
203
prefix : str ,
@@ -229,7 +218,9 @@ def create_sharding_infos_by_group(
229
218
else :
230
219
shared_feature [feature_name ] = True
231
220
232
- group_to_sharding_infos : Dict [str , List [EmbeddingShardingInfo ]] = defaultdict (list )
221
+ sharding_type_to_sharding_infos : Dict [str , List [EmbeddingShardingInfo ]] = (
222
+ defaultdict (list )
223
+ )
233
224
234
225
# state_dict returns parameter.Tensor, which loses parameter level attributes
235
226
parameter_by_name = dict (module .named_parameters ())
@@ -283,7 +274,6 @@ def create_sharding_infos_by_group(
283
274
)
284
275
per_table_fused_params = convert_to_fbgemm_types (per_table_fused_params )
285
276
286
- group = get_sharding_group (config , parameter_sharding , fused_params )
287
277
sharding_info = EmbeddingShardingInfo (
288
278
embedding_config = EmbeddingTableConfig (
289
279
num_embeddings = config .num_embeddings ,
@@ -303,8 +293,10 @@ def create_sharding_infos_by_group(
303
293
param = param ,
304
294
fused_params = per_table_fused_params ,
305
295
)
306
- group_to_sharding_infos [group ].append (sharding_info )
307
- return group_to_sharding_infos
296
+ sharding_type_to_sharding_infos [parameter_sharding .sharding_type ].append (
297
+ sharding_info
298
+ )
299
+ return sharding_type_to_sharding_infos
308
300
309
301
310
302
def create_sharding_infos_by_sharding_device_group (
@@ -581,7 +573,7 @@ def __init__(
581
573
)
582
574
self ._env = env
583
575
584
- group_to_sharding_infos = create_sharding_infos_by_group (
576
+ sharding_type_to_sharding_infos = create_sharding_infos_by_sharding (
585
577
module ,
586
578
table_name_to_parameter_sharding ,
587
579
"embedding_bags." ,
@@ -602,7 +594,7 @@ def __init__(
602
594
permute_embeddings = True ,
603
595
qcomm_codecs_registry = self .qcomm_codecs_registry ,
604
596
)
605
- for embedding_configs in group_to_sharding_infos .values ()
597
+ for embedding_configs in sharding_type_to_sharding_infos .values ()
606
598
]
607
599
608
600
self ._is_weighted : bool = module .is_weighted ()
0 commit comments