Skip to content

Commit 240414a

Browse files
committed
validate that it works on single machine with multiple GPUs
1 parent 2bb762d commit 240414a

File tree

1 file changed

+31
-17
lines changed

1 file changed

+31
-17
lines changed

assert.py

+31-17
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def start(
2525
num_experts,
2626
tokens_per_expert,
2727
dim,
28+
use_cuda
2829
):
2930
setup(rank, world_size)
3031

@@ -35,18 +36,11 @@ def start(
3536

3637
seq = torch.randn(batch_size, num_experts, tokens_per_expert, dim)
3738

38-
# distributed
39-
40-
model = DDP(net)
41-
out = model(seq)
42-
out.mean().backward()
43-
44-
ddp_all_out, _ = all_gather_variable_dim(out)
39+
# locally
4540

46-
# on single device
41+
local_net = deepcopy(net)
4742

4843
local_inputs, _ = all_gather_variable_dim(seq)
49-
local_net = deepcopy(net)
5044

5145
local_out = local_net(
5246
local_inputs,
@@ -55,18 +49,34 @@ def start(
5549

5650
local_out.mean().backward()
5751

52+
# distributed
53+
54+
model = DDP(net)
55+
ddp_inputs = seq
56+
57+
if use_cuda:
58+
model.cuda(rank)
59+
ddp_inputs = seq.cuda(rank)
60+
61+
out = model(ddp_inputs)
62+
out.mean().backward()
63+
64+
ddp_all_out, _ = all_gather_variable_dim(out)
65+
5866
if rank == 0:
59-
# validate output is the same
60-
# if done on 1 vs multiple machines
67+
# validate output is the same for local vs distributed
6168

62-
assert torch.allclose(local_out, ddp_all_out), 'output is not the same'
69+
model.cpu()
70+
ddp_all_out.cpu()
6371

64-
# validate backwards and grad
72+
assert torch.allclose(local_out, ddp_all_out.cpu(), atol = 1e-3), 'output is not the same'
73+
74+
# validate gradients of first expert is the same for local vs distributed
6575

6676
get_first_expert_grad = lambda t: t.experts[0].net[0].weight.grad
6777

6878
assert torch.allclose(
69-
get_first_expert_grad(net),
79+
get_first_expert_grad(net).cpu(),
7080
get_first_expert_grad(local_net),
7181
atol = 1e-2
7282
), 'grad is not the same'
@@ -76,10 +86,13 @@ def start(
7686
cleanup()
7787

7888
if __name__ == '__main__':
79-
world_size = 13
80-
num_experts = 4
89+
world_size = 8
90+
num_experts = 3
8191
batch_size = 2
8292
batch_size_var_len = True
93+
use_cuda = False
94+
95+
assert not use_cuda or torch.cuda.device_count() <= world_size
8396

8497
seq_len = 32
8598
dim = 8
@@ -92,7 +105,8 @@ def start(
92105
batch_size_var_len,
93106
num_experts,
94107
seq_len,
95-
dim
108+
dim,
109+
use_cuda
96110
),
97111
nprocs = world_size,
98112
join = True

0 commit comments

Comments
 (0)