Skip to content

Commit

Permalink
Add unittest.
Browse files Browse the repository at this point in the history
Some more unittests.
  • Loading branch information
xusu committed Nov 25, 2020
1 parent 584fa93 commit 9ed8287
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 99 deletions.
12 changes: 6 additions & 6 deletions configs/recognition_audio/audioonly/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

|config | n_fft | gpus | backbone |pretrain| top1 acc/delta| top5 acc/delta | inference_time(video/s) | gpu_mem(M)| ckpt | log| json|
|:--|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
|[audioonly_r50_64x1x1_100e_kinetics400_audio_feature](/configs/recognition_audio/avslowfast/audioonly_r50_64x1x1_100e_kinetics400_audio_feature.py)|1024|8| ResNet50 | None |20.37|37.37|x||[ckpt]()|[log]()|[json]()|
|[audioonly_r50_64x1x1_100e_kinetics400_audio_feature](/configs/recognition_audio/avslowfast/audioonly_r50_64x1x1_100e_kinetics400_audio_feature.py)|1024|8| ResNet50 | None |20.37|37.37|x|6154|[ckpt]()|[log]()|[json]()|

Notes:

Expand All @@ -36,10 +36,10 @@ You can use the following command to train a model.
python tools/train.py ${CONFIG_FILE} [optional arguments]
```

Example: train ResNet model on Kinetics-400 audio dataset in a deterministic option with periodic validation.
Example: train an AudioOnly model on Kinetics-400 audio dataset in a deterministic option with periodic validation.
```shell
python tools/train.py configs/audio_recognition/tsn_r50_64x1x1_100e_kinetics400_audio_feature.py \
--work-dir work_dirs/tsn_r50_64x1x1_100e_kinetics400_audio_feature \
python tools/train.py configs/audio_recognition/audioonly_r50_64x1x1_100e_kinetics400_audio_feature.py \
--work-dir work_dirs/audioonly_r50_64x1x1_100e_kinetics400_audio_feature \
--validate --seed 0 --deterministic
```

Expand All @@ -52,9 +52,9 @@ You can use the following command to test a model.
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [optional arguments]
```

