Skip to content

Commit

Permalink
refine timesformer backbone code
Browse files Browse the repository at this point in the history
  • Loading branch information
congee524 committed May 21, 2021
1 parent f2fdf71 commit bb3c600
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
type='Recognizer3D',
backbone=dict(
type='TimeSformer',
pretrained='work_dirs/vit_imagenet.pth',
pretrained='work_dirs/vit_base_patch16_224.pth',
num_frames=8,
img_size=224,
patch_size=16,
embed_dims=768,
in_channels=3,
drop_rate=0.,
dropout_ratio=0.,
transformer_layers=None,
attention_type='divided_space_time',
norm_cfg=dict(type='LN', eps=1e-6)),
Expand All @@ -31,14 +31,9 @@
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_bgr=False)

mc_cfg = dict(
server_list_cfg='/mnt/lustre/share/memcached_client/server_list.conf',
client_cfg='/mnt/lustre/share/memcached_client/client.conf',
sys_path='/mnt/lustre/share/pymc/py3')

train_pipeline = [
dict(type='SampleFrames', clip_len=8, frame_interval=32, num_clips=1),
dict(type='RawFrameDecode', io_backend='memcached', **mc_cfg),
dict(type='RawFrameDecode'),
dict(type='RandomRescale', scale_range=(256, 320)),
dict(type='RandomCrop', size=224),
dict(type='Flip', flip_ratio=0.5),
Expand All @@ -54,7 +49,7 @@
frame_interval=32,
num_clips=1,
test_mode=True),
dict(type='RawFrameDecode', io_backend='memcached', **mc_cfg),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
Expand All @@ -69,7 +64,7 @@
frame_interval=32,
num_clips=1,
test_mode=True),
dict(type='RawFrameDecode', io_backend='memcached', **mc_cfg),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 224)),
dict(type='ThreeCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
patch_size=16,
embed_dims=768,
in_channels=3,
drop_rate=0.,
dropout_ratio=0.,
transformer_layers=None,
attention_type='joint_space_time',
norm_cfg=dict(type='LN', eps=1e-6)),
Expand All @@ -31,14 +31,9 @@
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_bgr=False)

mc_cfg = dict(
server_list_cfg='/mnt/lustre/share/memcached_client/server_list.conf',
client_cfg='/mnt/lustre/share/memcached_client/client.conf',
sys_path='/mnt/lustre/share/pymc/py3')

train_pipeline = [
dict(type='SampleFrames', clip_len=8, frame_interval=32, num_clips=1),
dict(type='RawFrameDecode', io_backend='memcached', **mc_cfg),
dict(type='RawFrameDecode'),
dict(type='RandomRescale', scale_range=(256, 320)),
dict(type='RandomCrop', size=224),
dict(type='Flip', flip_ratio=0.5),
Expand All @@ -54,7 +49,7 @@
frame_interval=32,
num_clips=1,
test_mode=True),
dict(type='RawFrameDecode', io_backend='memcached', **mc_cfg),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
Expand All @@ -69,7 +64,7 @@
frame_interval=32,
num_clips=1,
test_mode=True),
dict(type='RawFrameDecode', io_backend='memcached', **mc_cfg),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 224)),
dict(type='ThreeCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
patch_size=16,
embed_dims=768,
in_channels=3,
drop_rate=0.,
dropout_ratio=0.,
transformer_layers=None,
attention_type='space_only',
norm_cfg=dict(type='LN', eps=1e-6)),
Expand All @@ -31,14 +31,9 @@
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_bgr=False)

mc_cfg = dict(
server_list_cfg='/mnt/lustre/share/memcached_client/server_list.conf',
client_cfg='/mnt/lustre/share/memcached_client/client.conf',
sys_path='/mnt/lustre/share/pymc/py3')

