Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

concat VBE embeddings in lookup module #2215

Closed
wants to merge 1 commit into from

Conversation

joshuadeng
Copy link
Contributor

Summary:
Previously to handle multiple VBE TBE output which is 1d tensor ordered by rank, we grouped sharding info such that there would only be one TBE created per sharding module. This avoided the issue of concatting multiple 1d tensors that are ordered by rank (not a problem in on VBE bc of 2d output which we can concat on dim 1).

This grouping which would be done only applies to specific UVM caching setups that used prefetch pipeline, as each sharding type could require multiple TBE to handle both HBM and UVM caching setups. In most cases the TBE could be fused for each sharding type, so we grouped by such.

Each sharding module handles individual input dist, lookup, output dist, and by creating a sharding module per each TBE in EMO setups would cause regression, as there would be an increase in comms to handle the increased input dists and output dists.

This diff removes the need for the grouping logic to circumvent the VBE TBE output concatenation by implementing output concatenation, which removes the necessity for specialized sharding grouping logic for specific EMO cases.

Differential Revision: D58894728

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 9, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58894728

Summary:
Pull Request resolved: pytorch#2215

Previously to handle multiple VBE TBE output which is 1d tensor ordered by rank, we grouped sharding info such that there would only be one TBE created per sharding module. This avoided the issue of concatting multiple 1d tensors that are ordered by rank (not a problem in on VBE bc of 2d output which we can concat on dim 1).

This grouping which would be done only applies to specific UVM caching setups that used prefetch pipeline, as each sharding type could require multiple TBE to handle both HBM and UVM caching setups. In most cases the TBE could be fused for each sharding type, so we grouped by such.

Each sharding module handles individual input dist, lookup, output dist, and by creating a sharding module per each TBE in EMO setups would cause regression, as there would be an increase in comms to handle the increased input dists and output dists.

This diff removes the need for the grouping logic to circumvent the VBE TBE output concatenation by implementing output concatenation, which removes the necessity for specialized sharding grouping logic for specific EMO cases.

Reviewed By: dstaay-fb

Differential Revision: D58894728
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58894728

joshuadeng added a commit to joshuadeng/torchrec that referenced this pull request Jul 10, 2024
Summary:
Pull Request resolved: pytorch#2215

Previously to handle multiple VBE TBE output which is 1d tensor ordered by rank, we grouped sharding info such that there would only be one TBE created per sharding module. This avoided the issue of concatting multiple 1d tensors that are ordered by rank (not a problem in on VBE bc of 2d output which we can concat on dim 1).

This grouping which would be done only applies to specific UVM caching setups that used prefetch pipeline, as each sharding type could require multiple TBE to handle both HBM and UVM caching setups. In most cases the TBE could be fused for each sharding type, so we grouped by such.

Each sharding module handles individual input dist, lookup, output dist, and by creating a sharding module per each TBE in EMO setups would cause regression, as there would be an increase in comms to handle the increased input dists and output dists.

This diff removes the need for the grouping logic to circumvent the VBE TBE output concatenation by implementing output concatenation, which removes the necessity for specialized sharding grouping logic for specific EMO cases.

Differential Revision: D58894728

Reviewed By: dstaay-fb, levythu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants