-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathmixture_of_experts.py
323 lines (260 loc) · 12.6 KB
/
mixture_of_experts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import torch
from torch import nn
import torch.nn.functional as F
import math
from inspect import isfunction
# constants
MIN_EXPERT_CAPACITY = 4
# helper functions
def default(val, default_val):
default_val = default_val() if isfunction(default_val) else default_val
return val if val is not None else default_val
def cast_tuple(el):
return el if isinstance(el, tuple) else (el,)
# tensor related helper functions
def top1(t):
values, index = t.topk(k=1, dim=-1)
values, index = map(lambda x: x.squeeze(dim=-1), (values, index))
return values, index
def cumsum_exclusive(t, dim=-1):
num_dims = len(t.shape)
num_pad_dims = - dim - 1
pre_padding = (0, 0) * num_pad_dims
pre_slice = (slice(None),) * num_pad_dims
padded_t = F.pad(t, (*pre_padding, 1, 0)).cumsum(dim=dim)
return padded_t[(..., slice(None, -1), *pre_slice)]
# pytorch one hot throws an error if there are out of bound indices.
# tensorflow, in contrast, does not throw an error
def safe_one_hot(indexes, max_length):
max_index = indexes.max() + 1
return F.one_hot(indexes, max(max_index + 1, max_length))[..., :max_length]
def init_(t):
dim = t.shape[-1]
std = 1 / math.sqrt(dim)
return t.uniform_(-std, std)
# activations
class GELU_(nn.Module):
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_
# expert class
class Experts(nn.Module):
def __init__(self,
dim,
num_experts = 16,
hidden_dim = None,
activation = GELU):
super().__init__()
hidden_dim = default(hidden_dim, dim * 4)
num_experts = cast_tuple(num_experts)
w1 = torch.zeros(*num_experts, dim, hidden_dim)
w2 = torch.zeros(*num_experts, hidden_dim, dim)
w1 = init_(w1)
w2 = init_(w2)
self.w1 = nn.Parameter(w1)
self.w2 = nn.Parameter(w2)
self.act = activation()
def forward(self, x):
hidden = torch.einsum('...nd,...dh->...nh', x, self.w1)
hidden = self.act(hidden)
out = torch.einsum('...nh,...hd->...nd', hidden, self.w2)
return out
# the below code is almost all transcribed from the official tensorflow version, from which the papers are written
# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/research/moe.py
# gating network
class Top2Gating(nn.Module):
def __init__(
self,
dim,
num_gates,
eps = 1e-9,
outer_expert_dims = tuple(),
second_policy_train = 'random',
second_policy_eval = 'random',
second_threshold_train = 0.2,
second_threshold_eval = 0.2,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.):
super().__init__()
self.eps = eps
self.num_gates = num_gates
self.w_gating = nn.Parameter(torch.randn(*outer_expert_dims, dim, num_gates))
self.second_policy_train = second_policy_train
self.second_policy_eval = second_policy_eval
self.second_threshold_train = second_threshold_train
self.second_threshold_eval = second_threshold_eval
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
def forward(self, x, importance = None):
*_, b, group_size, dim = x.shape
num_gates = self.num_gates
if self.training:
policy = self.second_policy_train
threshold = self.second_threshold_train
capacity_factor = self.capacity_factor_train
else:
policy = self.second_policy_eval
threshold = self.second_threshold_eval
capacity_factor = self.capacity_factor_eval
raw_gates = torch.einsum('...bnd,...de->...bne', x, self.w_gating)
raw_gates = raw_gates.softmax(dim=-1)
# FIND TOP 2 EXPERTS PER POSITON
# Find the top expert for each position. shape=[batch, group]
gate_1, index_1 = top1(raw_gates)
mask_1 = F.one_hot(index_1, num_gates).float()
density_1_proxy = raw_gates
if importance is not None:
equals_one_mask = (importance == 1.).float()
mask_1 *= equals_one_mask[..., None]
gate_1 *= equals_one_mask
density_1_proxy = density_1_proxy * equals_one_mask[..., None]
del equals_one_mask
gates_without_top_1 = raw_gates * (1. - mask_1)
gate_2, index_2 = top1(gates_without_top_1)
mask_2 = F.one_hot(index_2, num_gates).float()
if importance is not None:
greater_zero_mask = (importance > 0.).float()
mask_2 *= greater_zero_mask[..., None]
del greater_zero_mask
# normalize top2 gate scores
denom = gate_1 + gate_2 + self.eps
gate_1 /= denom
gate_2 /= denom
# BALANCING LOSSES
# shape = [batch, experts]
# We want to equalize the fraction of the batch assigned to each expert
density_1 = mask_1.mean(dim=-2)
# Something continuous that is correlated with what we want to equalize.
density_1_proxy = density_1_proxy.mean(dim=-2)
loss = (density_1_proxy * density_1).mean() * float(num_gates ** 2)
# Depending on the policy in the hparams, we may drop out some of the
# second-place experts.
if policy == "all":
pass
elif policy == "none":
mask_2 = torch.zeros_like(mask_2)
elif policy == "threshold":
mask_2 *= (gate_2 > threshold).float()
elif policy == "random":
probs = torch.zeros_like(gate_2).uniform_(0., 1.)
mask_2 *= (probs < (gate_2 / max(threshold, self.eps))).float().unsqueeze(-1)
else:
raise ValueError(f"Unknown policy {policy}")
# Each sequence sends (at most?) expert_capacity positions to each expert.
# Static expert_capacity dimension is needed for expert batch sizes
expert_capacity = min(group_size, int((group_size * capacity_factor) / num_gates))
expert_capacity = max(expert_capacity, MIN_EXPERT_CAPACITY)
expert_capacity_f = float(expert_capacity)
# COMPUTE ASSIGNMENT TO EXPERTS
# [batch, group, experts]
# This is the position within the expert's mini-batch for this sequence
position_in_expert_1 = cumsum_exclusive(mask_1, dim=-2) * mask_1
# Remove the elements that don't fit. [batch, group, experts]
mask_1 *= (position_in_expert_1 < expert_capacity_f).float()
# [batch, experts]
# How many examples in this sequence go to this expert
mask_1_count = mask_1.sum(dim=-2, keepdim=True)
# [batch, group] - mostly ones, but zeros where something didn't fit
mask_1_flat = mask_1.sum(dim=-1)
# [batch, group]
position_in_expert_1 = position_in_expert_1.sum(dim=-1)
# Weight assigned to first expert. [batch, group]
gate_1 *= mask_1_flat
position_in_expert_2 = cumsum_exclusive(mask_2, dim=-2) + mask_1_count
position_in_expert_2 *= mask_2
mask_2 *= (position_in_expert_2 < expert_capacity_f).float()
mask_2_flat = mask_2.sum(dim=-1)
position_in_expert_2 = position_in_expert_2.sum(dim=-1)
gate_2 *= mask_2_flat
# [batch, group, experts, expert_capacity]
combine_tensor = (
gate_1[..., None, None]
* mask_1_flat[..., None, None]
* F.one_hot(index_1, num_gates)[..., None]
* safe_one_hot(position_in_expert_1.long(), expert_capacity)[..., None, :] +
gate_2[..., None, None]
* mask_2_flat[..., None, None]
* F.one_hot(index_2, num_gates)[..., None]
* safe_one_hot(position_in_expert_2.long(), expert_capacity)[..., None, :]
)
dispatch_tensor = combine_tensor.bool().to(combine_tensor)
return dispatch_tensor, combine_tensor, loss
# plain mixture of experts
class MoE(nn.Module):
def __init__(self,
dim,
num_experts = 16,
hidden_dim = None,
activation = nn.ReLU,
second_policy_train = 'random',
second_policy_eval = 'random',
second_threshold_train = 0.2,
second_threshold_eval = 0.2,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.,
loss_coef = 1e-2,
experts = None):
super().__init__()
self.num_experts = num_experts
gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval}
self.gate = Top2Gating(dim, num_gates = num_experts, **gating_kwargs)
self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
self.loss_coef = loss_coef
def forward(self, inputs, **kwargs):
b, n, d, e = *inputs.shape, self.num_experts
dispatch_tensor, combine_tensor, loss = self.gate(inputs)
expert_inputs = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor)
# Now feed the expert inputs through the experts.
orig_shape = expert_inputs.shape
expert_inputs = expert_inputs.reshape(e, -1, d)
expert_outputs = self.experts(expert_inputs)
expert_outputs = expert_outputs.reshape(*orig_shape)
output = torch.einsum('ebcd,bnec->bnd', expert_outputs, combine_tensor)
return output, loss * self.loss_coef
# 2-level heirarchical mixture of experts
class HeirarchicalMoE(nn.Module):
def __init__(self,
dim,
num_experts = (4, 4),
hidden_dim = None,
activation = nn.ReLU,
second_policy_train = 'random',
second_policy_eval = 'random',
second_threshold_train = 0.2,
second_threshold_eval = 0.2,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.,
loss_coef = 1e-2,
experts = None):
super().__init__()
assert len(num_experts) == 2, 'only 2 levels of heirarchy for experts allowed for now'
num_experts_outer, num_experts_inner = num_experts
self.num_experts_outer = num_experts_outer
self.num_experts_inner = num_experts_inner
gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval}
self.gate_outer = Top2Gating(dim, num_gates = num_experts_outer, **gating_kwargs)
self.gate_inner = Top2Gating(dim, num_gates = num_experts_inner, outer_expert_dims = (num_experts_outer,), **gating_kwargs)
self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
self.loss_coef = loss_coef
def forward(self, inputs, **kwargs):
b, n, d, eo, ei = *inputs.shape, self.num_experts_outer, self.num_experts_inner
dispatch_tensor_outer, combine_tensor_outer, loss_outer = self.gate_outer(inputs)
expert_inputs_outer = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor_outer)
# we construct an "importance" Tensor for the inputs to the second-level
# gating. The importance of an input is 1.0 if it represents the
# first-choice expert-group and 0.5 if it represents the second-choice expert
# group. This is used by the second-level gating.
importance = combine_tensor_outer.permute(2, 0, 3, 1).sum(dim=-1)
importance = 0.5 * ((importance > 0.5).float() + (importance > 0.).float())
dispatch_tensor_inner, combine_tensor_inner, loss_inner = self.gate_inner(expert_inputs_outer, importance = importance)
expert_inputs = torch.einsum('ebnd,ebnfc->efbcd', expert_inputs_outer, dispatch_tensor_inner)
# Now feed the expert inputs through the experts.
orig_shape = expert_inputs.shape
expert_inputs = expert_inputs.reshape(eo, ei, -1, d)
expert_outputs = self.experts(expert_inputs)
expert_outputs = expert_outputs.reshape(*orig_shape)
# NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
# expert_output has shape [y0, x1, h, d, n]
expert_outputs_outer = torch.einsum('efbcd,ebnfc->ebnd', expert_outputs, combine_tensor_inner)
output = torch.einsum('ebcd,bnec->bnd', expert_outputs_outer, combine_tensor_outer)
return output, (loss_outer + loss_inner) * self.loss_coef