Skip to content

Commit abbe5c1

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
RegroupAsDict module (#2007)
Summary: Pull Request resolved: #2007 Currently, we have KT.regroup as a functional call. Issue with this two fold: (1) we don't caching values we effectively know after first batch, leading to marginally higher cpu computation (2) this values look like unbacked SymInt in PT2 IR and most graph captures. Reality is they are known. The speedup for inferrence is dramatic (forward only), 2-4X faster in some cases. So while a user change, we are adding a new module, to leverage these above insights. Benchmark (fwd only) [fallback] _regroup_keyed_tenors | B: 512 | F: 320 | device: cuda | Runtime (P90): 2.1 ms | Memory (P90): 96.0 [prod] KeyedTensor.regroup | B: 512 | F: 320 | device: cuda | Runtime (P90): 0.8 ms | Memory (P90): 144.0 [prod] KTRegroupAsDict | B: 512 | F: 320 | device: cuda | Runtime (P90): 0.2 ms | Memory (P90): 144.0 [fallback] _regroup_keyed_tenors | B: 512 | F: 640 | device: cuda | Runtime (P90): 5.6 ms | Memory (P90): 192.0 [prod] KeyedTensor.regroup | B: 512 | F: 640 | device: cuda | Runtime (P90): 1.2 ms | Memory (P90): 288.0 [prod] KTRegroupAsDict | B: 512 | F: 640 | device: cuda | Runtime (P90): 0.3 ms | Memory (P90): 288.0 [fallback] _regroup_keyed_tenors | B: 512 | F: 1280 | device: cuda | Runtime (P90): 15.4 ms | Memory (P90): 384.0 [prod] KeyedTensor.regroup | B: 512 | F: 1280 | device: cuda | Runtime (P90): 1.7 ms | Memory (P90): 576.0 [prod] KTRegroupAsDict | B: 512 | F: 1280 | device: cuda | Runtime (P90): 0.6 ms | Memory (P90): 576.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 80 | device: cuda | Runtime (P90): 0.5 ms | Memory (P90): 48.0 [prod] KeyedTensor.regroup | B: 1024 | F: 80 | device: cuda | Runtime (P90): 0.5 ms | Memory (P90): 72.0 [prod] KTRegroupAsDict | B: 1024 | F: 80 | device: cuda | Runtime (P90): 0.1 ms | Memory (P90): 72.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 160 | device: cuda | Runtime (P90): 1.2 ms | Memory (P90): 96.0 [prod] KeyedTensor.regroup | B: 1024 | F: 160 | device: cuda | Runtime (P90): 0.8 ms | Memory (P90): 144.0 [prod] KTRegroupAsDict | B: 1024 | F: 160 | device: cuda | Runtime (P90): 0.2 ms | Memory (P90): 144.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 320 | device: cuda | Runtime (P90): 2.6 ms | Memory (P90): 192.0 [prod] KeyedTensor.regroup | B: 1024 | F: 320 | device: cuda | Runtime (P90): 1.0 ms | Memory (P90): 288.0 [prod] KTRegroupAsDict | B: 1024 | F: 320 | device: cuda | Runtime (P90): 0.3 ms | Memory (P90): 288.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 640 | device: cuda | Runtime (P90): 6.3 ms | Memory (P90): 384.0 [prod] KeyedTensor.regroup | B: 1024 | F: 640 | device: cuda | Runtime (P90): 1.1 ms | Memory (P90): 576.0 [prod] KTRegroupAsDict | B: 1024 | F: 640 | device: cuda | Runtime (P90): 0.5 ms | Memory (P90): 576.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 1280 | device: cuda | Runtime (P90): 15.4 ms | Memory (P90): 768.0 [prod] KeyedTensor.regroup | B: 1024 | F: 1280 | device: cuda | Runtime (P90): 1.7 ms | Memory (P90): 1152.0 [prod] KTRegroupAsDict | B: 1024 | F: 1280 | device: cuda | Runtime (P90): 1.1 ms | Memory (P90): 1152.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 80 | device: cuda | Runtime (P90): 0.5 ms | Memory (P90): 96.0 [prod] KeyedTensor.regroup | B: 2048 | F: 80 | device: cuda | Runtime (P90): 0.5 ms | Memory (P90): 144.0 [prod] KTRegroupAsDict | B: 2048 | F: 80 | device: cuda | Runtime (P90): 0.1 ms | Memory (P90): 144.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 160 | device: cuda | Runtime (P90): 0.9 ms | Memory (P90): 192.0 [prod] KeyedTensor.regroup | B: 2048 | F: 160 | device: cuda | Runtime (P90): 0.6 ms | Memory (P90): 288.0 [prod] KTRegroupAsDict | B: 2048 | F: 160 | device: cuda | Runtime (P90): 0.3 ms | Memory (P90): 288.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 320 | device: cuda | Runtime (P90): 2.0 ms | Memory (P90): 384.0 [prod] KeyedTensor.regroup | B: 2048 | F: 320 | device: cuda | Runtime (P90): 0.8 ms | Memory (P90): 576.0 [prod] KTRegroupAsDict | B: 2048 | F: 320 | device: cuda | Runtime (P90): 0.5 ms | Memory (P90): 576.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 640 | device: cuda | Runtime (P90): 5.2 ms | Memory (P90): 768.0 [prod] KeyedTensor.regroup | B: 2048 | F: 640 | device: cuda | Runtime (P90): 1.1 ms | Memory (P90): 1152.0 [prod] KTRegroupAsDict | B: 2048 | F: 640 | device: cuda | Runtime (P90): 1.1 ms | Memory (P90): 1152.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 1280 | device: cuda | Runtime (P90): 15.8 ms | Memory (P90): 1536.0 [prod] KeyedTensor.regroup | B: 2048 | F: 1280 | device: cuda | Runtime (P90): 2.2 ms | Memory (P90): 2304.0 [prod] KTRegroupAsDict | B: 2048 | F: 1280 | device: cuda | Runtime (P90): 2.2 ms | Memory (P90): 2304.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 80 | device: cuda | Runtime (P90): 0.5 ms | Memory (P90): 192.0 [prod] KeyedTensor.regroup | B: 4096 | F: 80 | device: cuda | Runtime (P90): 0.5 ms | Memory (P90): 288.0 [prod] KTRegroupAsDict | B: 4096 | F: 80 | device: cuda | Runtime (P90): 0.3 ms | Memory (P90): 288.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 160 | device: cuda | Runtime (P90): 0.9 ms | Memory (P90): 384.0 [prod] KeyedTensor.regroup | B: 4096 | F: 160 | device: cuda | Runtime (P90): 0.6 ms | Memory (P90): 576.0 [prod] KTRegroupAsDict | B: 4096 | F: 160 | device: cuda | Runtime (P90): 0.5 ms | Memory (P90): 576.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 320 | device: cuda | Runtime (P90): 2.0 ms | Memory (P90): 768.0 [prod] KeyedTensor.regroup | B: 4096 | F: 320 | device: cuda | Runtime (P90): 1.1 ms | Memory (P90): 1152.0 [prod] KTRegroupAsDict | B: 4096 | F: 320 | device: cuda | Runtime (P90): 1.1 ms | Memory (P90): 1152.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 640 | device: cuda | Runtime (P90): 5.2 ms | Memory (P90): 1536.0 [prod] KeyedTensor.regroup | B: 4096 | F: 640 | device: cuda | Runtime (P90): 2.1 ms | Memory (P90): 2304.0 [prod] KTRegroupAsDict | B: 4096 | F: 640 | device: cuda | Runtime (P90): 2.1 ms | Memory (P90): 2304.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 1280 | device: cuda | Runtime (P90): 15.8 ms | Memory (P90): 3072.0 [prod] KeyedTensor.regroup | B: 4096 | F: 1280 | device: cuda | Runtime (P90): 4.2 ms | Memory (P90): 4608.0 [prod] KTRegroupAsDict | B: 4096 | F: 1280 | device: cuda | Runtime (P90): 4.3 ms | Memory (P90): 4608.0 Benchmark (fwd+backward) [fallback] _regroup_keyed_tenors | B: 512 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 72.0 [prod] KeyedTensor.regroup | B: 512 | F: 80 | device: cuda | Runtime (P90): 2.8 ms | Memory (P90): 72.0 [prod] KTRegroupAsDict | B: 512 | F: 80 | device: cuda | Runtime (P90): 2.3 ms | Memory (P90): 72.0 [fallback] _regroup_keyed_tenors | B: 512 | F: 160 | device: cuda | Runtime (P90): 7.7 ms | Memory (P90): 144.0 [prod] KeyedTensor.regroup | B: 512 | F: 160 | device: cuda | Runtime (P90): 4.6 ms | Memory (P90): 144.0 [prod] KTRegroupAsDict | B: 512 | F: 160 | device: cuda | Runtime (P90): 3.9 ms | Memory (P90): 144.0 [fallback] _regroup_keyed_tenors | B: 512 | F: 320 | device: cuda | Runtime (P90): 10.8 ms | Memory (P90): 288.0 [prod] KeyedTensor.regroup | B: 512 | F: 320 | device: cuda | Runtime (P90): 7.5 ms | Memory (P90): 288.0 [prod] KTRegroupAsDict | B: 512 | F: 320 | device: cuda | Runtime (P90): 9.9 ms | Memory (P90): 288.0 [fallback] _regroup_keyed_tenors | B: 512 | F: 640 | device: cuda | Runtime (P90): 22.7 ms | Memory (P90): 576.0 [prod] KeyedTensor.regroup | B: 512 | F: 640 | device: cuda | Runtime (P90): 13.8 ms | Memory (P90): 576.0 [prod] KTRegroupAsDict | B: 512 | F: 640 | device: cuda | Runtime (P90): 18.6 ms | Memory (P90): 576.0 [fallback] _regroup_keyed_tenors | B: 512 | F: 1280 | device: cuda | Runtime (P90): 58.0 ms | Memory (P90): 1152.0 [prod] KeyedTensor.regroup | B: 512 | F: 1280 | device: cuda | Runtime (P90): 27.9 ms | Memory (P90): 1152.0 [prod] KTRegroupAsDict | B: 512 | F: 1280 | device: cuda | Runtime (P90): 25.7 ms | Memory (P90): 1152.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 144.0 [prod] KeyedTensor.regroup | B: 1024 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 144.0 [prod] KTRegroupAsDict | B: 1024 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 144.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 160 | device: cuda | Runtime (P90): 6.6 ms | Memory (P90): 288.0 [prod] KeyedTensor.regroup | B: 1024 | F: 160 | device: cuda | Runtime (P90): 6.4 ms | Memory (P90): 288.0 [prod] KTRegroupAsDict | B: 1024 | F: 160 | device: cuda | Runtime (P90): 4.1 ms | Memory (P90): 288.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 320 | device: cuda | Runtime (P90): 15.0 ms | Memory (P90): 576.0 [prod] KeyedTensor.regroup | B: 1024 | F: 320 | device: cuda | Runtime (P90): 8.0 ms | Memory (P90): 576.0 [prod] KTRegroupAsDict | B: 1024 | F: 320 | device: cuda | Runtime (P90): 8.0 ms | Memory (P90): 576.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 640 | device: cuda | Runtime (P90): 23.6 ms | Memory (P90): 1152.0 [prod] KeyedTensor.regroup | B: 1024 | F: 640 | device: cuda | Runtime (P90): 19.3 ms | Memory (P90): 1152.0 [prod] KTRegroupAsDict | B: 1024 | F: 640 | device: cuda | Runtime (P90): 13.6 ms | Memory (P90): 1152.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 1280 | device: cuda | Runtime (P90): 55.7 ms | Memory (P90): 2304.0 [prod] KeyedTensor.regroup | B: 1024 | F: 1280 | device: cuda | Runtime (P90): 28.4 ms | Memory (P90): 2304.0 [prod] KTRegroupAsDict | B: 1024 | F: 1280 | device: cuda | Runtime (P90): 26.8 ms | Memory (P90): 2304.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 80 | device: cuda | Runtime (P90): 3.6 ms | Memory (P90): 288.0 [prod] KeyedTensor.regroup | B: 2048 | F: 80 | device: cuda | Runtime (P90): 3.5 ms | Memory (P90): 288.0 [prod] KTRegroupAsDict | B: 2048 | F: 80 | device: cuda | Runtime (P90): 3.6 ms | Memory (P90): 288.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 160 | device: cuda | Runtime (P90): 7.0 ms | Memory (P90): 576.0 [prod] KeyedTensor.regroup | B: 2048 | F: 160 | device: cuda | Runtime (P90): 6.4 ms | Memory (P90): 576.0 [prod] KTRegroupAsDict | B: 2048 | F: 160 | device: cuda | Runtime (P90): 4.6 ms | Memory (P90): 576.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 320 | device: cuda | Runtime (P90): 11.2 ms | Memory (P90): 1152.0 [prod] KeyedTensor.regroup | B: 2048 | F: 320 | device: cuda | Runtime (P90): 8.2 ms | Memory (P90): 1152.0 [prod] KTRegroupAsDict | B: 2048 | F: 320 | device: cuda | Runtime (P90): 8.8 ms | Memory (P90): 1152.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 640 | device: cuda | Runtime (P90): 23.9 ms | Memory (P90): 2304.0 [prod] KeyedTensor.regroup | B: 2048 | F: 640 | device: cuda | Runtime (P90): 20.6 ms | Memory (P90): 2304.0 [prod] KTRegroupAsDict | B: 2048 | F: 640 | device: cuda | Runtime (P90): 14.6 ms | Memory (P90): 2304.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 1280 | device: cuda | Runtime (P90): 54.5 ms | Memory (P90): 4608.0 [prod] KeyedTensor.regroup | B: 2048 | F: 1280 | device: cuda | Runtime (P90): 28.3 ms | Memory (P90): 4608.0 [prod] KTRegroupAsDict | B: 2048 | F: 1280 | device: cuda | Runtime (P90): 25.7 ms | Memory (P90): 4608.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 576.0 [prod] KeyedTensor.regroup | B: 4096 | F: 80 | device: cuda | Runtime (P90): 2.7 ms | Memory (P90): 576.0 [prod] KTRegroupAsDict | B: 4096 | F: 80 | device: cuda | Runtime (P90): 2.3 ms | Memory (P90): 576.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 160 | device: cuda | Runtime (P90): 5.8 ms | Memory (P90): 1152.0 [prod] KeyedTensor.regroup | B: 4096 | F: 160 | device: cuda | Runtime (P90): 4.4 ms | Memory (P90): 1152.0 [prod] KTRegroupAsDict | B: 4096 | F: 160 | device: cuda | Runtime (P90): 3.9 ms | Memory (P90): 1152.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 320 | device: cuda | Runtime (P90): 11.1 ms | Memory (P90): 2304.0 [prod] KeyedTensor.regroup | B: 4096 | F: 320 | device: cuda | Runtime (P90): 7.8 ms | Memory (P90): 2304.0 [prod] KTRegroupAsDict | B: 4096 | F: 320 | device: cuda | Runtime (P90): 7.0 ms | Memory (P90): 2304.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 640 | device: cuda | Runtime (P90): 23.9 ms | Memory (P90): 4608.0 [prod] KeyedTensor.regroup | B: 4096 | F: 640 | device: cuda | Runtime (P90): 14.5 ms | Memory (P90): 4608.0 [prod] KTRegroupAsDict | B: 4096 | F: 640 | device: cuda | Runtime (P90): 13.3 ms | Memory (P90): 4608.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 1280 | device: cuda | Runtime (P90): 64.0 ms | Memory (P90): 9216.0 [prod] KeyedTensor.regroup | B: 4096 | F: 1280 | device: cuda | Runtime (P90): 26.9 ms | Memory (P90): 9216.0 [prod] KTRegroupAsDict | B: 4096 | F: 1280 | device: cuda | Runtime (P90): 25.1 ms | Memory (P90): 9216.0 Reviewed By: PaulZhang12 Differential Revision: D57312926 fbshipit-source-id: fab925b273f2c22bb1189b4a6158640009779a8f
1 parent a7baf33 commit abbe5c1

File tree

4 files changed

+315
-1
lines changed

4 files changed

+315
-1
lines changed

torchrec/modules/regroup.py

+158
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
#!/usr/bin/env python3
9+
10+
from typing import Dict, List, Optional, Tuple
11+
12+
import torch
13+
from torchrec.sparse.jagged_tensor import (
14+
_all_keys_used_once,
15+
_desugar_keyed_tensors,
16+
_remap_to_groups,
17+
KeyedTensor,
18+
)
19+
20+
21+
@torch.fx.wrap
22+
def _concat_values(kts: List[KeyedTensor], dim: int) -> torch.Tensor:
23+
return torch.cat([kt.values() for kt in kts], dim=dim)
24+
25+
26+
@torch.fx.wrap
27+
def _permuted_values(
28+
kts: List[KeyedTensor], remap: List[Tuple[int, str]], dim: int
29+
) -> torch.Tensor:
30+
embedding_dicts = [kt.to_dict() for kt in kts]
31+
values = [embedding_dicts[idx][key] for (idx, key) in remap]
32+
return torch.cat(values, dim=dim)
33+
34+
35+
@torch.fx.wrap
36+
def _build_dict(
37+
keys: List[str], values: torch.Tensor, splits: List[int], dim: int
38+
) -> Dict[str, torch.Tensor]:
39+
return {
40+
key: tensor for key, tensor in zip(keys, torch.split(values, splits, dim=dim))
41+
}
42+
43+
44+
class KTRegroupAsDict(torch.nn.Module):
45+
"""
46+
KTRegroupAsDict is a nn.Module that mirrors beahvior of static method KeyedTensor.regroup_as_dict()
47+
48+
The advantage of using this module it caches the regrouping logic after first batch.
49+
50+
Args:
51+
groups (List[List[str]]): features per output group
52+
keys (List[str]): key of each output group
53+
54+
Example::
55+
56+
keys = ['object', 'user']
57+
groups = [['f1', 'f2'], ['f3']]
58+
regroup_module = KTRegroupAsDict(groups, keys)
59+
60+
61+
tensor_list = [torch.randn(2, 4), torch.randn(2, 8), torch.randn(2, 2)]
62+
kts = [KeyedTensor.from_tensor_list(['f1', 'f2', 'f3' ], tensor_list)]
63+
out = regroup_module(kts)
64+
65+
"""
66+
67+
def __init__(self, groups: List[List[str]], keys: List[str]) -> None:
68+
super().__init__()
69+
torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}")
70+
assert len(groups) == len(keys), "Groups and keys should have same length"
71+
self._groups = groups
72+
self._keys = keys
73+
self._is_inited = False
74+
75+
# cached values populated on first forward call
76+
self.device: Optional[torch.device] = None
77+
self._concat_dim: int = 1
78+
self._use_fbgemm_regroup: bool = False
79+
self._splits: List[int] = []
80+
self._idx_key_pairs: List[Tuple[int, str]] = []
81+
self._permute_tensor: Optional[torch.Tensor] = None
82+
self._inv_permute_tensor: Optional[torch.Tensor] = None
83+
self._offsets_tensor: Optional[torch.Tensor] = None
84+
self._inv_offsets_tensor: Optional[torch.Tensor] = None
85+
86+
def _init_fbgemm_regroup(self, kts: List[KeyedTensor]) -> None:
87+
self._use_fbgemm_regroup = True
88+
keys, lengths, values = _desugar_keyed_tensors(kts)
89+
permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups(
90+
keys, lengths, self._groups
91+
)
92+
# no need to pin_memory() or to(..., non_blocking=True) since occurs only once
93+
self._permute_tensor = permute.to(self.device)
94+
self._inv_permute_tensor = inv_permute.to(self.device)
95+
self._offsets_tensor = offsets.to(self.device)
96+
self._inv_offsets_tensor = inv_offsets.to(self.device)
97+
self._splits = splits
98+
99+
def _init_regroup(self, kts: List[KeyedTensor]) -> None:
100+
lengths = [kt.length_per_key() for kt in kts]
101+
indices = [kt._key_indices() for kt in kts]
102+
103+
key_to_idx: dict[str, int] = {}
104+
for i, kt in enumerate(kts):
105+
for key in kt.keys():
106+
if key in key_to_idx:
107+
raise RuntimeError(
108+
f"Duplicate key {key} found in KeyedTensors, undefined behavior"
109+
)
110+
key_to_idx[key] = i
111+
112+
splits: List[int] = []
113+
idx_key_pairs: List[Tuple[int, str]] = []
114+
for group in self._groups:
115+
group_length = 0
116+
for name in group:
117+
idx_key_pairs.append((key_to_idx[name], name))
118+
group_length += lengths[key_to_idx[name]][
119+
indices[key_to_idx[name]][name]
120+
]
121+
splits.append(group_length)
122+
123+
self._splits = splits
124+
self._idx_key_pairs = idx_key_pairs
125+
126+
def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]:
127+
if not self._is_inited:
128+
assert len(keyed_tensors) > 0, "Empty list provided"
129+
assert all(
130+
kt.device() == keyed_tensors[0].device() for kt in keyed_tensors
131+
), "All inputs should be on the same device."
132+
self.device = keyed_tensors[0].device()
133+
assert all(
134+
kt.key_dim() == keyed_tensors[0].key_dim() for kt in keyed_tensors
135+
), "All inputs should have the same key_dim"
136+
self._dim = keyed_tensors[0].key_dim()
137+
138+
if _all_keys_used_once(keyed_tensors, self._groups) and self._dim == 1:
139+
self._init_fbgemm_regroup(keyed_tensors)
140+
else:
141+
self._init_regroup(keyed_tensors)
142+
self._is_inited = True
143+
144+
if self._use_fbgemm_regroup:
145+
values = _concat_values(keyed_tensors, self._dim)
146+
permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad(
147+
values,
148+
self._offsets_tensor,
149+
self._permute_tensor,
150+
self._inv_offsets_tensor,
151+
self._inv_permute_tensor,
152+
)
153+
else:
154+
permuted_values = _permuted_values(
155+
keyed_tensors, self._idx_key_pairs, self._dim
156+
)
157+
158+
return _build_dict(self._keys, permuted_values, self._splits, self._dim)
+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
12+
import torch
13+
import torch.fx
14+
15+
from torchrec.modules.regroup import KTRegroupAsDict
16+
from torchrec.sparse.jagged_tensor import _all_keys_used_once, KeyedTensor
17+
from torchrec.sparse.tests.utils import build_groups, build_kts
18+
19+
20+
class KTRegroupAsDictTest(unittest.TestCase):
21+
def setUp(self) -> None:
22+
super().setUp()
23+
self.kts = build_kts(
24+
dense_features=20,
25+
sparse_features=20,
26+
dim_dense=64,
27+
dim_sparse=128,
28+
batch_size=128,
29+
device=torch.device("cpu"),
30+
run_backward=True,
31+
)
32+
self.num_groups = 2
33+
self.keys = ["user", "object"]
34+
self.labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float()
35+
36+
def test_regroup_backward_skips_and_duplicates(self) -> None:
37+
groups = build_groups(
38+
kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True
39+
)
40+
assert _all_keys_used_once(self.kts, groups) is False
41+
42+
regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys)
43+
tensor_groups = regroup_module(self.kts)
44+
pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
45+
loss = torch.nn.functional.l1_loss(pred0, self.labels).sum()
46+
actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad(
47+
loss, [self.kts[0].values(), self.kts[1].values()]
48+
)
49+
50+
# clear grads so can reuse inputs
51+
self.kts[0].values().grad = None
52+
self.kts[1].values().grad = None
53+
54+
tensor_groups = KeyedTensor.regroup_as_dict(
55+
keyed_tensors=self.kts, groups=groups, keys=self.keys
56+
)
57+
pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
58+
loss = torch.nn.functional.l1_loss(pred1, self.labels).sum()
59+
expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad(
60+
loss, [self.kts[0].values(), self.kts[1].values()]
61+
)
62+
63+
torch.allclose(pred0, pred1)
64+
torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
65+
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)
66+
67+
def test_regroup_backward(self) -> None:
68+
groups = build_groups(
69+
kts=self.kts, num_groups=self.num_groups, skips=False, duplicates=False
70+
)
71+
assert _all_keys_used_once(self.kts, groups) is True
72+
73+
regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys)
74+
tensor_groups = regroup_module(self.kts)
75+
pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
76+
loss = torch.nn.functional.l1_loss(pred0, self.labels).sum()
77+
actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad(
78+
loss, [self.kts[0].values(), self.kts[1].values()]
79+
)
80+
81+
# clear grads so can reuse inputs
82+
self.kts[0].values().grad = None
83+
self.kts[1].values().grad = None
84+
85+
tensor_groups = KeyedTensor.regroup_as_dict(
86+
keyed_tensors=self.kts, groups=groups, keys=self.keys
87+
)
88+
pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1))
89+
loss = torch.nn.functional.l1_loss(pred1, self.labels).sum()
90+
expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad(
91+
loss, [self.kts[0].values(), self.kts[1].values()]
92+
)
93+
94+
torch.allclose(pred0, pred1)
95+
torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
96+
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)
97+
98+
def test_fx_and_jit_regroup(self) -> None:
99+
groups = build_groups(
100+
kts=self.kts, num_groups=self.num_groups, skips=False, duplicates=False
101+
)
102+
assert _all_keys_used_once(self.kts, groups) is True
103+
104+
regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys)
105+
# first pass
106+
regroup_module(self.kts)
107+
108+
# now trace
109+
gm = torch.fx.symbolic_trace(regroup_module)
110+
jit_gm = torch.jit.script(gm)
111+
112+
out = jit_gm(self.kts)
113+
eager_out = regroup_module(self.kts)
114+
for key in out.keys():
115+
torch.allclose(out[key], eager_out[key])
116+
117+
def test_fx_and_jit_regroup_skips_and_duplicates(self) -> None:
118+
groups = build_groups(
119+
kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True
120+
)
121+
assert _all_keys_used_once(self.kts, groups) is False
122+
123+
regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys)
124+
# first pass
125+
regroup_module(self.kts)
126+
127+
# now trace
128+
gm = torch.fx.symbolic_trace(regroup_module)
129+
jit_gm = torch.jit.script(gm)
130+
131+
out = jit_gm(self.kts)
132+
eager_out = regroup_module(self.kts)
133+
for key in out.keys():
134+
torch.allclose(out[key], eager_out[key])

