@@ -25,6 +25,7 @@ def start(
25
25
num_experts ,
26
26
tokens_per_expert ,
27
27
dim ,
28
+ use_cuda
28
29
):
29
30
setup (rank , world_size )
30
31
@@ -35,18 +36,11 @@ def start(
35
36
36
37
seq = torch .randn (batch_size , num_experts , tokens_per_expert , dim )
37
38
38
- # distributed
39
-
40
- model = DDP (net )
41
- out = model (seq )
42
- out .mean ().backward ()
43
-
44
- ddp_all_out , _ = all_gather_variable_dim (out )
39
+ # locally
45
40
46
- # on single device
41
+ local_net = deepcopy ( net )
47
42
48
43
local_inputs , _ = all_gather_variable_dim (seq )
49
- local_net = deepcopy (net )
50
44
51
45
local_out = local_net (
52
46
local_inputs ,
@@ -55,18 +49,34 @@ def start(
55
49
56
50
local_out .mean ().backward ()
57
51
52
+ # distributed
53
+
54
+ model = DDP (net )
55
+ ddp_inputs = seq
56
+
57
+ if use_cuda :
58
+ model .cuda (rank )
59
+ ddp_inputs = seq .cuda (rank )
60
+
61
+ out = model (ddp_inputs )
62
+ out .mean ().backward ()
63
+
64
+ ddp_all_out , _ = all_gather_variable_dim (out )
65
+
58
66
if rank == 0 :
59
- # validate output is the same
60
- # if done on 1 vs multiple machines
67
+ # validate output is the same for local vs distributed
61
68
62
- assert torch .allclose (local_out , ddp_all_out ), 'output is not the same'
69
+ model .cpu ()
70
+ ddp_all_out .cpu ()
63
71
64
- # validate backwards and grad
72
+ assert torch .allclose (local_out , ddp_all_out .cpu (), atol = 1e-3 ), 'output is not the same'
73
+
74
+ # validate gradients of first expert is the same for local vs distributed
65
75
66
76
get_first_expert_grad = lambda t : t .experts [0 ].net [0 ].weight .grad
67
77
68
78
assert torch .allclose (
69
- get_first_expert_grad (net ),
79
+ get_first_expert_grad (net ). cpu () ,
70
80
get_first_expert_grad (local_net ),
71
81
atol = 1e-2
72
82
), 'grad is not the same'
@@ -76,10 +86,13 @@ def start(
76
86
cleanup ()
77
87
78
88
if __name__ == '__main__' :
79
- world_size = 13
80
- num_experts = 4
89
+ world_size = 8
90
+ num_experts = 3
81
91
batch_size = 2
82
92
batch_size_var_len = True
93
+ use_cuda = False
94
+
95
+ assert not use_cuda or torch .cuda .device_count () <= world_size
83
96
84
97
seq_len = 32
85
98
dim = 8
@@ -92,7 +105,8 @@ def start(
92
105
batch_size_var_len ,
93
106
num_experts ,
94
107
seq_len ,
95
- dim
108
+ dim ,
109
+ use_cuda
96
110
),
97
111
nprocs = world_size ,
98
112
join = True
0 commit comments