Example: test ResNet model on Kinetics-400 audio dataset and dump the result to a json file.
Example: test an AudioOnly model on Kinetics-400 audio dataset and dump the result to a json file.
```shell
python tools/test.py configs/audio_recognition/tsn_r50_64x1x1_100e_kinetics400_audio_feature.py \
python tools/test.py configs/audio_recognition/audioonly_r50_64x1x1_100e_kinetics400_audio_feature.py \
checkpoints/SOME_CHECKPOINT.pth --eval top_k_accuracy mean_class_accuracy \
--out result.json
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = ('./work_dirs/' +
'tsn_resnet_audio_r50_64x1x1_100e_kinetics400_audio_feature/')
'audioonly_r50_64x1x1_100e_kinetics400_audio_feature/')
load_from = None
resume_from = None
workflow = [('train', 1)]
62 changes: 0 additions & 62 deletions configs/recognition_audio/avslowfast/README.md

This file was deleted.

2 changes: 1 addition & 1 deletion configs/recognition_audio/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
|config | n_fft | gpus | backbone |pretrain| top1 acc/delta| top5 acc/delta | inference_time(video/s) | gpu_mem(M)| ckpt | log| json|
|:--|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
|[tsn_resnet_r18_64x1x1_100e_kinetics400_audio_feature](/configs/recognition_audio/resnet/tsn_resnet_r18_64x1x1_100e_kinetics400_audio_feature.py)|1024|8| ResNet18 | None |19.7|35.75|x|1897|[ckpt](https://download.openmmlab.com/mmaction/recognition/audio_recognition/tsn_r18_64x1x1_100e_kinetics400_audio_feature/tsn_r18_64x1x1_100e_kinetics400_audio_feature_20201012-bf34df6c.pth)|[log](https://download.openmmlab.com/mmaction/recognition/audio_recognition/tsn_r18_64x1x1_100e_kinetics400_audio_feature/20201010_144630.log)|[json](https://download.openmmlab.com/mmaction/recognition/audio_recognition/tsn_r18_64x1x1_100e_kinetics400_audio_feature/20201010_144630.log.json)|
|[tsn_resnet_r50_64x1x1_100e_kinetics400_audio_feature](/configs/recognition_audio/resnet/tsn_resnet_r50_64x1x1_100e_kinetics400_audio_feature.py)|1024|8| ResNet50 | None |17.58|32.54|x||[ckpt]()|[log]()|[json]()|
|[tsn_resnet_r50_64x1x1_100e_kinetics400_audio_feature](/configs/recognition_audio/resnet/tsn_resnet_r50_64x1x1_100e_kinetics400_audio_feature.py)|1024|8| ResNet50 | None |17.58|32.54|x|5811|[ckpt]()|[log]()|[json]()|
|[tsn_r18_64x1x1_100e_kinetics400_audio_feature](/configs/recognition_audio/resnet/tsn_r18_64x1x1_100e_kinetics400_audio_feature.py) + [tsn_r50_video_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/tsn_r50_video_320p_1x1x3_100e_kinetics400_rgb.py)|1024|8| ResNet(18+50) | None |71.50(+0.39)|90.18(+0.14)|x|x|x|x|x|

Notes:
Expand Down
19 changes: 10 additions & 9 deletions mmaction/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from .backbones import (C3D, X3D, ResNet, ResNet2Plus1d, ResNet3d, ResNet3dCSN,
ResNet3dSlowFast, ResNet3dSlowOnly, ResNetTIN,
ResNetTSM)
from .backbones import (C3D, X3D, AVResNet3dSlowFast, ResNet, ResNet2Plus1d,
ResNet3d, ResNet3dCSN, ResNet3dSlowFast,
ResNet3dSlowOnly, ResNetAudio, ResNetTIN, ResNetTSM)
from .builder import (build_backbone, build_head, build_localizer, build_loss,
build_model, build_neck, build_recognizer)
from .common import Conv2plus1d
from .heads import (AudioTSNHead, BaseHead, I3DHead, SlowFastHead, TPNHead,
TSMHead, TSNHead, X3DHead)
from .common import Conv2plus1d, ConvAudio
from .heads import (AudioTSNHead, AVSlowFastHead, BaseHead, I3DHead,
SlowFastHead, TPNHead, TSMHead, TSNHead, X3DHead)
from .localizers import BMN, PEM, TEM
from .losses import (BCELossWithLogits, BinaryLogisticRegressionLoss, BMNLoss,
CrossEntropyLoss, HVULoss, NLLLoss, OHEMHingeLoss,
SSNLoss)
from .necks import TPN
from .recognizers import (AudioRecognizer, BaseRecognizer, recognizer2d,
recognizer3d)
from .recognizers import (AudioRecognizer, AVRecognizer, BaseRecognizer,
recognizer2d, recognizer3d)
from .registry import BACKBONES, HEADS, LOCALIZERS, LOSSES, RECOGNIZERS

__all__ = [
Expand All @@ -25,5 +25,6 @@
'PEM', 'TEM', 'BinaryLogisticRegressionLoss', 'BMN', 'BMNLoss',
'build_model', 'OHEMHingeLoss', 'SSNLoss', 'ResNet3dCSN', 'ResNetTIN',
'TPN', 'TPNHead', 'build_loss', 'build_neck', 'AudioRecognizer',
'AudioTSNHead', 'X3D', 'X3DHead'
'AudioTSNHead', 'X3D', 'X3DHead', 'ResNetAudio', 'AVResNet3dSlowFast',
'AVRecognizer', 'ConvAudio', 'AVSlowFastHead'
]
22 changes: 8 additions & 14 deletions mmaction/models/backbones/resnet3d_avslowfast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchsnooper
from mmcv.cnn import ConvModule, kaiming_init
from mmcv.runner import load_checkpoint
from mmcv.utils import print_log
Expand Down Expand Up @@ -263,7 +262,6 @@ def init_weights(self):
else:
raise TypeError('pretrained must be a str or None')

@torchsnooper.snoop()
def forward(self, x, a):
"""Defines the computation performed at every call.
Expand All @@ -275,7 +273,10 @@ def forward(self, x, a):
tuple[torch.Tensor]: The feature of the input samples extracted
by the backbone.
"""
use_audio = random.random() > self.drop_out_ratio
if self.training:
use_audio = random.random() > self.drop_out_ratio
else:
use_audio = True
# stem
x_slow = nn.functional.interpolate(
x,
Expand All @@ -298,15 +299,7 @@ def forward(self, x, a):
if self.slow_path.lateral:
x_fast_lateral = self.slow_path.conv1_lateral(x_fast)
x_slow = torch.cat((x_slow, x_fast_lateral), dim=1)

if use_audio and self.audio_path.lateral:
x_audio_lateral = self.audio_path.conv1_lateral(x_audio)
x_audio_lateral = x_audio_lateral.unsqueeze(4)
# use t-pool rather than t-conv
x_audio_lateral_pooled = F.adaptive_avg_pool3d(
x_audio_lateral, [x_slow.size(2), 1, 1])
x_slow = x_slow + x_audio_lateral_pooled

# no audio fusion in early stage
# res-stages
for i, layer_name in enumerate(self.slow_path.res_layers):
res_layer = getattr(self.slow_path, layer_name)
Expand All @@ -318,12 +311,13 @@ def forward(self, x, a):

if (i != len(self.slow_path.res_layers) - 1
and self.slow_path.lateral and self.audio_path.lateral):
# No fusion needed in the final stage
# No fusion in the final stage
lateral_name = self.slow_path.lateral_connections[i]
conv_lateral = getattr(self.slow_path, lateral_name)
x_fast_lateral = conv_lateral(x_fast)
x_slow = torch.cat((x_slow, x_fast_lateral), dim=1)
if use_audio:
if use_audio and i > 0:
# lateral connection in res3,4 and pool5
lateral_name = self.audio_path.lateral_connections[i]
conv_lateral = getattr(self.audio_path, lateral_name)
x_audio_lateral = conv_lateral(x_audio)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmaction
known_third_party = cv2,joblib,matplotlib,mmcv,numpy,pandas,pytest,scipy,seaborn,torch,torchsnooper
known_third_party = cv2,joblib,matplotlib,mmcv,numpy,pandas,pytest,scipy,seaborn,torch
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
34 changes: 32 additions & 2 deletions tests/test_models/test_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import torch.nn as nn
from mmcv.utils import _BatchNorm

from mmaction.models import (C3D, X3D, ResNet, ResNet2Plus1d, ResNet3d,
ResNet3dCSN, ResNet3dSlowFast, ResNet3dSlowOnly,
from mmaction.models import (C3D, X3D, AVResNet3dSlowFast, ResNet,
ResNet2Plus1d, ResNet3d, ResNet3dCSN,
ResNet3dSlowFast, ResNet3dSlowOnly, ResNetAudio,
ResNetTIN, ResNetTSM)
from mmaction.models.backbones.resnet_tsm import NL3DWrapper

Expand Down Expand Up @@ -958,6 +959,35 @@ def test_c3d_backbone():
assert feat.shape == torch.Size([1, 4096])


def test_resnet_audio_backbone():
"""Test ResNetAudio backbone."""
input_shape = (1, 1, 128, 80)
spec = _demo_inputs(input_shape)
# inference
audioonly = ResNetAudio(50, None)
audioonly.init_weights()
audioonly.train()
feat = audioonly(spec)
assert feat.shape == torch.size([1, 2048])


def test_avslowfast_backbone():
"""Test ResNetAudio backbone."""
audio_shape = (1, 1, 128, 80)
image_shape = (1, 3, 32, 16, 16)
imgs = _demo_inputs(image_shape)
spec = _demo_inputs(audio_shape)
# inference
avsf = AVResNet3dSlowFast(None)
avsf.init_weights()
avsf.train()
feat = avsf(imgs, spec)
assert isinstance(feat, tuple)
assert feat[0].shape == torch.size([1, 2048])
assert feat[1].shape == torch.size([1, 256])
assert feat[2].shape == torch.size([1, 1024])


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_resnet_tin_backbone():
Expand Down
15 changes: 14 additions & 1 deletion tests/test_models/test_common_modules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from mmaction.models import Conv2plus1d
from mmaction.models import Conv2plus1d, ConvAudio


def test_conv2plus1d():
Expand All @@ -20,3 +20,16 @@ def test_conv2plus1d():
x = torch.rand(1, 3, 8, 256, 256)
output = conv_2plus1d(x)
assert output.shape == torch.Size([1, 8, 7, 255, 255])


def test_conv_audio():
with pytest.raises(AssertionError):
# Length of kernel size, stride and padding must be the same
ConvAudio(3, 8, (2, 2))

conv_audio = ConvAudio(3, 8, 2)
conv_audio.init_weights()

x = torch.rand(1, 3, 8, 8)
output = ConvAudio(x)
assert output.shape == torch.Size([1, 8, 8, 8])
39 changes: 37 additions & 2 deletions tests/test_models/test_head.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import torch.nn as nn

from mmaction.models import (AudioTSNHead, BaseHead, I3DHead, SlowFastHead,
TPNHead, TSMHead, TSNHead, X3DHead)
from mmaction.models import (AudioTSNHead, AVSlowFastHead, BaseHead, I3DHead,
SlowFastHead, TPNHead, TSMHead, TSNHead, X3DHead)


class ExampleHead(BaseHead):
Expand Down Expand Up @@ -214,6 +214,41 @@ def test_tsn_head_audio():
assert cls_scores.shape == torch.Size([8, 4])


def test_avslowfast_head():
"""Test loss method, layer construction, attributes and forward function in
tsn head."""
avsf_head = AVSlowFastHead(num_classes=4, in_channels=5)
avsf_head.init_weights()

assert avsf_head.num_classes == 4
assert avsf_head.dropout_ratio == 0.5
assert avsf_head.in_channels == 5
assert avsf_head.init_std == 0.01
assert avsf_head.spatial_type == 'avg'

assert isinstance(avsf_head.dropout, nn.Dropout)
assert avsf_head.dropout.p == avsf_head.dropout_ratio

assert isinstance(avsf_head.fc_cls, nn.Linear)
assert avsf_head.fc_cls.in_features == avsf_head.in_channels
assert avsf_head.fc_cls.out_features == avsf_head.num_classes

assert isinstance(avsf_head.avg_pool, nn.AdaptiveAvgPool3d)
assert avsf_head.avg_pool.output_size == (1, 1, 1)

slow_shape = (8, 1, 4, 7, 7)
fast_shape = (8, 2, 4, 7, 7)
audio_shape = (8, 1, 7, 7)

feat = tuple(
torch.rand(slow_shape), torch.rand(fast_shape),
torch.rand(audio_shape))

# tsn head inference
cls_scores = avsf_head(feat)
assert cls_scores.shape == torch.Size([8, 4])


def test_tsm_head():
"""Test loss method, layer construction, attributes and forward function in
tsm head."""
Expand Down
28 changes: 28 additions & 0 deletions tests/test_models/test_recognizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,34 @@ def test_audio_recognizer():
recognizer(one_spectro, None, return_loss=False)


def test_av_recognizer():
model, train_cfg, test_cfg = _get_audio_recognizer_cfg(
'avslowfast/avslowfast_r50_32x2x1_239e_kinetics400_audio_feature.py')
model['backbone']['pretrained'] = None

recognizer = build_recognizer(
model, train_cfg=train_cfg, test_cfg=test_cfg)

input_shape = (1, 3, 1, 128, 80)
demo_inputs_audio = generate_demo_inputs(input_shape, model_type='audio')
input_shape = (1, 3, 32, 16, 16)
demo_inputs_visual = generate_demo_inputs(input_shape)
demo_inputs = {**demo_inputs_audio, **demo_inputs_visual}
imgs = demo_inputs['imgs']
audios = demo_inputs['audios']
gt_labels = demo_inputs['gt_labels']

losses = recognizer(imgs, audios, gt_labels)
assert isinstance(losses, dict)

# Test forward test
with torch.no_grad():
audio_list = [audio[None, :] for audio in audios]
img_list = [img[None, :] for img in imgs]
for one_img, one_spectro in zip(img_list, audio_list):
recognizer(one_img, one_spectro, None, return_loss=False)


def test_c3d():
model, train_cfg, test_cfg = _get_recognizer_cfg(
'c3d/c3d_sports1m_16x1x1_45e_ucf101_rgb.py')
Expand Down

0 comments on commit 9ed8287

Please sign in to comment.