Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
congee524 committed May 24, 2021
1 parent e7ff123 commit a71b3b4
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 28 deletions.
2 changes: 1 addition & 1 deletion configs/recognition/timesformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Notes:
e.g., lr=0.005 for 8 GPUs x 8 videos/gpu and lr=0.00375 for 8 GPUs x 6 videos/gpu.
2. We keep the test setting with the [original repo](https://github.com/facebookresearch/TimeSformer) (three crop x 1 clip).
3. The pretrained model `vit_base_patch16_224.pth` used by TimeSformer was converted from [vision_transformer](https://github.com/google-research/vision_transformer).
4. The model `timesformer_divST_8x32x1` from [original repo](https://github.com/facebookresearch/TimeSformer) get 73.38 top-1 accuracy in our kinetics dataset (short-side 256).
4. The model `timesformer_divST_8x32x1` from [original repo](https://github.com/facebookresearch/TimeSformer) get 77.38 top-1 accuracy in our kinetics dataset (short-side 256).

For more details on data preparation, you can refer to Kinetics400 in [Data Preparation](/docs/data_preparation.md).

Expand Down
2 changes: 1 addition & 1 deletion configs/recognition/timesformer/README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
如,lr=0.005 对应 8 GPUs x 8 video/gpu,以及 lr=0.00375 对应 8 GPUs x 6 video/gpu。
2. MMAction2 保持与 [原代码](https://github.com/facebookresearch/TimeSformer) 的测试设置一致(three crop x 1 clip)。
3. TimeSformer 使用的预训练模型 `vit_base_patch16_224.pth` 转换自 [vision_transformer](https://github.com/google-research/vision_transformer)
4. 原代码库提供的 `timesformer_divST_8x32x1` 模型,在 MMAction2 短边 256 像素的 Kinetics 400 数据集上,达到 73.38 的 Top-1 精度。
4. 原代码库提供的 `timesformer_divST_8x32x1` 模型,在 MMAction2 短边 256 像素的 Kinetics 400 数据集上,达到 77.38 的 Top-1 精度。

对于数据集准备的细节,用户可参考 [数据集准备文档](/docs_zh_CN/data_preparation.md) 中的 Kinetics400 部分。

Expand Down
32 changes: 18 additions & 14 deletions mmaction/models/backbones/timesformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class PatchEmbed(nn.Module):
in_channels (int): Channel num of input features. Defaults to 3.
embed_dims (int): Dimensions of embedding. Defaults to 768.
conv_cfg (dict | None): Config dict for convolution layer. Defaults to
dict(type='Conv2d').
`dict(type='Conv2d')`.
"""

def __init__(self,
Expand Down Expand Up @@ -63,28 +63,31 @@ def forward(self, x):

@BACKBONES.register_module()
class TimeSformer(nn.Module):
"""TimeSformer A PyTorch impl of `Is Space-Time Attention All You Need for
"""TimeSformer. A PyTorch impl of `Is Space-Time Attention All You Need for
Video Understanding? <https://arxiv.org/abs/2102.05095>`_
Args:
num_frames (int): Total number of frame in the video.
num_frames (int): Number of frames in the video.
img_size (int | tuple): Size of input image.
patch_size (int): Size of one patch.
pretrained (str | None): Name of pretrained model. Default: None.
embed_dims (int): Dimensions of embedding. Defaults to 768.
num_heads (int): Number of parallel attention heads in
TransformerCoder. Defaults to 12.
num_transformer_layers (int): Number of transformer layers. Defaults to
12.
in_channels (int): Channel num of input features. Defaults to 3.
dropout_ratio (float): Probability of dropout layer. Defaults to 0..
transformer_layers (list[`mmcv.ConfigDict`] | `mmcv.ConfigDict`):
Config of transformerlayer in TransformerCoder. If it is
`mmcv.ConfigDict`, it would be repeated `num_transformer_layers`
times to a list[`mmcv.ConfigDict`]. Defaults to None.
transformer_layers (list[obj:`mmcv.ConfigDict`] |
obj:`mmcv.ConfigDict`): Config of transformerlayer in
TransformerCoder. If it is obj:`mmcv.ConfigDict`, it would be
repeated `num_transformer_layers` times to a
list[obj:`mmcv.ConfigDict`]. Defaults to None.
attention_type (str): Type of attentions in TransformerCoder. Choices
are 'divided_space_time', 'space_only' and 'joint_space_time'.
Defaults to 'divided_space_time'.
norm_cfg (dict | None): Config for norm layers. Defaults to
dict(type='LN', eps=1e-6).
norm_cfg (dict): Config for norm layers. Defaults to
`dict(type='LN', eps=1e-6)`.
"""
supported_attention_type = [
'divided_space_time', 'space_only', 'joint_space_time'
Expand All @@ -96,6 +99,7 @@ def __init__(self,
patch_size,
pretrained=None,
embed_dims=768,
num_heads=12,
num_transformer_layers=12,
in_channels=3,
dropout_ratio=0.,
Expand Down Expand Up @@ -146,15 +150,15 @@ def __init__(self,
dict(
type='DividedTemporalAttentionWithNorm',
embed_dims=embed_dims,
num_heads=12,
num_heads=num_heads,
num_frames=num_frames,
dropout_layer=dict(
type='DropPath', drop_prob=dpr[i]),
norm_cfg=dict(type='LN', eps=1e-6)),
dict(
type='DividedSpatialAttentionWithNorm',
embed_dims=embed_dims,
num_heads=12,
num_heads=num_heads,
num_frames=num_frames,
dropout_layer=dict(
type='DropPath', drop_prob=dpr[i]),
Expand All @@ -163,7 +167,7 @@ def __init__(self,
ffn_cfgs=dict(
type='FFNWithNorm',
embed_dims=embed_dims,
feedforward_channels=3072,
feedforward_channels=embed_dims * 4,
num_fcs=2,
act_cfg=dict(type='GELU'),
dropout_layer=dict(
Expand All @@ -181,15 +185,15 @@ def __init__(self,
dict(
type='MultiheadAttention',
embed_dims=embed_dims,
num_heads=12,
num_heads=num_heads,
batch_first=True,
dropout_layer=dict(
type='DropPath', drop_prob=dpr[i]))
],
ffn_cfgs=dict(
type='FFN',
embed_dims=embed_dims,
feedforward_channels=3072,
feedforward_channels=embed_dims * 4,
num_fcs=2,
act_cfg=dict(type='GELU'),
dropout_layer=dict(
Expand Down
88 changes: 82 additions & 6 deletions mmaction/models/common/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,31 @@

@ATTENTION.register_module()
class DividedTemporalAttentionWithNorm(BaseModule):
"""Temporal Attention in Divided Space Time Attention.
Args:
embed_dims (int): Dimensions of embedding.
num_heads (int): Number of parallel attention heads in
TransformerCoder.
num_frames (int): Number of frames in the video.
attn_drop (float): A Dropout layer on attn_output_weights. Defaults to
0..
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
Defaults to 0..
dropout_layer (dict): The dropout_layer used when adding the shortcut.
Defaults to `dict(type='DropPath', drop_prob=0.1)`.
norm_cfg (dict): Config dict for normalization layer. Defaults to
`dict(type='LN')`.
init_cfg (dict | None): The Config for initialization. Defaults to
None.
"""

def __init__(self,
embed_dims,
num_heads,
num_frames,
attn_drop=0.,
proj_drop=0.,
dropout_layer=dict(type='DropPath', drop_prob=0.1),
norm_cfg=dict(type='LN'),
init_cfg=None,
Expand All @@ -23,11 +42,12 @@ def __init__(self,
self.embed_dims = embed_dims
self.num_heads = num_heads
self.num_frames = num_frames
self.norm = build_norm_layer(norm_cfg, self.embed_dims)[1]
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
**kwargs)
self.proj_drop = nn.Dropout(proj_drop)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else nn.Identity()
self.norm = build_norm_layer(norm_cfg, self.embed_dims)[1]
self.temporal_fc = nn.Linear(self.embed_dims, self.embed_dims)

self.init_weights()
Expand All @@ -37,7 +57,7 @@ def init_weights(self):

def forward(self, query, key=None, value=None, residual=None, **kwargs):
assert residual is None, (
'Cannot apply pre-norm with DividedTemporalAttentionWithNorm')
'Always adding the shortcut in the forward function')

init_cls_token = query[:, 0, :].unsqueeze(1)
identity = query_t = query[:, 1:, :]
Expand All @@ -49,7 +69,8 @@ def forward(self, query, key=None, value=None, residual=None, **kwargs):
# res_temporal [batch_size * num_patches, num_frames, embed_dims]
query_t = self.norm(query_t.reshape(b * p, t, m)).permute(1, 0, 2)
res_temporal = self.attn(query_t, query_t, query_t)[0].permute(1, 0, 2)
res_temporal = self.dropout_layer(res_temporal.contiguous())
res_temporal = self.dropout_layer(
self.proj_drop(res_temporal.contiguous()))
res_temporal = self.temporal_fc(res_temporal)

# res_temporal [batch_size, num_patches * num_frames, embed_dims]
Expand All @@ -63,12 +84,31 @@ def forward(self, query, key=None, value=None, residual=None, **kwargs):

@ATTENTION.register_module()
class DividedSpatialAttentionWithNorm(BaseModule):
"""Spatial Attention in Divided Space Time Attention.
Args:
embed_dims (int): Dimensions of embedding.
num_heads (int): Number of parallel attention heads in
TransformerCoder.
num_frames (int): Number of frames in the video.
attn_drop (float): A Dropout layer on attn_output_weights. Defaults to
0..
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
Defaults to 0..
dropout_layer (dict): The dropout_layer used when adding the shortcut.
Defaults to `dict(type='DropPath', drop_prob=0.1)`.
norm_cfg (dict): Config dict for normalization layer. Defaults to
`dict(type='LN')`.
init_cfg (dict | None): The Config for initialization. Defaults to
None.
"""

def __init__(self,
embed_dims,
num_heads,
num_frames,
attn_drop=0.,
proj_drop=0.,
dropout_layer=dict(type='DropPath', drop_prob=0.1),
norm_cfg=dict(type='LN'),
init_cfg=None,
Expand All @@ -77,15 +117,22 @@ def __init__(self,
self.embed_dims = embed_dims
self.num_heads = num_heads
self.num_frames = num_frames
self.norm = build_norm_layer(norm_cfg, self.embed_dims)[1]
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
**kwargs)
self.proj_drop = nn.Dropout(proj_drop)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else nn.Identity()
self.norm = build_norm_layer(norm_cfg, self.embed_dims)[1]

self.init_weights()

def init_weights(self):
# init DividedSpatialAttentionWithNorm by default
pass

def forward(self, query, key=None, value=None, residual=None, **kwargs):
assert residual is None, (
'Cannot apply pre-norm with DividedTemporalAttentionWithNorm')
'Always adding the shortcut in the forward function')

identity = query
init_cls_token = query[:, 0, :].unsqueeze(1)
Expand All @@ -106,7 +153,8 @@ def forward(self, query, key=None, value=None, residual=None, **kwargs):
# res_spatial [batch_size * num_frames, num_patches + 1, embed_dims]
query_s = self.norm(query_s).permute(1, 0, 2)
res_spatial = self.attn(query_s, query_s, query_s)[0].permute(1, 0, 2)
res_spatial = self.dropout_layer(res_spatial.contiguous())
res_spatial = self.dropout_layer(
self.proj_drop(res_spatial.contiguous()))

# cls_token [batch_size, 1, embed_dims]
cls_token = res_spatial[:, 0, :].reshape(b, t, m)
Expand All @@ -123,6 +171,34 @@ def forward(self, query, key=None, value=None, residual=None, **kwargs):

@FEEDFORWARD_NETWORK.register_module()
class FFNWithNorm(FFN):
"""FFN with pre normalization layer.
FFNWithNorm is implemented to be compatible with `BaseTransformerLayer`
when using `DividedTemporalAttentionWithNorm` and
`DividedSpatialAttentionWithNorm`.
FFNWithNorm has one main difference with FFN:
- It apply one normalization layer before forwarding the input data to
feed-forward networks.
Args:
embed_dims (int): Dimensions of embedding. Defaults to 256.
feedforward_channels (int): Hidden dimension of FFNs. Defaults to 1024.
num_fcs (int, optional): Number of fully-connected layers in FFNs.
Defaults to 2.
act_cfg (dict): Config for activate layers.
Defaults to `dict(type='ReLU')`
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Defaults to 0..
add_residual (bool, optional): Whether to add the
residual connection. Defaults to `True`.
dropout_layer (dict | None): The dropout_layer used when adding the
shortcut. Defaults to None.
init_cfg (dict): The Config for initialization. Defaults to None.
norm_cfg (dict): Config dict for normalization layer. Defaults to
`dict(type='LN')`.
"""

def __init__(self, *args, norm_cfg=dict(type='LN'), **kwargs):
super().__init__(*args, **kwargs)
Expand Down
5 changes: 1 addition & 4 deletions mmaction/models/heads/timesformer_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,11 @@ class TimeSformerHead(BaseHead):
num_classes (int): Number of classes to be classified.
in_channels (int): Number of channels in input feature.
loss_cls (dict): Config for building loss.
Defaults to dict(type='CrossEntropyLoss')
Defaults to `dict(type='CrossEntropyLoss')`.
init_std (float): Std value for Initiation. Defaults to 0.02.
kwargs (dict, optional): Any keyword argument to be used to initialize
the head.
"""
supported_attention_type = [
'divided_space_time', 'space_only', 'joint_space_time'
]

def __init__(self,
num_classes,
Expand Down
3 changes: 1 addition & 2 deletions mmaction/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from .misc import get_random_string, get_shm_dir, get_thread_id
from .module_hooks import register_module_hooks
from .precise_bn import PreciseBNHook
from .trunc_normal import trunc_normal_

__all__ = [
'get_root_logger', 'collect_env', 'get_random_string', 'get_thread_id',
'get_shm_dir', 'GradCAM', 'PreciseBNHook', 'import_module_error_class',
'import_module_error_func', 'register_module_hooks', 'trunc_normal_'
'import_module_error_func', 'register_module_hooks'
]

0 comments on commit a71b3b4

Please sign in to comment.