Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor swin transfomer so later we can reuse component for 3d version #6088

Merged
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
fb1a749
Use List[int] instead of int for window_size and shift_size
YosuaMichael May 25, 2022
c3902ae
Make PatchMerging and SwinTransformerBlock able to handle 2d and 3d c…
YosuaMichael May 25, 2022
13d4d85
Separate patch embedding from SwinTransformer and enable to get model…
YosuaMichael May 25, 2022
a3d1192
Dont use if before padding so it is fx friendly
YosuaMichael May 25, 2022
6a57c15
Merge branch 'models/refactor-swin-transformer' into models/refactor-…
YosuaMichael May 25, 2022
ba5f8f9
Put the handling on window_size edge cases on separate function and w…
YosuaMichael May 25, 2022
ed14bd7
Update the weight url to the converted weight with new structure
YosuaMichael May 25, 2022
548109a
Update the accuracy of swin_transformer
YosuaMichael May 25, 2022
5d09c35
Merge pull request #2 from YosuaMichael/models/refactor-swin-transfor…
YosuaMichael May 25, 2022
2d767f8
Change assert to Exception and nit
YosuaMichael May 25, 2022
65c8439
Make num_classes optional
YosuaMichael May 25, 2022
2c0dead
Merge branch 'main' into models/refactor-swin-transformer
YosuaMichael May 25, 2022
f0f872f
Add typing output for _fix_window_and_shift_size function
YosuaMichael May 26, 2022
e2317f4
init head to None to make it jit scriptable
YosuaMichael May 26, 2022
77ce2b9
Revert the change to make num_classes optional
YosuaMichael May 26, 2022
f04801f
Revert unneccesarry changes that might be risky
YosuaMichael May 26, 2022
480f762
Merge branch 'main' into models/refactor-swin-transformer
YosuaMichael May 26, 2022
8cf5e39
Remove self.head declaration
YosuaMichael May 26, 2022
c260edb
Merge branch 'models/refactor-swin-transformer' of github.com:YosuaMi…
YosuaMichael May 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 83 additions & 63 deletions torchvision/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,23 @@ def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm
self.norm = norm_layer(4 * dim)

def forward(self, x: Tensor):
B, H, W, C = x.shape

x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
"""
Args:
x (Tensor): input tensor with expected layout of [..., H, W, C]
Returns:
Tensor with layout of [..., H/2, W/2, 2*C]
"""
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C

