diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 25e8900db56..c56093ed4bf 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -39,18 +39,23 @@ def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm self.norm = norm_layer(4 * dim) def forward(self, x: Tensor): - B, H, W, C = x.shape - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + """ + Args: + x (Tensor): input tensor with expected layout of [..., H, W, C] + Returns: + Tensor with layout of [..., H/2, W/2, 2*C] + """ + H, W, _ = x.shape[-3:] + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C + x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C + x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C + x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C x = self.norm(x) - x = self.reduction(x) - x = x.view(B, H // 2, W // 2, 2 * C) + x = self.reduction(x) # ... H/2 W/2 2*C return x @@ -59,9 +64,9 @@ def shifted_window_attention( qkv_weight: Tensor, proj_weight: Tensor, relative_position_bias: Tensor, - window_size: int, + window_size: List[int], num_heads: int, - shift_size: int = 0, + shift_size: List[int], attention_dropout: float = 0.0, dropout: float = 0.0, qkv_bias: Optional[Tensor] = None, @@ -75,9 +80,9 @@ def shifted_window_attention( qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. relative_position_bias (Tensor): The learned relative position bias added to attention. - window_size (int): Window size. + window_size (List[int]): Window size. num_heads (int): Number of attention heads. - shift_size (int): Shift size for shifted window attention. Default: 0. + shift_size (List[int]): Shift size for shifted window attention. attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. dropout (float): Dropout ratio of output. Default: 0.0. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. @@ -87,23 +92,25 @@ def shifted_window_attention( """ B, H, W, C = input.shape # pad feature maps to multiples of window size - pad_r = (window_size - W % window_size) % window_size - pad_b = (window_size - H % window_size) % window_size + pad_r = (window_size[1] - W % window_size[1]) % window_size[1] + pad_b = (window_size[0] - H % window_size[0]) % window_size[0] x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) _, pad_H, pad_W, _ = x.shape - # If window size is larger than feature size, there is no need to shift window. - if window_size == min(pad_H, pad_W): - shift_size = 0 + # If window size is larger than feature size, there is no need to shift window + if window_size[0] >= pad_H: + shift_size[0] = 0 + if window_size[1] >= pad_W: + shift_size[1] = 0 # cyclic shift - if shift_size > 0: - x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) # partition windows - num_windows = (pad_H // window_size) * (pad_W // window_size) - x = x.view(B, pad_H // window_size, window_size, pad_W // window_size, window_size, C) - x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size * window_size, C) # B*nW, Ws*Ws, C + num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1]) + x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C # multi-head attention qkv = F.linear(x, qkv_weight, qkv_bias) @@ -114,17 +121,18 @@ def shifted_window_attention( # add relative position bias attn = attn + relative_position_bias - if shift_size > 0: + if sum(shift_size) > 0: # generate attention mask attn_mask = x.new_zeros((pad_H, pad_W)) - slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, None)) + h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None)) + w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None)) count = 0 - for h in slices: - for w in slices: + for h in h_slices: + for w in w_slices: attn_mask[h[0] : h[1], w[0] : w[1]] = count count += 1 - attn_mask = attn_mask.view(pad_H // window_size, window_size, pad_W // window_size, window_size) - attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size * window_size) + attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1]) + attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1]) attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) @@ -139,12 +147,12 @@ def shifted_window_attention( x = F.dropout(x, p=dropout) # reverse windows - x = x.view(B, pad_H // window_size, pad_W // window_size, window_size, window_size, C) + x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) # reverse cyclic shift - if shift_size > 0: - x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2)) + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) # unpad features x = x[:, :H, :W, :].contiguous() @@ -162,8 +170,8 @@ class ShiftedWindowAttention(nn.Module): def __init__( self, dim: int, - window_size: int, - shift_size: int, + window_size: List[int], + shift_size: List[int], num_heads: int, qkv_bias: bool = True, proj_bias: bool = True, @@ -171,6 +179,8 @@ def __init__( dropout: float = 0.0, ): super().__init__() + if len(window_size) != 2 or len(shift_size) != 2: + raise ValueError("window_size and shift_size must be of length 2") self.window_size = window_size self.shift_size = shift_size self.num_heads = num_heads @@ -182,29 +192,35 @@ def __init__( # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) ) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size) - coords_w = torch.arange(self.window_size) + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size - 1 - relative_coords[:, :, 0] *= 2 * self.window_size - 1 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def forward(self, x: Tensor): + """ + Args: + x (Tensor): Tensor with layout of [B, H, W, C] + Returns: + Tensor with same layout as input, i.e. [B, H, W, C] + """ + + N = self.window_size[0] * self.window_size[1] relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] - relative_position_bias = relative_position_bias.view( - self.window_size * self.window_size, self.window_size * self.window_size, -1 - ) + relative_position_bias = relative_position_bias.view(N, N, -1) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) return shifted_window_attention( @@ -228,31 +244,33 @@ class SwinTransformerBlock(nn.Module): Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. - window_size (int): Window size. Default: 7. - shift_size (int): Shift size for shifted window attention. Default: 0. + window_size (List[int]): Window size. + shift_size (List[int]): Shift size for shifted window attention. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention """ def __init__( self, dim: int, num_heads: int, - window_size: int = 7, - shift_size: int = 0, + window_size: List[int], + shift_size: List[int], mlp_ratio: float = 4.0, dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention, ): super().__init__() self.norm1 = norm_layer(dim) - self.attn = ShiftedWindowAttention( + self.attn = attn_layer( dim, window_size, shift_size, @@ -281,11 +299,11 @@ class SwinTransformer(nn.Module): Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_ paper. Args: - patch_size (int): Patch size. + patch_size (List[int]): Patch size. embed_dim (int): Patch embedding dimension. depths (List(int)): Depth of each Swin Transformer layer. num_heads (List(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7. + window_size (List[int]): Window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. @@ -297,11 +315,11 @@ class SwinTransformer(nn.Module): def __init__( self, - patch_size: int, + patch_size: List[int], embed_dim: int, depths: List[int], num_heads: List[int], - window_size: int = 7, + window_size: List[int], mlp_ratio: float = 4.0, dropout: float = 0.0, attention_dropout: float = 0.0, @@ -324,7 +342,9 @@ def __init__( # split image into non-overlapping patches layers.append( nn.Sequential( - nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size), + nn.Conv2d( + 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1]) + ), Permute([0, 2, 3, 1]), norm_layer(embed_dim), ) @@ -344,7 +364,7 @@ def __init__( dim, num_heads[i_stage], window_size=window_size, - shift_size=0 if i_layer % 2 == 0 else window_size // 2, + shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size], mlp_ratio=mlp_ratio, dropout=dropout, attention_dropout=attention_dropout, @@ -381,11 +401,11 @@ def forward(self, x): def _swin_transformer( - patch_size: int, + patch_size: List[int], embed_dim: int, depths: List[int], num_heads: List[int], - window_size: int, + window_size: List[int], stochastic_depth_prob: float, weights: Optional[WeightsEnum], progress: bool, @@ -508,11 +528,11 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * weights = Swin_T_Weights.verify(weights) return _swin_transformer( - patch_size=4, + patch_size=[4, 4], embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, + window_size=[7, 7], stochastic_depth_prob=0.2, weights=weights, progress=progress, @@ -544,11 +564,11 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, * weights = Swin_S_Weights.verify(weights) return _swin_transformer( - patch_size=4, + patch_size=[4, 4], embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], - window_size=7, + window_size=[7, 7], stochastic_depth_prob=0.3, weights=weights, progress=progress, @@ -580,11 +600,11 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * weights = Swin_B_Weights.verify(weights) return _swin_transformer( - patch_size=4, + patch_size=[4, 4], embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], - window_size=7, + window_size=[7, 7], stochastic_depth_prob=0.5, weights=weights, progress=progress,