Skip to content

Commit

Permalink
Expand transfer_if_needed to work with arbitrary number of tensors.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-Vasile committed Feb 26, 2025
1 parent b5d471a commit 8c54c47
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 18 deletions.
56 changes: 38 additions & 18 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,46 @@ def copy_w_new_shards_and_devices(tensor: ShardedTensor, new_shards: List[torch.
else:
raise NotImplementedError(f"copy_with_new_shards not implemented for {type(tensor)}")

def transfer_if_needed(lhs: ShardedTensor, rhs: ShardedTensor) -> Tuple[ShardedTensor, ShardedTensor]:
if all(l_dev == r_dev for l_dev, r_dev in zip(lhs.devices, rhs.devices)):
return lhs, rhs # All corresponding shards are on the same devices, no need to transfer

if lhs.devices_pinned and rhs.devices_pinned:
raise ValueError(f"Both tensors have their devices pinned, but they are pinned to different devices, or in different orders: {lhs.devices} vs {rhs.devices}.")
elif not lhs.devices_pinned and not rhs.devices_pinned:
return lhs, rhs # TODO: If different devices and neither are pinned, shouldn't we need to transfer something?
def transfer_if_needed(*tensors: Tuple[ShardedTensor, ...]) -> Tuple[ShardedTensor, ...]:
if len(tensors) <= 1:
return
assert all(isinstance(tensor, ShardedTensor) for tensor in tensors)

# Check if all tensors are on the same devices.
devices_0 = tensors[0].devices
for tensor in tensors[1:]:
if not all(devices_0[j] == tensor.devices[j] for j in range(len(devices_0))):
break
else:
to_move, pinned = (lhs, rhs) if rhs.devices_pinned else (rhs, lhs)
shards = tuple(
(
transfer_to_logical_device(shard, pinned.devices[i])
if pinned.devices[i] != to_move.devices[i]
else barrier_on_logical_device(shard, pinned.devices[i])
return tensors # All tensors already on same device

i_pinned = tuple(i for i, t in enumerate(tensors) if t.devices_pinned)
if len(i_pinned) == 0:
return tensors # TODO: If different devices and none are pinned, shouldn't we need to transfer something?

d_pinned = tensors[i_pinned[0]].devices
for i in i_pinned[1:]:
if not all(d_pinned[j] == tensors[i].devices[j] for j in range(len(d_pinned))):
raise ValueError("All pinned tensors must be on the same devices.")

# Move all non-pinned tensors to the same devices as the pinned ones.
new_tensors = []
for i, tensor in enumerate(tensors):
if tensor.devices_pinned:
new_tensors.append(tensor)
else:
shards = tuple(
(
transfer_to_logical_device(shard, d_pinned[j])
if d_pinned[j] != tensor.devices[j]
else barrier_on_logical_device(shard, d_pinned[j])
)
for j, shard in enumerate(tensor.shards)
)
for i, shard in enumerate(to_move.shards)
)
moved = copy_w_new_shards_and_devices(to_move, shards, pinned.devices)
return (moved, pinned) if rhs.devices_pinned else (pinned, moved)
new_tensors.append(copy_w_new_shards_and_devices(tensor, shards, d_pinned))

return tuple(new_tensors)


def override_w_tranfer(operation, *override_args):
def decorator(f):
Expand Down
59 changes: 59 additions & 0 deletions sharktank/tests/ops/pipeline_parallelized_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,65 @@ def testBothPinnedOnDifferentDevices(self):
except ValueError:
return
assert False # Should have thrown a ValueError since both tensors are pinned, but devices are not the same

def testMultiTensorsNoPinned(self):
tensor_count = 5
shard_count = 4
shard_shape = [3, 4]
shards = [torch.rand(shard_shape, dtype=torch.float32) for _ in range(shard_count)]
t_pre = [
SplitPrimitiveTensor(shard_dim=1, ts=shards, devices=tuple(shard_count*i + d for d in range(shard_count)), devices_pinned=False)
for i in range(tensor_count)
]
t_post = ops.transfer_if_needed(*t_pre)

for i in range(tensor_count):
assert all(d_pre == d_post for d_pre, d_post in zip(t_pre[i].devices, t_post[i].devices))

def testMultiTensorsOnePinned(self):
tensor_count = 5
shard_count = 4
shard_shape = [3, 4]
shards = [torch.rand(shard_shape, dtype=torch.float32) for _ in range(shard_count)]
t_pre = [
SplitPrimitiveTensor(shard_dim=1, ts=shards, devices=tuple(shard_count*i + d for d in range(shard_count)), devices_pinned=(i==0))
for i in range(tensor_count)
]
t_post = ops.transfer_if_needed(*t_pre)

for i in range(tensor_count):
assert all(d_pre == d_post for d_pre, d_post in zip(t_pre[0].devices, t_post[i].devices))

def testMultiTensorsMultiPinnedNoConflict(self):
tensor_count = 5
shard_count = 4
shard_shape = [3, 4]
shards = [torch.rand(shard_shape, dtype=torch.float32) for _ in range(shard_count)]
t_pre = [
SplitPrimitiveTensor(shard_dim=1, ts=shards, devices=tuple(shard_count*i*(i % 2 != 0) + d for d in range(shard_count)), devices_pinned=(i % 2 == 0))
for i in range(tensor_count)
]
t_post = ops.transfer_if_needed(*t_pre)

for i in range(tensor_count):
assert all(d_pre == d_post for d_pre, d_post in zip(t_pre[0].devices, t_post[i].devices))

def testMultiTensorsMultiPinnedWithConflict(self):
tensor_count = 5
shard_count = 4
shard_shape = [3, 4]
shards = [torch.rand(shard_shape, dtype=torch.float32) for _ in range(shard_count)]
t_pre = [
SplitPrimitiveTensor(shard_dim=1, ts=shards, devices=tuple(shard_count*i + d for d in range(shard_count)), devices_pinned=(i < 2))
for i in range(tensor_count)
]
try:
ops.transfer_if_needed(*t_pre)
except ValueError:
return

assert False # Should throw and error since the first two tensors are pinned to different devices


class MatmulTest(unittest.TestCase):
def testShardedParallelAxesInLhsAndRhs(self): # matmul_split
Expand Down

0 comments on commit 8c54c47

Please sign in to comment.