Skip to content

Commit 9e02ba2

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
enable unit test gpu (#2228)
Summary: # context * use FakeProcessGroup to mimic the multi-process tests * can use `_test_compile_fake_pg_fn` as the single-process VB compile test ``` from torchrec.distributed.tests.test_pt2_multiprocess import _test_compile_fake_pg_fn _test_compile_fake_pg_fn( rank=0, world_size=2, ) ``` reference: D59637444 NOTE: right now only tested for EBC, not sure about other sparse modules like PEA or VLE, which shouldn't be too hard to add similar changes. Differential Revision: D51095381
1 parent 5f8a495 commit 9e02ba2

File tree

4 files changed

+258
-21
lines changed

4 files changed

+258
-21
lines changed

torchrec/distributed/comm_ops.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -559,14 +559,25 @@ def variable_batch_all2all_pooled_sync(
559559
]
560560

561561
with record_function("## alltoall_fwd_single ##"):
562-
sharded_output_embeddings = torch.ops.torchrec.all_to_all_single(
563-
sharded_input_embeddings,
564-
output_split_sizes,
565-
input_split_sizes,
566-
pg_name(pg),
567-
pg.size(),
568-
get_gradient_division(),
569-
)
562+
if pg._get_backend_name() == "fake":
563+
sharded_output_embeddings = torch.empty(
564+
sum(output_split_sizes),
565+
device=sharded_input_embeddings.device,
566+
dtype=sharded_input_embeddings.dtype,
567+
)
568+
s0 = sharded_output_embeddings.size(0)
569+
# Bad assumption that our rank GE than other
570+
torch._check(s0 <= sharded_input_embeddings.size(0))
571+
sharded_output_embeddings.copy_(sharded_input_embeddings[:s0])
572+
else:
573+
sharded_output_embeddings = torch.ops.torchrec.all_to_all_single(
574+
sharded_input_embeddings,
575+
output_split_sizes,
576+
input_split_sizes,
577+
pg_name(pg),
578+
pg.size(),
579+
get_gradient_division(),
580+
)
570581

571582
if a2ai.codecs is not None:
572583
codecs = none_throws(a2ai.codecs)

torchrec/distributed/dist_data.py

+42-12
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,25 @@ def __init__(
239239
# https://github.com/pytorch/pytorch/issues/122788
240240
with record_function("## all2all_data:kjt splits ##"):
241241
input_tensor = torch.stack(input_tensors, dim=1).flatten()
242-
self._output_tensor = dist._functional_collectives.all_to_all_single(
243-
input_tensor,
244-
output_split_sizes=None,
245-
input_split_sizes=None,
246-
group=pg,
247-
)
242+
if pg._get_backend_name() == "fake":
243+
self._output_tensor = torch.empty(
244+
[self.num_workers * len(input_tensors)],
245+
device=input_tensors[0].device,
246+
dtype=input_tensors[0].dtype,
247+
)
248+
249+
self._output_tensor = input_tensor[
250+
: input_tensor.size(0) // 2
251+
].repeat(2)
252+
else:
253+
self._output_tensor = (
254+
dist._functional_collectives.all_to_all_single(
255+
input_tensor,
256+
output_split_sizes=None,
257+
input_split_sizes=None,
258+
group=pg,
259+
)
260+
)
248261
# To avoid hasattr in _wait_impl to check self._splits_awaitable
249262
# pyre-ignore
250263
self._splits_awaitable = None
@@ -342,6 +355,7 @@ def __init__(
342355
self._output_tensors: List[torch.Tensor] = []
343356
self._awaitables: List[dist.Work] = []
344357
self._world_size: int = self._pg.size()
358+
rank = dist.get_rank(self._pg)
345359

346360
for input_split, output_split, input_tensor, label in zip(
347361
input_splits,
@@ -353,12 +367,28 @@ def __init__(
353367
# TODO(ivankobzarev) Remove this dynamo condition once dynamo functional collectives remapping does not emit copy_
354368
# https://github.com/pytorch/pytorch/issues/122788
355369
with record_function(f"## all2all_data:kjt {label} ##"):
356-
output_tensor = dist._functional_collectives.all_to_all_single(
357-
input_tensor,
358-
output_split,
359-
input_split,
360-
pg,
361-
)
370+
if self._pg._get_backend_name() == "fake":
371+
output_tensor = torch.empty(
372+
sum(output_split),
373+
device=self._device,
374+
dtype=input_tensor.dtype,
375+
)
376+
_l = sum(output_split[:rank])
377+
_r = _l + output_split[rank]
378+
torch._check(_r < input_tensor.size(0))
379+
torch._check(_l < input_tensor.size(0))
380+
torch._check(_l <= _r)
381+
torch._check(2 * (_r - _l) == output_tensor.size(0))
382+
output_tensor.copy_(
383+
input_tensor[_l:_r].repeat(self._world_size)
384+
)
385+
else:
386+
output_tensor = dist._functional_collectives.all_to_all_single(
387+
input_tensor,
388+
output_split,
389+
input_split,
390+
pg,
391+
)
362392
self._output_tensors.append(output_tensor)
363393
else:
364394
output_tensor = torch.empty(

torchrec/distributed/embeddingbag.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,9 @@ def __init__(
657657
broadcast_buffers=True,
658658
static_graph=True,
659659
)
660-
self._initialize_torch_state()
660+
661+
if env.process_group and dist.get_backend(env.process_group) != "fake":
662+
self._initialize_torch_state()
661663

662664
# TODO[zainhuda]: support module device coming from CPU
663665
if module.device not in ["meta", "cpu"] and module.device.type not in [

torchrec/distributed/tests/test_pt2_multiprocess.py

+194
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from hypothesis import given, settings, strategies as st, Verbosity
2626
from torch import distributed as dist
2727
from torch._dynamo.testing import reduce_to_scalar_loss
28+
from torch.distributed import ProcessGroup
29+
from torch.testing._internal.distributed.fake_pg import FakeStore
2830
from torchrec.distributed.embedding import EmbeddingCollectionSharder
2931
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
3032
from torchrec.distributed.fbgemm_qcomm_codec import QCommsConfig
@@ -499,6 +501,184 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor:
499501
##### NUMERIC CHECK END #####
500502

501503

504+
def _test_compile_fake_pg_fn(
505+
rank: int,
506+
world_size: int,
507+
) -> None:
508+
sharding_type = ShardingType.TABLE_WISE.value
509+
input_type = _InputType.SINGLE_BATCH
510+
torch_compile_backend = "eager"
511+
config = _TestConfig()
512+
num_embeddings = 256
513+
# emb_dim must be % 4 == 0 for fbgemm
514+
emb_dim = 12
515+
batch_size = 10
516+
num_features: int = 5
517+
518+
num_float_features: int = 8
519+
num_weighted_features: int = 1
520+
521+
device: torch.Device = torch.device("cuda")
522+
store = FakeStore()
523+
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
524+
pg: ProcessGroup = dist.distributed_c10d._get_default_group()
525+
526+
topology: Topology = Topology(world_size=world_size, compute_device="cuda")
527+
mi = TestModelInfo(
528+
dense_device=device,
529+
sparse_device=device,
530+
num_features=num_features,
531+
num_float_features=num_float_features,
532+
num_weighted_features=num_weighted_features,
533+
topology=topology,
534+
)
535+
536+
mi.planner = EmbeddingShardingPlanner(
537+
topology=topology,
538+
batch_size=batch_size,
539+
enumerator=EmbeddingEnumerator(
540+
topology=topology,
541+
batch_size=batch_size,
542+
estimator=[
543+
EmbeddingPerfEstimator(topology=topology),
544+
EmbeddingStorageEstimator(topology=topology),
545+
],
546+
),
547+
)
548+
549+
mi.tables = [
550+
EmbeddingBagConfig(
551+
num_embeddings=num_embeddings,
552+
embedding_dim=emb_dim,
553+
name="table_" + str(i),
554+
feature_names=["feature_" + str(i)],
555+
)
556+
for i in range(mi.num_features)
557+
]
558+
559+
mi.weighted_tables = [
560+
EmbeddingBagConfig(
561+
num_embeddings=num_embeddings,
562+
embedding_dim=emb_dim,
563+
name="weighted_table_" + str(i),
564+
feature_names=["weighted_feature_" + str(i)],
565+
)
566+
for i in range(mi.num_weighted_features)
567+
]
568+
569+
mi.model = _gen_model(_ModelType.EBC, mi)
570+
mi.model.training = True
571+
572+
model = mi.model
573+
574+
planner = EmbeddingShardingPlanner(
575+
topology=Topology(world_size, device.type),
576+
constraints=None,
577+
)
578+
579+
sharders = [
580+
EBCSharderFixedShardingType(sharding_type),
581+
ECSharderFixedShardingType(sharding_type),
582+
]
583+
584+
plan: ShardingPlan = planner.plan(model, sharders) # pyre-ignore
585+
586+
def _dmp(m: torch.nn.Module) -> DistributedModelParallel: # pyre-ignore
587+
return DistributedModelParallel(
588+
m,
589+
env=ShardingEnv(world_size, rank, pg),
590+
plan=plan,
591+
sharders=sharders,
592+
device=device,
593+
init_data_parallel=False,
594+
)
595+
596+
dmp = _dmp(model)
597+
dmp_compile = _dmp(model)
598+
599+
# TODO: Fix some data dependent failures on subsequent inputs
600+
n_extra_numerics_checks = config.n_extra_numerics_checks_inputs
601+
ins = []
602+
603+
for _ in range(1 + n_extra_numerics_checks):
604+
if input_type == _InputType.VARIABLE_BATCH:
605+
(
606+
_,
607+
local_model_inputs,
608+
) = ModelInput.generate_variable_batch_input(
609+
average_batch_size=batch_size,
610+
world_size=world_size,
611+
num_float_features=num_float_features,
612+
# pyre-ignore
613+
tables=mi.tables,
614+
)
615+
else:
616+
(
617+
_,
618+
local_model_inputs,
619+
) = ModelInput.generate(
620+
batch_size=batch_size,
621+
world_size=world_size,
622+
num_float_features=num_float_features,
623+
tables=mi.tables,
624+
weighted_tables=mi.weighted_tables,
625+
variable_batch_size=False,
626+
)
627+
ins.append(local_model_inputs)
628+
629+
local_model_input = ins[0][rank].to(device)
630+
631+
kjt = local_model_input.idlist_features
632+
ff = local_model_input.float_features
633+
ff.requires_grad = True
634+
kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=True)
635+
636+
compile_input_ff = ff.clone().detach()
637+
compile_input_ff.requires_grad = True
638+
639+
torchrec.distributed.comm_ops.set_use_sync_collectives(True)
640+
torchrec.pt2.checks.set_use_torchdynamo_compiling_path(True)
641+
642+
dmp.train(True)
643+
dmp_compile.train(True)
644+
645+
def get_weights(dmp: DistributedModelParallel) -> torch.Tensor:
646+
tbe = dmp._dmp_wrapped_module._ebc._lookups[0]._emb_modules[0]._emb_module
647+
assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen)
648+
return tbe.weights_dev.clone().detach()
649+
650+
original_weights = get_weights(dmp)
651+
original_weights.zero_()
652+
original_compile_weights = get_weights(dmp_compile)
653+
original_compile_weights.zero_()
654+
655+
eager_out = dmp(kjt_ft, ff)
656+
reduce_to_scalar_loss(eager_out).backward()
657+
658+
if torch_compile_backend is None:
659+
return
660+
661+
##### COMPILE #####
662+
with unittest.mock.patch(
663+
"torch._dynamo.config.skip_torchrec",
664+
False,
665+
):
666+
torch._dynamo.config.capture_scalar_outputs = True
667+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
668+
torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt = True
669+
670+
opt_fn = torch.compile(
671+
dmp_compile,
672+
backend=torch_compile_backend,
673+
fullgraph=True,
674+
)
675+
compile_out = opt_fn(
676+
kjt_for_pt2_tracing(kjt, convert_to_vb=True), compile_input_ff
677+
)
678+
torch.testing.assert_close(eager_out, compile_out, atol=1e-3, rtol=1e-3)
679+
##### COMPILE END #####
680+
681+
502682
class TestPt2Train(MultiProcessTestBase):
503683
def disable_cuda_tf32(self) -> bool:
504684
return True
@@ -580,3 +760,17 @@ def test_compile_multiprocess(
580760
config=config,
581761
torch_compile_backend=compile_backend,
582762
)
763+
764+
# pyre-ignore
765+
@unittest.skipIf(
766+
torch.cuda.device_count() < 1,
767+
"Not enough GPUs, this test requires one GPU",
768+
)
769+
@settings(verbosity=Verbosity.verbose, deadline=None)
770+
def test_compile_multiprocess_fake_pg(
771+
self,
772+
) -> None:
773+
_test_compile_fake_pg_fn(
774+
rank=0,
775+
world_size=2,
776+
)

0 commit comments

Comments
 (0)