Skip to content

Commit c5144fd

Browse files
joshuadengfacebook-github-bot
authored andcommitted
revert sharding grouping logic for vbe (#2216)
Summary: Pull Request resolved: #2216 reverting sharding grouping logic in EBC/VLE modules that supported specific UVM caching + prefetch pipeline uses cases to circumvent VBE TBE output concatenation. As concatenation is implemented in the preceding diff, this diff cleans up the logic left behind from grouping sharding by UVM caching kernel conditions to avoid VBE TBE output concatenation. Reviewed By: dstaay-fb, levythu Differential Revision: D58989195
1 parent b99849f commit c5144fd

File tree

2 files changed

+16
-24
lines changed

2 files changed

+16
-24
lines changed

torchrec/distributed/embeddingbag.py

+10-18
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from functools import partial
1414
from typing import (
1515
Any,
16-
Callable,
1716
cast,
1817
Dict,
1918
Iterator,
@@ -39,7 +38,6 @@
3938
EmbeddingShardingInfo,
4039
KJTListSplitsAwaitable,
4140
Multistreamable,
42-
USE_ONE_TBE_PER_TABLE,
4341
)
4442
from torchrec.distributed.embedding_types import (
4543
BaseEmbeddingSharder,
@@ -77,7 +75,6 @@
7775
optimizer_type_to_emb_opt_type,
7876
)
7977
from torchrec.modules.embedding_configs import (
80-
BaseEmbeddingConfig,
8178
EmbeddingBagConfig,
8279
EmbeddingTableConfig,
8380
PoolingType,
@@ -200,15 +197,7 @@ def create_embedding_bag_sharding(
200197
raise ValueError(f"Sharding type not supported {sharding_type}")
201198

202199

203-
def get_sharding_group(
204-
config: BaseEmbeddingConfig,
205-
param_sharding: ParameterSharding,
206-
fused_params: Optional[Dict[str, Any]] = None,
207-
) -> str:
208-
return param_sharding.sharding_type
209-
210-
211-
def create_sharding_infos_by_group(
200+
def create_sharding_infos_by_sharding(
212201
module: EmbeddingBagCollectionInterface,
213202
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
214203
prefix: str,
@@ -229,7 +218,9 @@ def create_sharding_infos_by_group(
229218
else:
230219
shared_feature[feature_name] = True
231220

232-
group_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = defaultdict(list)
221+
sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = (
222+
defaultdict(list)
223+
)
233224

234225
# state_dict returns parameter.Tensor, which loses parameter level attributes
235226
parameter_by_name = dict(module.named_parameters())
@@ -283,7 +274,6 @@ def create_sharding_infos_by_group(
283274
)
284275
per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params)
285276

286-
group = get_sharding_group(config, parameter_sharding, fused_params)
287277
sharding_info = EmbeddingShardingInfo(
288278
embedding_config=EmbeddingTableConfig(
289279
num_embeddings=config.num_embeddings,
@@ -303,8 +293,10 @@ def create_sharding_infos_by_group(
303293
param=param,
304294
fused_params=per_table_fused_params,
305295
)
306-
group_to_sharding_infos[group].append(sharding_info)
307-
return group_to_sharding_infos
296+
sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append(
297+
sharding_info
298+
)
299+
return sharding_type_to_sharding_infos
308300

309301

310302
def create_sharding_infos_by_sharding_device_group(
@@ -581,7 +573,7 @@ def __init__(
581573
)
582574
self._env = env
583575

584-
group_to_sharding_infos = create_sharding_infos_by_group(
576+
sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
585577
module,
586578
table_name_to_parameter_sharding,
587579
"embedding_bags.",
@@ -602,7 +594,7 @@ def __init__(
602594
permute_embeddings=True,
603595
qcomm_codecs_registry=self.qcomm_codecs_registry,
604596
)
605-
for embedding_configs in group_to_sharding_infos.values()
597+
for embedding_configs in sharding_type_to_sharding_infos.values()
606598
]
607599

608600
self._is_weighted: bool = module.is_weighted()

torchrec/distributed/planner/tests/test_embeddingbag_utils.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import unittest
1212

1313
from torchrec.distributed.embeddingbag import (
14-
create_sharding_infos_by_group,
14+
create_sharding_infos_by_sharding,
1515
EmbeddingBagCollectionSharder,
1616
)
1717
from torchrec.distributed.planner import (
@@ -79,7 +79,7 @@ def setUp(self) -> None:
7979
)
8080
self.expected_plan = planner.plan(self.model, [self.sharder]) # pyre-ignore[6]
8181

82-
self.expected_sharding_infos = create_sharding_infos_by_group(
82+
self.expected_sharding_infos = create_sharding_infos_by_sharding(
8383
self.model,
8484
self.expected_plan.get_plan_for_module(""), # pyre-ignore[6]
8585
prefix="embedding_bags.",
@@ -93,7 +93,7 @@ def test_create_sharding_infos_by_group_override(self) -> None:
9393

9494
# with sharder fused params that will get overridden
9595
sharder_fused_params = {"enforce_hbm": False}
96-
overriden_sharding_infos = create_sharding_infos_by_group(
96+
overriden_sharding_infos = create_sharding_infos_by_sharding(
9797
self.model,
9898
self.expected_plan.get_plan_for_module(""),
9999
prefix="embedding_bags.",
@@ -106,7 +106,7 @@ def test_create_sharding_infos_by_group_override(self) -> None:
106106

107107
# with sharder fused params that won't get overridden
108108
sharder_fused_params = {"ABC": True}
109-
not_overriden_sharding_infos = create_sharding_infos_by_group(
109+
not_overriden_sharding_infos = create_sharding_infos_by_sharding(
110110
self.model,
111111
self.expected_plan.get_plan_for_module(""),
112112
prefix="embedding_bags.",
@@ -141,7 +141,7 @@ def test_create_sharding_infos_by_group_combine(self) -> None:
141141
# provide that two fused params from sharder
142142
sharder_fused_params = {"enforce_hbm": True, "stochastic_rounding": False}
143143

144-
combined_sharding_infos = create_sharding_infos_by_group(
144+
combined_sharding_infos = create_sharding_infos_by_sharding(
145145
self.model,
146146
new_plan.get_plan_for_module(""), # pyre-ignore[6]
147147
prefix="embedding_bags.",
@@ -156,7 +156,7 @@ def test_create_sharding_infos_by_group_combine(self) -> None:
156156

157157
# provide that two fused params from sharder wrongly
158158
sharder_fused_params = {"enforce_hbm": True, "stochastic_rounding": True}
159-
wrong_combined_sharding_infos = create_sharding_infos_by_group(
159+
wrong_combined_sharding_infos = create_sharding_infos_by_sharding(
160160
self.model,
161161
new_plan.get_plan_for_module(""), # pyre-ignore[6]
162162
prefix="embedding_bags.",

0 commit comments

Comments
 (0)