Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Enable torchvision backbones #720

Merged
merged 42 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
c5f00a8
resolve comments
Oct 16, 2020
05575c1
update changelog
Oct 16, 2020
eb09070
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Oct 19, 2020
81302a4
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Oct 22, 2020
43a649a
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Oct 27, 2020
755809e
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Oct 29, 2020
d478c9d
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 2, 2020
08bbc06
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 7, 2020
ff958e6
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 8, 2020
d0e192d
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 11, 2020
a52c536
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 11, 2020
81a2029
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 16, 2020
e03d2a9
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 17, 2020
2a9b57f
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 27, 2020
28001ff
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 30, 2020
46cc5dd
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Dec 1, 2020
667818a
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Dec 18, 2020
34398a8
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Dec 18, 2020
9928127
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Jan 5, 2021
b633443
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Jan 5, 2021
a891176
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Jan 5, 2021
611ae31
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
kennymckormick Jan 13, 2021
96662d4
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
kennymckormick Jan 15, 2021
ef5ab3c
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
kennymckormick Jan 25, 2021
0b63bd0
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
kennymckormick Jan 27, 2021
d4450fc
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
kennymckormick Jan 27, 2021
467b8d8
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
kennymckormick Feb 1, 2021
1e2e566
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
kennymckormick Feb 3, 2021
f91fe5e
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
kennymckormick Mar 3, 2021
c2067b1
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
kennymckormick Mar 5, 2021
0b923cf
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
kennymckormick Mar 11, 2021
c181802
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
kennymckormick Mar 12, 2021
5ca360e
support torchvision backbones
kennymckormick Mar 13, 2021
0eebe93
add ckpt, changelog and unittest
kennymckormick Mar 18, 2021
96affe1
Merge branch 'master' into tv_backbones
kennymckormick Mar 24, 2021
bae040c
fix lint
kennymckormick Mar 24, 2021
4eda6cb
Merge branch 'master' into tv_backbones
kennymckormick Mar 24, 2021
7994bf4
fix lint
kennymckormick Mar 24, 2021
6dc64d4
fix lint
kennymckormick Mar 24, 2021
7d83ab8
Update changelog.md
kennymckormick Mar 24, 2021
d60308b
Update changelog.md
kennymckormick Mar 24, 2021
1ab9add
Merge branch 'master' into tv_backbones
kennymckormick Mar 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions configs/recognition/tsn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,13 @@ Here, We use [1: 1] to indicate that we combine rgb and flow score with coeffici

It's possible and convenient to use a 3rd-party backbone for TSN under the framework of MMAction2, here we provide some examples for:

