-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Enable torchvision backbones #720
Changes from all commits
c5f00a8
05575c1
eb09070
81302a4
43a649a
755809e
d478c9d
08bbc06
ff958e6
d0e192d
a52c536
81a2029
e03d2a9
2a9b57f
28001ff
46cc5dd
667818a
34398a8
9928127
b633443
a891176
611ae31
96662d4
ef5ab3c
0b63bd0
d4450fc
467b8d8
1e2e566
f91fe5e
c2067b1
0b923cf
c181802
5ca360e
0eebe93
96affe1
bae040c
4eda6cb
7994bf4
6dc64d4
7d83ab8
d60308b
1ab9add
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
_base_ = [ | ||
'../../../_base_/schedules/sgd_100e.py', | ||
'../../../_base_/default_runtime.py' | ||
] | ||
|
||
# model settings | ||
model = dict( | ||
type='Recognizer2D', | ||
backbone=dict(type='torchvision.densenet161', pretrained=True), | ||
cls_head=dict( | ||
type='TSNHead', | ||
num_classes=400, | ||
in_channels=2208, | ||
spatial_type='avg', | ||
consensus=dict(type='AvgConsensus', dim=1), | ||
dropout_ratio=0.4, | ||
init_std=0.01), | ||
# model training and testing settings | ||
train_cfg=None, | ||
test_cfg=dict(average_clips=None)) | ||
|
||
# dataset settings | ||
dataset_type = 'RawframeDataset' | ||
data_root = 'data/kinetics400/rawframes_train_320p' | ||
data_root_val = 'data/kinetics400/rawframes_val_320p' | ||
ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes_320p.txt' | ||
ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes_320p.txt' | ||
ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes_320p.txt' | ||
img_norm_cfg = dict( | ||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) | ||
train_pipeline = [ | ||
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=3), | ||
dict(type='RawFrameDecode'), | ||
dict(type='Resize', scale=(-1, 256)), | ||
dict(type='RandomResizedCrop'), | ||
dict(type='Resize', scale=(224, 224), keep_ratio=False), | ||
dict(type='Flip', flip_ratio=0.5), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='FormatShape', input_format='NCHW'), | ||
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), | ||
dict(type='ToTensor', keys=['imgs', 'label']) | ||
] | ||
val_pipeline = [ | ||
dict( | ||
type='SampleFrames', | ||
clip_len=1, | ||
frame_interval=1, | ||
num_clips=3, | ||
test_mode=True), | ||
dict(type='RawFrameDecode'), | ||
dict(type='Resize', scale=(-1, 256)), | ||
dict(type='CenterCrop', crop_size=256), | ||
dict(type='Flip', flip_ratio=0), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='FormatShape', input_format='NCHW'), | ||
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), | ||
dict(type='ToTensor', keys=['imgs']) | ||
] | ||
test_pipeline = [ | ||
dict( | ||
type='SampleFrames', | ||
clip_len=1, | ||
frame_interval=1, | ||
num_clips=25, | ||
test_mode=True), | ||
dict(type='RawFrameDecode'), | ||
dict(type='Resize', scale=(-1, 256)), | ||
dict(type='ThreeCrop', crop_size=256), | ||
dict(type='Flip', flip_ratio=0), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='FormatShape', input_format='NCHW'), | ||
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), | ||
dict(type='ToTensor', keys=['imgs']) | ||
] | ||
data = dict( | ||
videos_per_gpu=12, | ||
workers_per_gpu=4, | ||
train=dict( | ||
type=dataset_type, | ||
ann_file=ann_file_train, | ||
data_prefix=data_root, | ||
pipeline=train_pipeline), | ||
val=dict( | ||
type=dataset_type, | ||
ann_file=ann_file_val, | ||
data_prefix=data_root_val, | ||
pipeline=val_pipeline), | ||
test=dict( | ||
type=dataset_type, | ||
ann_file=ann_file_test, | ||
data_prefix=data_root_val, | ||
pipeline=test_pipeline)) | ||
|
||
# runtime settings | ||
work_dir = './work_dirs/tsn_dense161_320p_1x1x3_100e_kinetics400_rgb/' | ||
optimizer = dict( | ||
type='SGD', | ||
lr=0.00375, # this lr is used for 8 gpus | ||
momentum=0.9, | ||
weight_decay=0.0001) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import torch | ||
from torch import nn | ||
|
||
from ..registry import RECOGNIZERS | ||
from .base import BaseRecognizer | ||
|
@@ -17,6 +18,14 @@ def forward_train(self, imgs, labels, **kwargs): | |
losses = dict() | ||
|
||
x = self.extract_feat(imgs) | ||
|
||
if self.backbone_from == 'torchvision': | ||
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1): | ||
# apply adaptive avg pooling | ||
x = nn.AdaptiveAvgPool2d(1)(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this is needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The shapes of different torchvision models' outputs vary, it may be N x C x H x W, N x C x 1 x 1 or N x C. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if the special-cased processing is expected. Why backbones from other sources are not handled? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Backbones from other sources are written by us (mmcls / mmaction2), thus are more standard. In torchvision, there can be some legacy issues. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay |
||
x = x.reshape((x.shape[0], -1)) | ||
x = x.reshape(x.shape + (1, 1)) | ||
|
||
if hasattr(self, 'neck'): | ||
x = [ | ||
each.reshape((-1, num_segs) + | ||
|
@@ -43,6 +52,14 @@ def _do_test(self, imgs): | |
num_segs = imgs.shape[0] // batches | ||
|
||
x = self.extract_feat(imgs) | ||
|
||
if self.backbone_from == 'torchvision': | ||
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1): | ||
# apply adaptive avg pooling | ||
x = nn.AdaptiveAvgPool2d(1)(x) | ||
x = x.reshape((x.shape[0], -1)) | ||
x = x.reshape(x.shape + (1, 1)) | ||
|
||
if hasattr(self, 'neck'): | ||
x = [ | ||
each.reshape((-1, num_segs) + | ||
|
@@ -110,7 +127,7 @@ def forward_test(self, imgs): | |
"""Defines the computation performed at every call when evaluation and | ||
testing.""" | ||
if self.test_cfg.get('fcn_test', False): | ||
# If specified, spatially fully-convolutional testing is performed | ||
# If specified, spatially fully-convolutional testing is performed | ||
return self._do_fcn_test(imgs).cpu().numpy() | ||
return self._do_test(imgs).cpu().numpy() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happened for torchvision
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchvision init all modules in the init function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Users might call this function by hand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for which case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean users want to init the backbones in their own way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No specific case. I'm concerning the semantic of this function. If the function name is "init weights", then it is expected to init weights. However, current code can silently refuse to init weights based on the value of some external variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add a warning in init_weights of recognizer2d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep this for now. If this is not correct, there will be bug reports eventually.