train_pipeline = [
dict(type='SampleFrames', clip_len=8, frame_interval=32, num_clips=1),
dict(type='RawFrameDecode', io_backend='memcached', **mc_cfg),
dict(type='RawFrameDecode'),
dict(type='RandomRescale', scale_range=(256, 320)),
dict(type='RandomCrop', size=224),
dict(type='Flip', flip_ratio=0.5),
Expand All @@ -54,7 +49,7 @@
frame_interval=32,
num_clips=1,
test_mode=True),
dict(type='RawFrameDecode', io_backend='memcached', **mc_cfg),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
Expand All @@ -69,7 +64,7 @@
frame_interval=32,
num_clips=1,
test_mode=True),
dict(type='RawFrameDecode', io_backend='memcached', **mc_cfg),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 224)),
dict(type='ThreeCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
Expand Down
84 changes: 62 additions & 22 deletions mmaction/models/backbones/timesformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mmcv.runner import _load_checkpoint, load_state_dict
from torch.nn.modules.utils import _pair

from mmaction.utils import trunc_normal_
from mmaction.utils import trunc_normal_ # TODO: use trunc_normal_ in mmcv
from ...utils import get_root_logger
from ..registry import BACKBONES

Expand All @@ -16,17 +16,17 @@ class PatchEmbed(nn.Module):
"""Image to Patch Embedding.
Args:
img_size (int | tuple): The size of input image.
patch_size (int): The size of one patch
in_channels (int): The num of input channels.
embed_dims (int): The dimensions of embedding.
conv_cfg (dict | None): The config dict for conv layers.
Default: None.
img_size (int | tuple): Size of input image.
patch_size (int): Size of one patch.
in_channels (int): Channel num of input features. Defaults to 3.
embed_dims (int): Dimensions of embedding. Defaults to 768.
conv_cfg (dict | None): Config dict for convolution layer. Defaults to
dict(type='Conv2d').
"""

def __init__(self,
img_size=224,
patch_size=16,
img_size,
patch_size,
in_channels=3,
embed_dims=768,
conv_cfg=dict(type='Conv2d')):
Expand Down Expand Up @@ -63,6 +63,29 @@ def forward(self, x):

@BACKBONES.register_module()
class TimeSformer(nn.Module):
"""TimeSformer A PyTorch impl of `Is Space-Time Attention All You Need for
Video Understanding? <https://arxiv.org/abs/2102.05095>`_
Args:
num_frames (int): Total number of frame in the video.
img_size (int | tuple): Size of input image.
patch_size (int): Size of one patch.
pretrained (str | None): Name of pretrained model. Default: None.
embed_dims (int): Dimensions of embedding. Defaults to 768.
num_transformer_layers (int): Number of transformer layers. Defaults to
12.
in_channels (int): Channel num of input features. Defaults to 3.
dropout_ratio (float): Probability of dropout layer. Defaults to 0..
transformer_layers (list[`mmcv.ConfigDict`] | `mmcv.ConfigDict`):
Config of transformerlayer in TransformerCoder. If it is
`mmcv.ConfigDict`, it would be repeated `num_transformer_layers`
times to a list[`mmcv.ConfigDict`]. Defaults to None.
attention_type (str): Type of attentions in TransformerCoder. Choices
are 'divided_space_time', 'space_only' and 'joint_space_time'.
Defaults to 'divided_space_time'.
norm_cfg (dict | None): Config for norm layers. Defaults to
dict(type='LN', eps=1e-6).
"""
supported_attention_type = [
'divided_space_time', 'space_only', 'joint_space_time'
]
Expand All @@ -73,21 +96,23 @@ def __init__(self,
patch_size,
pretrained=None,
embed_dims=768,
num_layers=12,
num_transformer_layers=12,
in_channels=3,
drop_rate=0.,
dropout_ratio=0.,
transformer_layers=None,
attention_type='divided_space_time',
norm_cfg=dict(type='LN', eps=1e-6),
**kwargs):
super().__init__(**kwargs)
assert attention_type in self.supported_attention_type, (
f'Unsupported Attention Type {self.attention_type}!')

self.num_frames = num_frames
self.pretrained = pretrained
self.embed_dims = embed_dims
self.num_layers = num_layers
self.num_transformer_layers = num_transformer_layers
self.attention_type = attention_type

self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
Expand All @@ -98,17 +123,20 @@ def __init__(self,
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
self.drop_after_pos = nn.Dropout(p=dropout_ratio)
if self.attention_type != 'space_only':
self.time_embed = nn.Parameter(
torch.zeros(1, num_frames, embed_dims))
self.drop_after_time = nn.Dropout(p=drop_rate)
self.drop_after_time = nn.Dropout(p=dropout_ratio)

self.norm = build_norm_layer(norm_cfg, embed_dims)[1]

if transformer_layers is None:
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, 0.1, num_layers)]
dpr = [
x.item()
for x in torch.linspace(0, 0.1, num_transformer_layers)
]

if self.attention_type == 'divided_space_time':
_transformerlayers_cfg = [
Expand Down Expand Up @@ -142,7 +170,7 @@ def __init__(self,
type='DropPath', drop_prob=dpr[i]),
norm_cfg=dict(type='LN', eps=1e-6)),
operation_order=('self_attn', 'self_attn', 'ffn'))
for i in range(num_layers)
for i in range(num_transformer_layers)
]
else:
# Sapce Only & Joint Space Time
Expand All @@ -168,19 +196,21 @@ def __init__(self,
type='DropPath', drop_prob=dpr[i])),
operation_order=('norm', 'self_attn', 'norm', 'ffn'),
norm_cfg=dict(type='LN', eps=1e-6))
for i in range(num_layers)
for i in range(num_transformer_layers)
]

transformer_layers = ConfigDict(
dict(
type='TransformerLayerSequence',
transformerlayers=_transformerlayers_cfg,
num_layers=num_layers))
num_layers=num_transformer_layers))

self.transformer_layers = build_transformer_layer_sequence(
transformer_layers)

def init_weights(self, pretrained=None):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)

Expand All @@ -190,22 +220,31 @@ def init_weights(self, pretrained=None):
logger = get_root_logger()
logger.info(f'load model from: {self.pretrained}')

state_dict = _load_checkpoint(self.pretrained, map_location='cpu')
state_dict = _load_checkpoint(self.pretrained)
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']

if self.attention_type == 'divided_space_time':
# copy the parameters of space attention to time attention
old_state_dict_keys = list(state_dict.keys())
for old_key in old_state_dict_keys:
if 'attentions.1' in old_key:
new_key = old_key.replace('attentions.1',
'attentions.0')
if 'norms' in old_key:
new_key = old_key.replace('norms.0',
'attentions.0.norm')
new_key = new_key.replace('norms.1', 'ffns.0.norm')
state_dict[new_key] = state_dict.pop(old_key)

old_state_dict_keys = list(state_dict.keys())
for old_key in old_state_dict_keys:
if 'attentions.0' in old_key:
new_key = old_key.replace('attentions.0',
'attentions.1')
state_dict[new_key] = state_dict[old_key].clone()

load_state_dict(self, state_dict, strict=False, logger=logger)

def forward(self, x):
"""Defines the computation performed at every call."""
# x [batch_size * num_frames, num_patches, embed_dims]
B = x.shape[0]
x = self.patch_embed(x)
Expand All @@ -232,6 +271,7 @@ def forward(self, x):
# x [batch_size, num_patches + 1, embed_dims]
x = x.view(-1, self.num_frames, *x.size()[-2:])
x = torch.mean(x, 1)

x = self.norm(x)

return x[:, 0]

0 comments on commit bb3c600

Please sign in to comment.