- [x] Backbones from MMClassification
- [x] Backbones from [MMClassification](https://github.com/open-mmlab/mmclassification/)
- [x] Backbones from [TorchVision](https://github.com/pytorch/vision/)

| config | resolution | gpus | backbone | pretrain | top1 acc | top5 acc | ckpt | log | json |
| :----------------------------------------------------------: | :------------: | :--: | :----------------------------------------------------------: | :------: | :------: | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
| config | resolution | gpus | backbone | pretrain | top1 acc | top5 acc | ckpt | log | json |
| :----------------------------------------------------------- | :------------: | :--: | :----------------------------------------------------------: | :------: | :------: | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
| [tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | ResNeXt101-32x4d [[MMCls](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext)] | ImageNet | 73.43 | 91.01 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb-16a8b561.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.json) |
| [tsn_dense161_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | ResNeXt101-32x4d [[TorchVision](https://github.com/pytorch/vision/)] | ImageNet | 72.78 | 90.75 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb-cbe85332.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb.json) |

### Kinetics-400 Data Benchmark (8-gpus, ResNet50, ImageNet pretrain; 3 segments)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
_base_ = [
'../../../_base_/schedules/sgd_100e.py',
'../../../_base_/default_runtime.py'
]

# model settings
model = dict(
type='Recognizer2D',
backbone=dict(type='torchvision.densenet161', pretrained=True),
cls_head=dict(
type='TSNHead',
num_classes=400,
in_channels=2208,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.4,
init_std=0.01),
# model training and testing settings
train_cfg=None,
test_cfg=dict(average_clips=None))

# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/kinetics400/rawframes_train_320p'
data_root_val = 'data/kinetics400/rawframes_val_320p'
ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes_320p.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes_320p.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes_320p.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=3),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=3,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=25,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=12,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))

# runtime settings
work_dir = './work_dirs/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/'
optimizer = dict(
type='SGD',
lr=0.00375, # this lr is used for 8 gpus
momentum=0.9,
weight_decay=0.0001)
4 changes: 3 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

- Support LFB ([#553](https://github.com/open-mmlab/mmaction2/pull/553))
- Support using backbones from MMCls for TSN ([#679](https://github.com/open-mmlab/mmaction2/pull/679))
- Support using backbones from TorchVision for TSN ([#720]([[Feature\] Enable torchvision backbones by kennymckormick · Pull Request #720 · open-mmlab/mmaction2 (github.com)](https://github.com/open-mmlab/mmaction2/pull/720)))

**Improvements**

Expand All @@ -19,7 +20,8 @@

- Add LFB for AVA2.1 ([#553](https://github.com/open-mmlab/mmaction2/pull/553))
- Add slowonly_nl_embedded_gaussian_r50_4x16x1_150e_kinetics400_rgb ([#690](https://github.com/open-mmlab/mmaction2/pull/690))
- Add TSN with ResNeXt-101-32x4d backbone ([#679](https://github.com/open-mmlab/mmaction2/pull/679))
- Add TSN with ResNeXt-101-32x4d backbone as an example for using MMCls backbones ([#679](https://github.com/open-mmlab/mmaction2/pull/679))
- Add TSN with Densenet161 backbone as an example for using TorchVision backbones ([#720]([[Feature\] Enable torchvision backbones by kennymckormick · Pull Request #720 · open-mmlab/mmaction2 (github.com)](https://github.com/open-mmlab/mmaction2/pull/720)))
- Add slowonly_nl_embedded_gaussian_r50_8x8x1_150e_kinetics400_rgb ([#704](https://github.com/open-mmlab/mmaction2/pull/704))

### 0.12.0 (28/02/2021)
Expand Down
37 changes: 34 additions & 3 deletions mmaction/models/recognizers/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict

Expand Down Expand Up @@ -33,14 +34,30 @@ def __init__(self,
train_cfg=None,
test_cfg=None):
super().__init__()
# The backbones in mmcls can be used by TSN
# record the source of the backbone
self.backbone_from = 'mmaction2'

if backbone['type'].startswith('mmcls.'):
try:
import mmcls.models.builder as mmcls_builder
except (ImportError, ModuleNotFoundError):
raise ImportError('Please install mmcls to use this backbone.')
backbone['type'] = backbone['type'][6:]
self.backbone = mmcls_builder.build_backbone(backbone)
self.backbone_from = 'mmcls'
elif backbone['type'].startswith('torchvision.'):
try:
import torchvision.models
except (ImportError, ModuleNotFoundError):
raise ImportError('Please install torchvision to use this '
'backbone.')
backbone_type = backbone.pop('type')[12:]
self.backbone = torchvision.models.__dict__[backbone_type](
**backbone)
# disable the classifier
self.backbone.classifier = nn.Identity()
self.backbone.fc = nn.Identity()
self.backbone_from = 'torchvision'
else:
self.backbone = builder.build_backbone(backbone)

Expand Down Expand Up @@ -69,7 +86,17 @@ def __init__(self,

def init_weights(self):
"""Initialize the model network weights."""
self.backbone.init_weights()
if self.backbone_from in ['mmcls', 'mmaction2']:
self.backbone.init_weights()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happened for torchvision

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchvision init all modules in the init function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users might call this function by hand

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for which case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean users want to init the backbones in their own way?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No specific case. I'm concerning the semantic of this function. If the function name is "init weights", then it is expected to init weights. However, current code can silently refuse to init weights based on the value of some external variables.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add a warning in init_weights of recognizer2d

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep this for now. If this is not correct, there will be bug reports eventually.

elif self.backbone_from == 'torchvision':
warnings.warn('We do not initialize weights for backbones in '
'torchvision, since the weights for backbones in '
'torchvision are initialized in their __init__ '
'functions. ')
else:
raise NotImplementedError('Unsupported backbone source '
f'{self.backbone_from}!')

self.cls_head.init_weights()
if hasattr(self, 'neck'):
self.neck.init_weights()
Expand All @@ -84,7 +111,11 @@ def extract_feat(self, imgs):
Returns:
torch.tensor: The extracted features.
"""
x = self.backbone(imgs)
if (hasattr(self.backbone, 'features')
and self.backbone_from == 'torchvision'):
x = self.backbone.features(imgs)
else:
x = self.backbone(imgs)
return x

def average_clip(self, cls_score, num_segs=1):
Expand Down
18 changes: 18 additions & 0 deletions mmaction/models/recognizers/recognizer2d.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from torch import nn

from ..registry import RECOGNIZERS
from .base import BaseRecognizer

Expand All @@ -15,6 +17,14 @@ def forward_train(self, imgs, labels, **kwargs):
losses = dict()

x = self.extract_feat(imgs)

if self.backbone_from == 'torchvision':
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1):
# apply adaptive avg pooling
x = nn.AdaptiveAvgPool2d(1)(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shapes of different torchvision models' outputs vary, it may be N x C x H x W, N x C x 1 x 1 or N x C.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the special-cased processing is expected. Why backbones from other sources are not handled?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backbones from other sources are written by us (mmcls / mmaction2), thus are more standard. In torchvision, there can be some legacy issues.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay

x = x.reshape((x.shape[0], -1))
x = x.reshape(x.shape + (1, 1))

if hasattr(self, 'neck'):
x = [
each.reshape((-1, num_segs) +
Expand Down Expand Up @@ -42,6 +52,14 @@ def _do_test(self, imgs):
num_segs = imgs.shape[0] // batches

x = self.extract_feat(imgs)

if self.backbone_from == 'torchvision':
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1):
# apply adaptive avg pooling
x = nn.AdaptiveAvgPool2d(1)(x)
x = x.reshape((x.shape[0], -1))
x = x.reshape(x.shape + (1, 1))

if hasattr(self, 'neck'):
x = [
each.reshape((-1, num_segs) +
Expand Down
21 changes: 21 additions & 0 deletions tests/test_models/test_recognizers/test_recognizer2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,27 @@ def test_tsn():
for one_img in img_list:
recognizer(one_img, None, return_loss=False)

tv_backbone = dict(type='torchvision.densenet161', pretrained=True)
config.model['backbone'] = tv_backbone
config.model['cls_head']['in_channels'] = 2208

recognizer = build_recognizer(config.model)

input_shape = (1, 3, 3, 32, 32)
demo_inputs = generate_recognizer_demo_inputs(input_shape)

imgs = demo_inputs['imgs']
gt_labels = demo_inputs['gt_labels']

losses = recognizer(imgs, gt_labels)
assert isinstance(losses, dict)

# Test forward test
with torch.no_grad():
img_list = [img[None, :] for img in imgs]
for one_img in img_list:
recognizer(one_img, None, return_loss=False)


def test_tsm():
config = get_recognizer_cfg('tsm/tsm_r50_1x1x8_50e_kinetics400_rgb.py')
Expand Down