-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathssast.py
508 lines (422 loc) · 26.5 KB
/
ssast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
import torch.nn as nn
import torch
import sys
from timm.models.layers import trunc_normal_
import timm
import numpy as np
import os
import wget
from timm.models.layers import to_2tuple
from random import randrange
from matplotlib import pyplot as plt
import random
# override the timm package to relax the input shape constraint.
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2)
return x
def get_sinusoid_encoding(n_position, d_hid):
''' Sinusoid position encoding table '''
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class SSASTModel(nn.Module):
def __init__(self, label_dim=527, fshape=128, tshape=2, fstride=128, tstride=2,
input_fdim=128, input_tdim=1024, model_size='base', task='ft_avgtok',
pretrain_stage=True, load_pretrained_mdl_path=None, mix_beta=None):
super(SSASTModel, self).__init__()
assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.'
# override timm input shape restriction
timm.models.vision_transformer.PatchEmbed = PatchEmbed
self.task = task
# pretrain the AST models
if pretrain_stage == True:
if load_pretrained_mdl_path != None:
raise ValueError('Setting load_pretrained_mdl_path at pretraining stage is useless, pretraining is always from scratch, please change it to None.')
if fstride != fshape or tstride != tshape:
raise ValueError('fstride != fshape or tstride != tshape, they must be same at the pretraining stage, patch split overlapping is not supported.')
# if AudioSet pretraining is not used (but ImageNet pretraining may still apply)
if model_size == 'tiny':
self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=False)
self.heads, self.depth = 3, 12
self.cls_token_num = 2
elif model_size == 'small':
self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=False)
self.heads, self.depth = 6, 12
self.cls_token_num = 2
elif model_size == 'base':
self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=False)
self.heads, self.depth = 12, 12
self.cls_token_num = 2
elif model_size == 'base_nokd':
self.v = timm.create_model('vit_deit_base_patch16_384', pretrained=False)
self.heads, self.depth = 12, 12
self.cls_token_num = 1
else:
raise Exception('Model size must be one of tiny, small, base, base_nokd')
self.original_num_patches = self.v.patch_embed.num_patches
self.oringal_hw = int(self.original_num_patches ** 0.5)
self.original_embedding_dim = self.v.pos_embed.shape[2]
# SSL Pretraining Code
self.softmax = nn.Softmax(dim=-1)
self.lsoftmax = nn.LogSoftmax(dim=-1)
self.fshape, self.tshape = fshape, tshape
self.fstride, self.tstride = fstride, tstride
self.input_fdim, self.input_tdim = input_fdim, input_tdim
# this is a trick to make state_dict to track pretraining input_fdim and input_tdim and save them by using torch.save
self.p_input_fdim, self.p_input_tdim = nn.Parameter(torch.tensor(input_fdim), requires_grad=False), nn.Parameter(torch.tensor(input_tdim), requires_grad=False)
# masked patch classification (discriminative objective) layer
# we use two layers for pretext task, but using a single layer has similar performance.
# we map the output of transformer (768-dim for base models) to 256-dim patch input space, and then dot product with flattened patch input (also 256-dim) to calculate loss.
# alternatively, you can map the output of transformer to 768-dim patch embedding space, and dot product with patch embedding. Performance-wise they are similar, but map to 256 space is more efficient.
self.cpredlayer = nn.Sequential(nn.Linear(self.original_embedding_dim, self.original_embedding_dim), nn.ReLU(), nn.Linear(self.original_embedding_dim, 256))
# masked patch reconstruction (generative objective) layer
self.gpredlayer = nn.Sequential(nn.Linear(self.original_embedding_dim, self.original_embedding_dim), nn.ReLU(), nn.Linear(self.original_embedding_dim, 256))
self.unfold = torch.nn.Unfold(kernel_size=(fshape, tshape), stride=(fstride, tstride))
# we use learnable mask embedding (follow the BEIT paper), but using a fixed mask embedding (e.g., 0) leads to same performance.
self.mask_embed = nn.Parameter(torch.zeros([1, 1, self.original_embedding_dim]))
self.mask_embed = torch.nn.init.xavier_normal_(self.mask_embed)
# get the intermediate shape
self.p_f_dim, self.p_t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim, fshape, tshape)
num_patches = self.p_f_dim * self.p_t_dim
self.num_patches = num_patches
self.v.patch_embed.num_patches = num_patches
print('pretraining patch split stride: frequency={:d}, time={:d}'.format(fstride, tstride))
print('pretraining patch shape: frequency={:d}, time={:d}'.format(fshape, tshape))
print('pretraining patch array dimension: frequency={:d}, time={:d}'.format(self.p_f_dim, self.p_t_dim))
print('pretraining number of patches={:d}'.format(num_patches))
# the linear patch projection layer, use 1 channel for spectrogram rather than the original 3 channels for RGB images.
new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(fshape, tshape), stride=(fstride, tstride))
self.v.patch_embed.proj = new_proj
# use trainable positional embedding
new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + self.cls_token_num, self.original_embedding_dim))
self.v.pos_embed = new_pos_embed
trunc_normal_(self.v.pos_embed, std=.02)
# use a pretrained models for finetuning
elif pretrain_stage == False:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if load_pretrained_mdl_path == None:
raise ValueError('Please set load_pretrained_mdl_path to load a pretrained models.')
# added kaen2891
out_dir = './pretrained_models/'
if not os.path.exists(out_dir):
os.makedirs(out_dir, exist_ok=True)
if load_pretrained_mdl_path == 'Patch':
if os.path.exists(os.path.join(out_dir, 'SSAST-Base-Patch-400.pth')) == False:
# this model performs 59.9 on the Avg Audio Performance
# more details are below: https://github.com/YuanGongND/ssast
audioset_mdl_url = 'https://www.dropbox.com/s/ewrzpco95n9jdz6/SSAST-Base-Patch-400.pth?dl=1'
print('Download SSAST-Base-Patch-400.pth \n')
wget.download(audioset_mdl_url, out=os.path.join(out_dir, 'SSAST-Base-Patch-400.pth'))
sd = torch.load(os.path.join(out_dir, 'SSAST-Base-Patch-400.pth'), map_location=device)
print('Loaded SSAST-Base-Patch-400.pth successfully.')
elif load_pretrained_mdl_path == 'Frame':
if os.path.exists(os.path.join(out_dir, 'SSAST-Base-Frame-400.pth')) == False:
# this model performs 57.6 on the Avg Audio Performance
# more details are below: https://github.com/YuanGongND/ssast
audioset_mdl_url = 'https://www.dropbox.com/s/nx6nl4d4bl71sm8/SSAST-Base-Frame-400.pth?dl=1'
print('Download SSAST-Base-Frame-400.pth \n')
wget.download(audioset_mdl_url, out=os.path.join(out_dir, 'SSAST-Base-Frame-400.pth'))
sd = torch.load(os.path.join(out_dir, 'SSAST-Base-Frame-400.pth'), map_location=device)
print('\n Loaded SSAST-Base-Frame-400.pth successfully.')
# get the fshape and tshape, input_fdim and input_tdim in the pretraining stage
try:
p_fshape, p_tshape = sd['module.v.patch_embed.proj.weight'].shape[2], sd['module.v.patch_embed.proj.weight'].shape[3]
p_input_fdim, p_input_tdim = sd['module.p_input_fdim'].item(), sd['module.p_input_tdim'].item()
except:
raise ValueError('The model loaded is not from a torch.nn.Dataparallel object. Wrap it with torch.nn.Dataparallel and try again.')
print('now load a SSL pretrained models from ' + load_pretrained_mdl_path)
# during pretraining, fstride=fshape and tstride=tshape because no patch overlapping is used
# here, input_fdim and input_tdim should be that used in pretraining, not that in the fine-tuning.
# we need to know input_fdim and input_tdim to do positional embedding cut/interpolation.
# generally it should be better to use same input_fdim during pretraining and finetuning, but input_tdim can be safely different
audio_model = SSASTModel(fstride=p_fshape, tstride=p_tshape, fshape=p_fshape, tshape=p_tshape,
input_fdim=p_input_fdim, input_tdim=p_input_tdim, pretrain_stage=True, model_size=model_size)
audio_model = torch.nn.DataParallel(audio_model)
audio_model.load_state_dict(sd, strict=False)
self.v = audio_model.module.v
self.final_feat_dim = self.original_embedding_dim = self.v.pos_embed.shape[2]
self.mix_beta = mix_beta
self.cls_token_num = audio_model.module.cls_token_num
# mlp head for fine-tuning
self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim),
nn.Linear(self.original_embedding_dim, label_dim))
f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim, fshape, tshape)
# patch array dimension during pretraining
p_f_dim, p_t_dim = audio_model.module.p_f_dim, audio_model.module.p_t_dim
num_patches = f_dim * t_dim
p_num_patches = p_f_dim * p_t_dim
self.v.patch_embed.num_patches = num_patches
print('fine-tuning patch split stride: frequncey={:d}, time={:d}'.format(fstride, tstride))
print('fine-tuning number of patches={:d}'.format(num_patches))
# patch shape should be same for pretraining and fine-tuning
if fshape != p_fshape or tshape != p_tshape:
raise ValueError('The patch shape of pretraining and fine-tuning is not consistant, pretraining: f={:d}, t={:d}, finetuning: f={:d}, t={:d}'.format(p_fshape, p_tshape, fshape, tshape))
# patch split stride generally should be different for pretraining and fine-tuning, as patch split overlapping is only used in finetuning
# during pretraining, p_fshape = p_fstride and p_tshape = p_tstride
if fstride != p_fshape or tstride != p_tshape:
# initialize a new patch embedding layer with desired new stride.
new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(fshape, tshape), stride=(fstride, tstride))
# but the weights of patch embedding layer is still got from the pretrained models
new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1))
new_proj.bias = self.v.patch_embed.proj.bias
self.v.patch_embed.proj = new_proj
new_pos_embed = self.v.pos_embed[:, self.cls_token_num:, :].detach().reshape(1, p_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, p_f_dim, p_t_dim)
# cut or interpolate the positional embedding
if t_dim < p_t_dim:
new_pos_embed = new_pos_embed[:, :, :, int(p_t_dim/2) - int(t_dim / 2): int(p_t_dim/2) - int(t_dim / 2) + t_dim]
else:
new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(8, t_dim), mode='bilinear')
if f_dim < p_f_dim:
new_pos_embed = new_pos_embed[:, :, int(p_f_dim/2) - int(f_dim / 2): int(p_f_dim/2) - int(f_dim / 2) + t_dim, :]
else:
new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1, 2)
self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :self.cls_token_num, :].detach(), new_pos_embed], dim=1))
# get the shape of intermediate representation.
def get_shape(self, fstride, tstride, input_fdim, input_tdim, fshape, tshape):
test_input = torch.randn(1, 1, input_fdim, input_tdim)
test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(fshape, tshape), stride=(fstride, tstride))
test_out = test_proj(test_input)
f_dim = test_out.shape[2]
t_dim = test_out.shape[3]
return f_dim, t_dim
# generate mask for 16*16 patch
def gen_maskid_patch(self, sequence_len=512, mask_size=100, cluster=3):
mask_id = []
# randomize clutering factor in [3,6)
cur_clus = randrange(cluster) + 3
while len(list(set(mask_id))) <= mask_size:
start_id = randrange(sequence_len)
# this improves the efficiency, but might change the pretrained model
# while start_id in mask_id:
# start_id = randrange(sequence_len)
cur_mask = []
for i in range(0, cur_clus):
for j in range(0, cur_clus):
mask_cand = start_id + self.p_t_dim * i + j
if mask_cand > 0 and mask_cand < sequence_len:
cur_mask.append(mask_cand)
mask_id = mask_id + cur_mask
mask_id = list(set(mask_id))[:mask_size]
return torch.tensor(mask_id)
# using cluster for frame masking hurts the performance, so just use the naive random sampling
def gen_maskid_frame(self, sequence_len=512, mask_size=100):
mask_id = random.sample(range(0, sequence_len), mask_size)
return torch.tensor(mask_id)
def finetuningavgtok(self, x, y=None, patch_mix=False):
B = x.shape[0]
x = self.v.patch_embed(x)
if patch_mix:
x, y_a, y_b, lam, index = self.patch_mix(x, y)
if self.cls_token_num == 2:
cls_tokens = self.v.cls_token.expand(B, -1, -1)
dist_token = self.v.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
else:
cls_tokens = self.v.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.v.pos_embed
x = self.v.pos_drop(x)
for blk_id, blk in enumerate(self.v.blocks):
x = blk(x)
x = self.v.norm(x)
# average output of all tokens except cls token(s)
x = torch.mean(x[:, self.cls_token_num:, :], dim=1)
# x = self.mlp_head(x)
if not patch_mix:
return x
else:
return x, y_a, y_b, lam, index
def finetuningcls(self, x):
B = x.shape[0]
x = self.v.patch_embed(x)
if self.cls_token_num == 2:
cls_tokens = self.v.cls_token.expand(B, -1, -1)
dist_token = self.v.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
else:
cls_tokens = self.v.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.v.pos_embed
x = self.v.pos_drop(x)
for blk_id, blk in enumerate(self.v.blocks):
x = blk(x)
x = self.v.norm(x)
# if models has two cls tokens (DEIT), average as the clip-level representation
if self.cls_token_num == 2:
x = (x[:, 0] + x[:, 1]) / 2
else:
x = x[:, 0]
# x = self.mlp_head(x)
return x
# masked patch pretraining with discriminative objective
def mpc(self, x, mask_patch, cluster, show_mask=False):
input = self.unfold(x).transpose(1, 2)
B = x.shape[0]
# x in shape (batch_size, sequence_len, embedding dim)
x = self.v.patch_embed(x)
# encode the patch
# size 12(batch_size) * 100(#mask_patch) * 768(hidden_dim), prepare to save the true values of masked samples
encode_samples = torch.empty((B, mask_patch, 256), device=x.device, requires_grad=False).float()
# size 12(batch_size) * 100(#mask_patch), index of masked patches
mask_index = torch.empty((B, mask_patch), device=x.device, requires_grad=False).long()
# size 12(batch_size) * 512(sequence_len) * 768(hidden_dim)
mask_dense = torch.ones([x.shape[0], x.shape[1], x.shape[2]], device=x.device)
# for each audio clip in the batch
for i in range(B):
# randomly generate #mask_patch mask indexes without duplicate
if cluster == True:
# use this if you are masking e.g. 16*16 patches
mask_index[i] = self.gen_maskid_patch(self.num_patches, mask_patch)
else:
# use this if you are masking frame, i.e., 128*2 patches
mask_index[i] = self.gen_maskid_frame(self.num_patches, mask_patch)
# copy the masked embeddings, note gradients are stopped in this path
encode_samples[i] = input[i, mask_index[i], :].clone().detach()
# mask the encode samples with 0
mask_dense[i, mask_index[i], :] = 0
# follow BEIT paper, mask with learnable masking embedding, but no performance diff observed compared with masking with 0s.
mask_tokens = self.mask_embed.expand(B, x.shape[1], -1)
# mask the patch
x = x * mask_dense + (1-mask_dense) * mask_tokens
# pass through the Transformer layers
cls_tokens = self.v.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.v.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.v.pos_embed
x = self.v.pos_drop(x)
for blk in self.v.blocks:
x = blk(x)
x = self.v.norm(x)
# prediction of the masked patch
pred = torch.empty((B, mask_patch, 256), device=x.device).float() # e.g. size 12*100*768
for i in range(B):
# +2 for indexes because skipping the cls and dis token
# we map the output of transformer (768-dim for base models) to 256-dim patch input space, and then dot product with flattened patch input (also 256-dim) to calculate loss.
# alternatively, you can map the output of transformer to 768-dim patch embedding space, and dot product with patch embedding. Performance-wise they are similar, but map to 256 space is more efficient.
pred[i] = self.cpredlayer(x[i, mask_index[i] + self.cls_token_num, :])
# calculate the NCE loss
nce = torch.tensor(0.0).to(x.device)
correct = torch.tensor(0.0).to(x.device)
for i in np.arange(0, B):
# negative samples are from the same batch
# 8/12/2022: has a difference with equation (1) in the ssast paper but (likely) performance-wise similar, see https://github.com/YuanGongND/ssast/issues/13
total = torch.mm(encode_samples[i], torch.transpose(pred[i], 0, 1)) # e.g. size 100*100
correct += torch.sum(torch.eq(torch.argmax(self.softmax(total), dim=0), torch.arange(0, mask_patch, device=x.device))) # correct is a tensor
nce += torch.sum(torch.diag(self.lsoftmax(total))) # nce is a tensor
acc = 1. * correct / (B * mask_patch)
nce = nce / (-1. * B * mask_patch)
# visualize the masked area, for probing test only, set show_mask = False for any training/inference.
if show_mask == False:
return acc, nce
else:
if B > 1:
raise Exception('Currently only support single spectrogram probing test.')
self.mask_correct = torch.nn.Parameter(torch.arange(0, mask_patch), requires_grad=False)
pred = input.clone() # [B, 512, 256]
masked = input.clone()
for i in range(B):
result = [float(t) * 99 for t in torch.eq(torch.argmax(self.softmax(total), dim=0), self.mask_correct)]
pred[i, mask_index[i], :] = torch.tensor(result).reshape(mask_patch, 1).expand(mask_patch, 256)
masked[i, mask_index[i], :] = 99.0
# print(total)
# print(self.softmax(total))
# print(torch.argmax(self.softmax(total), dim=0))
# print(self.mask_correct)
# print(torch.eq(torch.argmax(self.softmax(total), dim=0), self.mask_correct))
# print([float(t)*99 for t in torch.eq(torch.argmax(self.softmax(total), dim=0), self.mask_correct)])
fold = torch.nn.Fold(output_size=([self.input_fdim, self.input_tdim]), kernel_size=(self.fshape, self.tshape), stride=(self.fstride, self.tstride))
pred = fold(pred.transpose(1, 2))
masked = fold(masked.transpose(1, 2))
return pred, masked
# # masked patch pretraining with generative objective
def mpg(self, input, mask_patch, cluster):
B = input.shape[0]
x = self.v.patch_embed(input)
input = self.unfold(input).transpose(1, 2)
# size 12(batch_size) * 100(#mask_patch), index of masked patches
mask_index = torch.empty((B, mask_patch), device=x.device, requires_grad=False).long()
# size 12(batch_size) * 512(sequence_len) * 768(hidden_dim)
mask_dense = torch.ones([x.shape[0], x.shape[1], x.shape[2]], device=x.device)
for i in range(B):
# randomly generate #mask_patch mask indexes without duplicate
if cluster == True:
# use this if you are masking e.g. 16*16 patches
mask_index[i] = self.gen_maskid_patch(self.num_patches, mask_patch)
else:
# use this if you are masking frame, i.e., 128*2 patches
mask_index[i] = self.gen_maskid_frame(self.num_patches, mask_patch)
mask_dense[i, mask_index[i], :] = 0
mask_tokens = self.mask_embed.expand(B, x.shape[1], -1)
# follow BEIT paper, mask with learnable masking embedding, but no performance diff observed compared with masking with 0s.
x = x * mask_dense + (1-mask_dense) * mask_tokens
# go through the Transformer layers
cls_tokens = self.v.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.v.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.v.pos_embed
x = self.v.pos_drop(x)
for blk in self.v.blocks:
x = blk(x)
x = self.v.norm(x)
pred = torch.empty((B, mask_patch, self.fshape * self.tshape), device=x.device).float() # e.g. size 12*100*256
target = torch.empty((B, mask_patch, self.fshape * self.tshape), device=x.device).float() # e.g. size 12*100*256
for i in range(B):
# +2 for indexes because cls and dis token
pred[i] = self.gpredlayer(x[i, mask_index[i] + self.cls_token_num, :])
target[i] = input[i, mask_index[i], :]
# calculate the MSE loss
mse = torch.mean((pred - target) ** 2)
return mse
def patch_mix(self, image, target):
if self.mix_beta > 0:
lam = np.random.beta(self.mix_beta, self.mix_beta)
else:
lam = 1
batch_size, num_patch, dim = image.size()
device = image.device
index = torch.randperm(batch_size).to(device)
num_mask = int(num_patch * (1. - lam))
mask = torch.randperm(num_patch)[:num_mask].to(device)
image[:, mask, :] = image[index][:, mask, :]
lam = 1 - (num_mask / num_patch)
y_a, y_b = target, target[index]
return image, y_a, y_b, lam, index
def forward(self, x, y=None, patch_mix=False):
cluster = True
mask_patch = 400
# expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
# x = x.unsqueeze(1) # (12, 1, 1024, 128)
x = x.transpose(2, 3) # (12, 1, 128, 1024)
# finetuning (ft), use the mean of all token (patch) output as clip-level representation.
# this is default for SSAST fine-tuning as during pretraining, supervision signal is given to each token, not the [cls] token
if self.task == 'ft_avgtok':
return self.finetuningavgtok(x, y, patch_mix)
# alternatively, use the [cls] token output as clip-level representation.
elif self.task == 'ft_cls':
return self.finetuningcls(x)
# pretraining, masked patch classification (discriminative objective)
elif self.task == 'pretrain_mpc':
return self.mpc(x, mask_patch=mask_patch, cluster=cluster)
# pretraining, masked patch reconstruction (generative objective)
elif self.task == 'pretrain_mpg':
return self.mpg(x, mask_patch=mask_patch, cluster=cluster)
elif self.task == 'visualize_mask':
return self.mpc(x, mask_patch=mask_patch, cluster=cluster, show_mask=True)
else:
raise Exception('Task unrecognized.')