-
Notifications
You must be signed in to change notification settings - Fork 153
/
Copy pathswin_transformer_3d_encoder.py
538 lines (481 loc) · 19 KB
/
swin_transformer_3d_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Modified from 2d Swin Transformers in torchvision:
# https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py
from functools import lru_cache, partial
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torchvision.models.vision_transformer import MLPBlock
from torchvision.ops.stochastic_depth import StochasticDepth
def _compute_pad_size_3d(
size_dhw: Tuple[int, int, int], patch_size: Tuple[int, int, int]
) -> Tuple[int, int, int]:
pad_size = [
(patch_size[i] - size_dhw[i] % patch_size[i]) % patch_size[i] for i in range(3)
]
return (pad_size[0], pad_size[1], pad_size[2])
# Cache the attention mask for performance
@lru_cache
def _compute_attention_mask_3d(
x: Tensor,
size_dhw: Tuple[int, int, int],
window_size: Tuple[int, int, int],
shift_size: Tuple[int, int, int],
) -> Tensor:
# generate attention mask
attn_mask = x.new_zeros(*size_dhw)
num_windows = (
(size_dhw[0] // window_size[0])
* (size_dhw[1] // window_size[1])
* (size_dhw[2] // window_size[2])
)
slices = [
(
(0, -window_size[i]),
(-window_size[i], -shift_size[i]),
(-shift_size[i], None),
)
for i in range(3)
]
count = 0
for d in slices[0]:
for h in slices[1]:
for w in slices[2]:
attn_mask[d[0] : d[1], h[0] : h[1], w[0] : w[1]] = count
count += 1
# Partition window on attn_mask
attn_mask = attn_mask.view(
size_dhw[0] // window_size[0],
window_size[0],
size_dhw[1] // window_size[1],
window_size[1],
size_dhw[2] // window_size[2],
window_size[2],
)
attn_mask = attn_mask.permute(0, 2, 4, 1, 3, 5).reshape(
num_windows, window_size[0] * window_size[1] * window_size[2]
)
attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
return attn_mask
class PatchMerging(nn.Module):
"""Patch Merging Layer.
Args:
dim (int): Number of input channels.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
"""
def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x: Tensor):
"""
Args:
x (Tensor): input tensor with expected layout of [..., H, W, C]
Returns:
Tensor with layout of [..., H/2, W/2, 2*C]
"""
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
x = self.norm(x)
x = self.reduction(x) # ... H/2 W/2 2*C
return x
def shifted_window_attention_3d(
input: Tensor,
qkv_weight: Tensor,
proj_weight: Tensor,
relative_position_bias: Tensor,
window_size: List[int],
num_heads: int,
shift_size: List[int],
attention_dropout: float = 0.0,
dropout: float = 0.0,
qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None,
):
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
input (Tensor[B, D, H, W, C]): The input tensor, 5-dimensions.
qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
relative_position_bias (Tensor): The learned relative position bias added to attention.
window_size (List[int]): 3-dimensions window size, D, H, W .
num_heads (int): Number of attention heads.
shift_size (List[int]): Shift size for shifted window attention (D, H, W).
attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
dropout (float): Dropout ratio of output. Default: 0.0.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
Returns:
Tensor[B, D, H, W, C]: The output tensor after shifted window attention.
"""
B, D, H, W, C = input.shape
# pad feature maps to multiples of window size
pad_size = _compute_pad_size_3d(
(D, H, W), (window_size[0], window_size[1], window_size[2])
)
x = F.pad(input, (0, 0, 0, pad_size[2], 0, pad_size[1], 0, pad_size[0]))
_, Dp, Hp, Wp, _ = x.shape
padded_size = (Dp, Hp, Wp)
# cyclic shift
if sum(shift_size) > 0:
x = torch.roll(
x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)
)
# partition windows
num_windows = (
(padded_size[0] // window_size[0])
* (padded_size[1] // window_size[1])
* (padded_size[2] // window_size[2])
)
x = x.view(
B,
padded_size[0] // window_size[0],
window_size[0],
padded_size[1] // window_size[1],
window_size[1],
padded_size[2] // window_size[2],
window_size[2],
C,
)
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).reshape(
B * num_windows, window_size[0] * window_size[1] * window_size[2], C
) # B*nW, Wd*Wh*Ww, C
# multi-head attention
qkv = F.linear(x, qkv_weight, qkv_bias)
qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(
2, 0, 3, 1, 4
)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * (C // num_heads) ** -0.5
attn = q.matmul(k.transpose(-2, -1))
# add relative position bias
attn = attn + relative_position_bias
if sum(shift_size) > 0:
# generate attention mask to handle shifted windows with varying size
attn_mask = _compute_attention_mask_3d(
x,
(padded_size[0], padded_size[1], padded_size[2]),
(window_size[0], window_size[1], window_size[2]),
(shift_size[0], shift_size[1], shift_size[2]),
)
attn = attn.view(
x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)
)
attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, num_heads, x.size(1), x.size(1))
attn = F.softmax(attn, dim=-1)
attn = F.dropout(attn, p=attention_dropout)
x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
x = F.linear(x, proj_weight, proj_bias)
x = F.dropout(x, p=dropout)
# reverse windows
x = x.view(
B,
padded_size[0] // window_size[0],
padded_size[1] // window_size[1],
padded_size[2] // window_size[2],
window_size[0],
window_size[1],
window_size[2],
C,
)
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).reshape(B, Dp, Hp, Wp, C)
# reverse cyclic shift
if sum(shift_size) > 0:
x = torch.roll(
x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)
)
# unpad features
x = x[:, :D, :H, :W, :].contiguous()
return x
class ShiftedWindowAttention3d(nn.Module):
"""
See :func:`shifted_window_attention_3d`.
"""
def __init__(
self,
dim: int,
window_size: List[int],
shift_size: List[int],
num_heads: int,
qkv_bias: bool = True,
proj_bias: bool = True,
attention_dropout: float = 0.0,
dropout: float = 0.0,
):
super().__init__()
self.window_size = window_size # Wd, Wh, Ww
self.shift_size = shift_size
self.num_heads = num_heads
self.attention_dropout = attention_dropout
self.dropout = dropout
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros(
(2 * window_size[0] - 1)
* (2 * window_size[1] - 1)
* (2 * window_size[2] - 1),
num_heads,
)
) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_dhw = [torch.arange(self.window_size[i]) for i in range(3)]
coords = torch.stack(
torch.meshgrid(coords_dhw[0], coords_dhw[1], coords_dhw[2], indexing="ij")
) # 3, Wd, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 3, Wd*Wh*Ww, Wd*Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 2] += self.window_size[2] - 1
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (
2 * self.window_size[2] - 1
)
relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(self, x: Tensor):
_, D, H, W, _ = x.shape
size_dhw = (D, H, W)
window_size, shift_size = self.window_size.copy(), self.shift_size.copy()
# Handle case where window_size is larger than the input tensor
for i in range(3):
if size_dhw[i] <= window_size[i]:
# In this case, window_size will adapt to the input size, and no need to shift
window_size[i] = size_dhw[i]
shift_size[i] = 0
N = window_size[0] * window_size[1] * window_size[2]
relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)] # type: ignore[index]
relative_position_bias = relative_position_bias.view(N, N, -1)
relative_position_bias = (
relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
)
return shifted_window_attention_3d(
x,
self.qkv.weight,
self.proj.weight,
relative_position_bias,
window_size,
self.num_heads,
shift_size=shift_size,
attention_dropout=self.attention_dropout,
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
)
class SwinTransformerBlock(nn.Module):
"""
Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (List[int]): Window size.
shift_size (List[int]): Shift size for shifted window attention.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention3d
"""
def __init__(
self,
dim: int,
num_heads: int,
window_size: List[int],
shift_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention3d,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_layer(
dim,
window_size,
shift_size,
num_heads,
attention_dropout=attention_dropout,
dropout=dropout,
)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout)
def forward(self, x: Tensor):
x = x + self.stochastic_depth(self.attn(self.norm1(x)))
x = x + self.stochastic_depth(self.mlp(self.norm2(x)))
return x
# Modified from https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py#L416
class PatchEmbed3d(nn.Module):
"""Video to Patch Embedding.
Args:
patch_size (List[int]): Patch token size.
in_channels (int): Number of input channels. Default: 3
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
patch_size: List[int],
in_channels: int = 3,
embed_dim: int = 96,
norm_layer: Optional[Callable[..., nn.Module]] = None,
):
super().__init__()
self.tuple_patch_size = (patch_size[0], patch_size[1], patch_size[2])
self.proj = nn.Conv3d(
in_channels,
embed_dim,
kernel_size=self.tuple_patch_size,
stride=self.tuple_patch_size,
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, D, H, W = x.size()
pad_size = _compute_pad_size_3d((D, H, W), self.tuple_patch_size)
x = F.pad(x, (0, pad_size[2], 0, pad_size[1], 0, pad_size[0]))
x = self.proj(x) # B C D Wh Ww
x = x.permute(0, 2, 3, 4, 1) # B D Wh Ww C
if self.norm is not None:
x = self.norm(x)
return x
class SwinTransformer3d(nn.Module):
"""
Implements 3D Swin Transformer from the `"Video Swin Transformer" <https://arxiv.org/abs/2106.13230>`_ paper.
Args:
patch_size (List[int]): Patch size.
embed_dim (int): Patch embedding dimension.
depths (List(int)): Depth of each Swin Transformer layer.
num_heads (List(int)): Number of attention heads in different layers.
window_size (List[int]): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0.
num_classes (int, optional): Number of classes for classification head,
if None it will have no head. Default: 400.
block (nn.Module, optional): SwinTransformer Block. Default: None.
norm_layer (nn.Module, optional): Normalization layer. Default: None.
patch_embed (nn.Module, optional): Patch Embedding layer. Default: None.
"""
def __init__(
self,
patch_size: List[int],
embed_dim: int,
depths: List[int],
num_heads: List[int],
window_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
num_classes: Optional[int] = 400,
norm_layer: Optional[Callable[..., nn.Module]] = None,
block: Optional[Callable[..., nn.Module]] = None,
patch_embed: Optional[Callable[..., nn.Module]] = None,
):
super().__init__()
self.num_classes = num_classes
if block is None:
block = SwinTransformerBlock
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-5)
if patch_embed is None:
patch_embed = PatchEmbed3d
# split image into non-overlapping patches
self.patch_embed = patch_embed(
patch_size=patch_size, embed_dim=embed_dim, norm_layer=norm_layer
)
self.pos_drop = nn.Dropout(p=dropout)
layers: List[nn.Module] = []
total_stage_blocks = sum(depths)
stage_block_id = 0
# build SwinTransformer blocks
for i_stage in range(len(depths)):
stage: List[nn.Module] = []
dim = embed_dim * 2 ** i_stage
for i_layer in range(depths[i_stage]):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = (
stochastic_depth_prob
* float(stage_block_id)
/ (total_stage_blocks - 1)
)
stage.append(
block(
dim,
num_heads[i_stage],
window_size=window_size,
shift_size=[
0 if i_layer % 2 == 0 else w // 2 for w in window_size
],
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_dropout=attention_dropout,
stochastic_depth_prob=sd_prob,
norm_layer=norm_layer,
attn_layer=ShiftedWindowAttention3d,
)
)
stage_block_id += 1
layers.append(nn.Sequential(*stage))
# add patch merging layer
if i_stage < (len(depths) - 1):
layers.append(PatchMerging(dim, norm_layer))
self.features = nn.Sequential(*layers)
self.num_features = embed_dim * 2 ** (len(depths) - 1)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool3d(1)
if num_classes is not None:
self.head = nn.Linear(self.num_features, num_classes)
else:
self.head = None
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
# x: B C D H W
x = self.patch_embed(x) # B _D _H _W C
x = self.pos_drop(x)
x = self.features(x) # B _D _H _W C
x = self.norm(x)
x = x.permute(0, 4, 1, 2, 3) # B, C, _D, _H, _W
x = self.avgpool(x)
x = torch.flatten(x, 1)
if self.num_classes is not None:
x = self.head(x)
return x