|
25 | 25 | from hypothesis import given, settings, strategies as st, Verbosity
|
26 | 26 | from torch import distributed as dist
|
27 | 27 | from torch._dynamo.testing import reduce_to_scalar_loss
|
| 28 | +from torch.distributed import ProcessGroup |
| 29 | +from torch.testing._internal.distributed.fake_pg import FakeStore |
28 | 30 | from torchrec.distributed.embedding import EmbeddingCollectionSharder
|
29 | 31 | from torchrec.distributed.embedding_types import EmbeddingComputeKernel
|
30 | 32 | from torchrec.distributed.fbgemm_qcomm_codec import QCommsConfig
|
@@ -499,6 +501,184 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor:
|
499 | 501 | ##### NUMERIC CHECK END #####
|
500 | 502 |
|
501 | 503 |
|
| 504 | +def _test_compile_fake_pg_fn( |
| 505 | + rank: int, |
| 506 | + world_size: int, |
| 507 | +) -> None: |
| 508 | + sharding_type = ShardingType.TABLE_WISE.value |
| 509 | + input_type = _InputType.SINGLE_BATCH |
| 510 | + torch_compile_backend = "eager" |
| 511 | + config = _TestConfig() |
| 512 | + num_embeddings = 256 |
| 513 | + # emb_dim must be % 4 == 0 for fbgemm |
| 514 | + emb_dim = 12 |
| 515 | + batch_size = 10 |
| 516 | + num_features: int = 5 |
| 517 | + |
| 518 | + num_float_features: int = 8 |
| 519 | + num_weighted_features: int = 1 |
| 520 | + |
| 521 | + device: torch.Device = torch.device("cuda") |
| 522 | + store = FakeStore() |
| 523 | + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) |
| 524 | + pg: ProcessGroup = dist.distributed_c10d._get_default_group() |
| 525 | + |
| 526 | + topology: Topology = Topology(world_size=world_size, compute_device="cuda") |
| 527 | + mi = TestModelInfo( |
| 528 | + dense_device=device, |
| 529 | + sparse_device=device, |
| 530 | + num_features=num_features, |
| 531 | + num_float_features=num_float_features, |
| 532 | + num_weighted_features=num_weighted_features, |
| 533 | + topology=topology, |
| 534 | + ) |
| 535 | + |
| 536 | + mi.planner = EmbeddingShardingPlanner( |
| 537 | + topology=topology, |
| 538 | + batch_size=batch_size, |
| 539 | + enumerator=EmbeddingEnumerator( |
| 540 | + topology=topology, |
| 541 | + batch_size=batch_size, |
| 542 | + estimator=[ |
| 543 | + EmbeddingPerfEstimator(topology=topology), |
| 544 | + EmbeddingStorageEstimator(topology=topology), |
| 545 | + ], |
| 546 | + ), |
| 547 | + ) |
| 548 | + |
| 549 | + mi.tables = [ |
| 550 | + EmbeddingBagConfig( |
| 551 | + num_embeddings=num_embeddings, |
| 552 | + embedding_dim=emb_dim, |
| 553 | + name="table_" + str(i), |
| 554 | + feature_names=["feature_" + str(i)], |
| 555 | + ) |
| 556 | + for i in range(mi.num_features) |
| 557 | + ] |
| 558 | + |
| 559 | + mi.weighted_tables = [ |
| 560 | + EmbeddingBagConfig( |
| 561 | + num_embeddings=num_embeddings, |
| 562 | + embedding_dim=emb_dim, |
| 563 | + name="weighted_table_" + str(i), |
| 564 | + feature_names=["weighted_feature_" + str(i)], |
| 565 | + ) |
| 566 | + for i in range(mi.num_weighted_features) |
| 567 | + ] |
| 568 | + |
| 569 | + mi.model = _gen_model(_ModelType.EBC, mi) |
| 570 | + mi.model.training = True |
| 571 | + |
| 572 | + model = mi.model |
| 573 | + |
| 574 | + planner = EmbeddingShardingPlanner( |
| 575 | + topology=Topology(world_size, device.type), |
| 576 | + constraints=None, |
| 577 | + ) |
| 578 | + |
| 579 | + sharders = [ |
| 580 | + EBCSharderFixedShardingType(sharding_type), |
| 581 | + ECSharderFixedShardingType(sharding_type), |
| 582 | + ] |
| 583 | + |
| 584 | + plan: ShardingPlan = planner.plan(model, sharders) # pyre-ignore |
| 585 | + |
| 586 | + def _dmp(m: torch.nn.Module) -> DistributedModelParallel: # pyre-ignore |
| 587 | + return DistributedModelParallel( |
| 588 | + m, |
| 589 | + env=ShardingEnv(world_size, rank, pg), |
| 590 | + plan=plan, |
| 591 | + sharders=sharders, |
| 592 | + device=device, |
| 593 | + init_data_parallel=False, |
| 594 | + ) |
| 595 | + |
| 596 | + dmp = _dmp(model) |
| 597 | + dmp_compile = _dmp(model) |
| 598 | + |
| 599 | + # TODO: Fix some data dependent failures on subsequent inputs |
| 600 | + n_extra_numerics_checks = config.n_extra_numerics_checks_inputs |
| 601 | + ins = [] |
| 602 | + |
| 603 | + for _ in range(1 + n_extra_numerics_checks): |
| 604 | + if input_type == _InputType.VARIABLE_BATCH: |
| 605 | + ( |
| 606 | + _, |
| 607 | + local_model_inputs, |
| 608 | + ) = ModelInput.generate_variable_batch_input( |
| 609 | + average_batch_size=batch_size, |
| 610 | + world_size=world_size, |
| 611 | + num_float_features=num_float_features, |
| 612 | + # pyre-ignore |
| 613 | + tables=mi.tables, |
| 614 | + ) |
| 615 | + else: |
| 616 | + ( |
| 617 | + _, |
| 618 | + local_model_inputs, |
| 619 | + ) = ModelInput.generate( |
| 620 | + batch_size=batch_size, |
| 621 | + world_size=world_size, |
| 622 | + num_float_features=num_float_features, |
| 623 | + tables=mi.tables, |
| 624 | + weighted_tables=mi.weighted_tables, |
| 625 | + variable_batch_size=False, |
| 626 | + ) |
| 627 | + ins.append(local_model_inputs) |
| 628 | + |
| 629 | + local_model_input = ins[0][rank].to(device) |
| 630 | + |
| 631 | + kjt = local_model_input.idlist_features |
| 632 | + ff = local_model_input.float_features |
| 633 | + ff.requires_grad = True |
| 634 | + kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=True) |
| 635 | + |
| 636 | + compile_input_ff = ff.clone().detach() |
| 637 | + compile_input_ff.requires_grad = True |
| 638 | + |
| 639 | + torchrec.distributed.comm_ops.set_use_sync_collectives(True) |
| 640 | + torchrec.pt2.checks.set_use_torchdynamo_compiling_path(True) |
| 641 | + |
| 642 | + dmp.train(True) |
| 643 | + dmp_compile.train(True) |
| 644 | + |
| 645 | + def get_weights(dmp: DistributedModelParallel) -> torch.Tensor: |
| 646 | + tbe = dmp._dmp_wrapped_module._ebc._lookups[0]._emb_modules[0]._emb_module |
| 647 | + assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen) |
| 648 | + return tbe.weights_dev.clone().detach() |
| 649 | + |
| 650 | + original_weights = get_weights(dmp) |
| 651 | + original_weights.zero_() |
| 652 | + original_compile_weights = get_weights(dmp_compile) |
| 653 | + original_compile_weights.zero_() |
| 654 | + |
| 655 | + eager_out = dmp(kjt_ft, ff) |
| 656 | + reduce_to_scalar_loss(eager_out).backward() |
| 657 | + |
| 658 | + if torch_compile_backend is None: |
| 659 | + return |
| 660 | + |
| 661 | + ##### COMPILE ##### |
| 662 | + with unittest.mock.patch( |
| 663 | + "torch._dynamo.config.skip_torchrec", |
| 664 | + False, |
| 665 | + ): |
| 666 | + torch._dynamo.config.capture_scalar_outputs = True |
| 667 | + torch._dynamo.config.capture_dynamic_output_shape_ops = True |
| 668 | + torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt = True |
| 669 | + |
| 670 | + opt_fn = torch.compile( |
| 671 | + dmp_compile, |
| 672 | + backend=torch_compile_backend, |
| 673 | + fullgraph=True, |
| 674 | + ) |
| 675 | + compile_out = opt_fn( |
| 676 | + kjt_for_pt2_tracing(kjt, convert_to_vb=True), compile_input_ff |
| 677 | + ) |
| 678 | + torch.testing.assert_close(eager_out, compile_out, atol=1e-3, rtol=1e-3) |
| 679 | + ##### COMPILE END ##### |
| 680 | + |
| 681 | + |
502 | 682 | class TestPt2Train(MultiProcessTestBase):
|
503 | 683 | def disable_cuda_tf32(self) -> bool:
|
504 | 684 | return True
|
@@ -580,3 +760,17 @@ def test_compile_multiprocess(
|
580 | 760 | config=config,
|
581 | 761 | torch_compile_backend=compile_backend,
|
582 | 762 | )
|
| 763 | + |
| 764 | + # pyre-ignore |
| 765 | + @unittest.skipIf( |
| 766 | + torch.cuda.device_count() < 1, |
| 767 | + "Not enough GPUs, this test requires one GPU", |
| 768 | + ) |
| 769 | + @settings(verbosity=Verbosity.verbose, deadline=None) |
| 770 | + def test_compile_multiprocess_fake_pg( |
| 771 | + self, |
| 772 | + ) -> None: |
| 773 | + _test_compile_fake_pg_fn( |
| 774 | + rank=0, |
| 775 | + world_size=2, |
| 776 | + ) |
0 commit comments