Skip to content

Commit

Permalink
More docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
xusu committed Nov 27, 2020
1 parent a038b62 commit 72a545b
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 13 deletions.
15 changes: 7 additions & 8 deletions mmaction/datasets/audio_visual_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
@DATASETS.register_module
class AudioVisualDataset(RawframeDataset):
"""Dataset that read both audio and visual, supporting both rawframes and
videos. Annotation file can be that of the rawframe dataset,
or:
videos. Annotation file can be that of the rawframe dataset, such as:
.. code-block:: txt
Expand All @@ -30,6 +29,7 @@ class AudioVisualDataset(RawframeDataset):
def __init__(self, ann_file, pipeline, audio_prefix, **kwargs):
self.audio_prefix = audio_prefix
self.video_prefix = kwargs.pop('video_prefix', None)
self.data_prefix = kwargs.get('data_prefix', None)
super().__init__(ann_file, pipeline, **kwargs)

def load_annotations(self):
Expand All @@ -43,16 +43,15 @@ def load_annotations(self):
frame_dir = line_split[idx]
if self.audio_prefix is not None:
audio_path = osp.join(self.audio_prefix,
frame_dir) + '.npy'
frame_dir + '.npy')
video_info['audio_path'] = audio_path
if self.video_prefix:
video_path = osp.join(self.video_prefix,
frame_dir) + '.mp4'

frame_dir + '.mp4')
video_info['filename'] = video_path
if self.data_prefix is not None:
frame_dir = osp.join(self.data_prefix, frame_dir)
video_info['frame_dir'] = frame_dir
video_info['audio_path'] = audio_path
video_info['filename'] = video_path
video_info['frame_dir'] = frame_dir
idx += 1
if self.with_offset:
# idx for offset and total_frames
Expand Down
30 changes: 26 additions & 4 deletions mmaction/models/backbones/resnet_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def _inner_forward(x):

@BACKBONES.register_module
class ResNetAudio(nn.Module):
"""ResNet 2d audio backbone.
"""ResNet 2d audio backbone. Reference:
<https://arxiv.org/abs/2001.08740>`_.
Args:
depth (int): Depth of resnet, from {50, 101, 152}.
Expand All @@ -124,9 +126,10 @@ class ResNetAudio(nn.Module):
Default: (1, 1, 1, 1).
conv1_kernel (int): Kernel size of the first conv layer. Default: 9.
conv1_stride (int | tuple[int]): Stride of the first conv layer.
Default: 1.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
factorize (Sequence[int]): factorize Dims of each block.
factorize (Sequence[int]): factorize Dims of each block for audio.
Default: (1, 1, 0, 0).
norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var). Default: False.
Expand Down Expand Up @@ -166,8 +169,8 @@ def __init__(self,
norm_eval=False,
with_cp=False,
conv_cfg=dict(type='Conv'),
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN2d', requires_grad=True),
act_cfg=dict(type='ReLU', inplace=True),
zero_init_residual=True):
super().__init__()
if depth not in self.arch_settings:
Expand Down Expand Up @@ -243,6 +246,9 @@ def make_res_layer(self,
dilation (int): Spacing between kernel elements. Default: 1.
factorize (int | Sequence[int]): Determine whether to factorize
for each block. Default: 1.
norm_cfg (dict):
Config for norm layers. required keys are `type` and
`requires_grad`. Default: None.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed.
Default: False.
Expand Down Expand Up @@ -290,6 +296,8 @@ def make_res_layer(self,
return nn.Sequential(*layers)

def _make_stem_layer(self):
"""Construct the stem layers consists of a conv+norm+act module and a
pooling layer."""
self.conv1 = ConvModule(
self.in_channels,
self.base_channels,
Expand All @@ -301,6 +309,8 @@ def _make_stem_layer(self):
act_cfg=self.act_cfg)

def _freeze_stages(self):
"""Prevent all the parameters from being optimized before
``self.frozen_stages``."""
if self.frozen_stages >= 0:
self.conv1.bn.eval()
for m in [self.conv1.conv, self.conv1.bn]:
Expand All @@ -314,6 +324,8 @@ def _freeze_stages(self):
param.requires_grad = False

def init_weights(self):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
if isinstance(self.pretrained, str):
logger = get_root_logger()
logger.info(f'load model from: {self.pretrained}')
Expand All @@ -336,13 +348,23 @@ def init_weights(self):
raise TypeError('pretrained must be a str or None')

def forward(self, x):
"""Defines the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The feature of the input samples extracted
by the backbone.
"""
x = self.conv1(x)
for layer_name in self.res_layers:
res_layer = getattr(self, layer_name)
x = res_layer(x)
return x

def train(self, mode=True):
"""Set the optimization status when training."""
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
Expand Down
2 changes: 2 additions & 0 deletions mmaction/models/common/conv_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
class ConvAudio(nn.Module):
"""Conv2d module for AudioResNet backbone.
<https://arxiv.org/abs/2001.08740>`_.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ def test_c3d_backbone():

def test_resnet_audio_backbone():
"""Test ResNetAudio backbone."""
input_shape = (1, 1, 32, 32)
input_shape = (1, 1, 16, 16)
spec = _demo_inputs(input_shape)
# inference
audioonly = ResNetAudio(50, None)
Expand Down

0 comments on commit 72a545b

Please sign in to comment.