torchrec/sparse/jagged_tensor.py

+6
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,9 @@ def to_padded_dense_weights(
650650
self.weights(), [self.offsets()], [N], padding_value
651651
)
652652

653+
def device(self) -> torch.device:
654+
return self._values.device
655+
653656
def lengths(self) -> torch.Tensor:
654657
_lengths = _maybe_compute_lengths(self._lengths, self._offsets)
655658
self._lengths = _lengths
@@ -2570,6 +2573,9 @@ def values(self) -> torch.Tensor:
25702573
def key_dim(self) -> int:
25712574
return self._key_dim
25722575

2576+
def device(self) -> torch.device:
2577+
return self._values.device
2578+
25732579
def offset_per_key(self) -> List[int]:
25742580
_offset_per_key = _maybe_compute_offset_per_key_kt(
25752581
self._length_per_key,

torchrec/sparse/tests/jagged_tensor_benchmark.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818
from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult
19+
from torchrec.modules.regroup import KTRegroupAsDict
1920
from torchrec.sparse.jagged_tensor import (
2021
_regroup_keyed_tensors,
2122
KeyedJaggedTensor,
@@ -53,7 +54,10 @@ def wrapped_func(
5354
) -> None:
5455
result = fn(**fn_kwargs)
5556
if run_backward:
56-
vectors = [tensor.sum(dim=1) for tensor in result]
57+
if isinstance(result, dict):
58+
vectors = [tensor.sum(dim=1) for tensor in result.values()]
59+
else:
60+
vectors = [tensor.sum(dim=1) for tensor in result]
5761
pred = vectors[0]
5862
for vector in vectors[1:]:
5963
pred.mul(vector)
@@ -216,6 +220,18 @@ def main(
216220
KeyedTensor.regroup,
217221
{"keyed_tensors": kts, "groups": groups},
218222
)
223+
bench(
224+
"[prod] KTRegroupAsDict",
225+
labels,
226+
batch_size,
227+
n_dense + n_sparse,
228+
device_type,
229+
run_backward,
230+
KTRegroupAsDict(
231+
groups=groups, keys=[str(i) for i in range(n_groups)]
232+
),
233+
{"keyed_tensors": kts},
234+
)
219235

220236

221237
if __name__ == "__main__":

0 commit comments

Comments
 (0)