x = self.norm(x)
x = self.reduction(x)
x = x.view(B, H // 2, W // 2, 2 * C)
x = self.reduction(x) # ... H/2 W/2 2*C
return x


Expand All @@ -59,9 +64,9 @@ def shifted_window_attention(
qkv_weight: Tensor,
proj_weight: Tensor,
relative_position_bias: Tensor,
window_size: int,
window_size: List[int],
num_heads: int,
shift_size: int = 0,
shift_size: List[int],
attention_dropout: float = 0.0,
dropout: float = 0.0,
qkv_bias: Optional[Tensor] = None,
Expand All @@ -75,9 +80,9 @@ def shifted_window_attention(
qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
relative_position_bias (Tensor): The learned relative position bias added to attention.
window_size (int): Window size.
window_size (List[int]): Window size.
num_heads (int): Number of attention heads.
shift_size (int): Shift size for shifted window attention. Default: 0.
shift_size (List[int]): Shift size for shifted window attention.
attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
dropout (float): Dropout ratio of output. Default: 0.0.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
Expand All @@ -87,23 +92,25 @@ def shifted_window_attention(
"""
B, H, W, C = input.shape
# pad feature maps to multiples of window size
pad_r = (window_size - W % window_size) % window_size
pad_b = (window_size - H % window_size) % window_size
pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
_, pad_H, pad_W, _ = x.shape

# If window size is larger than feature size, there is no need to shift window.
if window_size == min(pad_H, pad_W):
shift_size = 0
# If window size is larger than feature size, there is no need to shift window
if window_size[0] >= pad_H:
shift_size[0] = 0
if window_size[1] >= pad_W:
shift_size[1] = 0

# cyclic shift
if shift_size > 0:
x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
if sum(shift_size) > 0:
x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))

# partition windows
num_windows = (pad_H // window_size) * (pad_W // window_size)
x = x.view(B, pad_H // window_size, window_size, pad_W // window_size, window_size, C)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size * window_size, C) # B*nW, Ws*Ws, C
num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C

# multi-head attention
qkv = F.linear(x, qkv_weight, qkv_bias)
Expand All @@ -114,17 +121,18 @@ def shifted_window_attention(
# add relative position bias
attn = attn + relative_position_bias

if shift_size > 0:
if sum(shift_size) > 0:
# generate attention mask
attn_mask = x.new_zeros((pad_H, pad_W))
slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, None))
h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None))
w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None))
count = 0
for h in slices:
for w in slices:
for h in h_slices:
for w in w_slices:
attn_mask[h[0] : h[1], w[0] : w[1]] = count
count += 1
attn_mask = attn_mask.view(pad_H // window_size, window_size, pad_W // window_size, window_size)
attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size * window_size)
attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1])
attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1])
attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
Expand All @@ -139,12 +147,12 @@ def shifted_window_attention(
x = F.dropout(x, p=dropout)

# reverse windows
x = x.view(B, pad_H // window_size, pad_W // window_size, window_size, window_size, C)
x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)

# reverse cyclic shift
if shift_size > 0:
x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2))
if sum(shift_size) > 0:
x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))

# unpad features
x = x[:, :H, :W, :].contiguous()
Expand All @@ -162,15 +170,17 @@ class ShiftedWindowAttention(nn.Module):
def __init__(
self,
dim: int,
window_size: int,
shift_size: int,
window_size: List[int],
shift_size: List[int],
num_heads: int,
qkv_bias: bool = True,
proj_bias: bool = True,
attention_dropout: float = 0.0,
dropout: float = 0.0,
):
super().__init__()
if len(window_size) != 2 or len(shift_size) != 2:
raise ValueError("window_size and shift_size must be of length 2")
self.window_size = window_size
self.shift_size = shift_size
self.num_heads = num_heads
Expand All @@ -182,29 +192,35 @@ def __init__(

# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size)
coords_w = torch.arange(self.window_size)
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size - 1
relative_coords[:, :, 0] *= 2 * self.window_size - 1
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

def forward(self, x: Tensor):
"""
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
Returns:
Tensor with same layout as input, i.e. [B, H, W, C]
"""

N = self.window_size[0] * self.window_size[1]
relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index]
relative_position_bias = relative_position_bias.view(
self.window_size * self.window_size, self.window_size * self.window_size, -1
)
relative_position_bias = relative_position_bias.view(N, N, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)

return shifted_window_attention(
Expand All @@ -228,31 +244,33 @@ class SwinTransformerBlock(nn.Module):
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size. Default: 7.
shift_size (int): Shift size for shifted window attention. Default: 0.
window_size (List[int]): Window size.
shift_size (List[int]): Shift size for shifted window attention.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention
"""

def __init__(
self,
dim: int,
num_heads: int,
window_size: int = 7,
shift_size: int = 0,
window_size: List[int],
shift_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention,
):
super().__init__()

self.norm1 = norm_layer(dim)
self.attn = ShiftedWindowAttention(
self.attn = attn_layer(
dim,
window_size,
shift_size,
Expand Down Expand Up @@ -281,11 +299,11 @@ class SwinTransformer(nn.Module):
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
Shifted Windows" <https://arxiv.org/pdf/2103.14030>`_ paper.
Args:
patch_size (int): Patch size.
patch_size (List[int]): Patch size.
embed_dim (int): Patch embedding dimension.
depths (List(int)): Depth of each Swin Transformer layer.
num_heads (List(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7.
window_size (List[int]): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
Expand All @@ -297,11 +315,11 @@ class SwinTransformer(nn.Module):

def __init__(
self,
patch_size: int,
patch_size: List[int],
embed_dim: int,
depths: List[int],
num_heads: List[int],
window_size: int = 7,
window_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
Expand All @@ -324,7 +342,9 @@ def __init__(
# split image into non-overlapping patches
layers.append(
nn.Sequential(
nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size),
nn.Conv2d(
3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1])
),
Permute([0, 2, 3, 1]),
norm_layer(embed_dim),
)
Expand All @@ -344,7 +364,7 @@ def __init__(
dim,
num_heads[i_stage],
window_size=window_size,
shift_size=0 if i_layer % 2 == 0 else window_size // 2,
shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size],
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_dropout=attention_dropout,
Expand Down Expand Up @@ -381,11 +401,11 @@ def forward(self, x):


def _swin_transformer(
patch_size: int,
patch_size: List[int],
embed_dim: int,
depths: List[int],
num_heads: List[int],
window_size: int,
window_size: List[int],
stochastic_depth_prob: float,
weights: Optional[WeightsEnum],
progress: bool,
Expand Down Expand Up @@ -508,11 +528,11 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
weights = Swin_T_Weights.verify(weights)

return _swin_transformer(
patch_size=4,
patch_size=[4, 4],
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
window_size=[7, 7],
stochastic_depth_prob=0.2,
weights=weights,
progress=progress,
Expand Down Expand Up @@ -544,11 +564,11 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, *
weights = Swin_S_Weights.verify(weights)

return _swin_transformer(
patch_size=4,
patch_size=[4, 4],
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
window_size=[7, 7],
stochastic_depth_prob=0.3,
weights=weights,
progress=progress,
Expand Down Expand Up @@ -580,11 +600,11 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, *
weights = Swin_B_Weights.verify(weights)

return _swin_transformer(
patch_size=4,
patch_size=[4, 4],
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=7,
window_size=[7, 7],
stochastic_depth_prob=0.5,
weights=weights,
progress=progress,
Expand Down