Skip to content

Commit

Permalink
Add TIN model (#53)
Browse files Browse the repository at this point in the history
Co-authored-by: lizz <innerlee@users.noreply.github.com>
  • Loading branch information
dreamerlin and innerlee authored Aug 27, 2020
1 parent 15575db commit c9aea9c
Show file tree
Hide file tree
Showing 8 changed files with 631 additions and 4 deletions.
131 changes: 131 additions & 0 deletions configs/recognition/tin/tin_r50_1x1x8_40e_sthv1_rgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='ResNetTIN',
pretrained='torchvision://resnet50',
depth=50,
norm_eval=False,
shift_div=4),
cls_head=dict(
type='TSMHead',
num_classes=174,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.8,
init_std=0.001,
is_shift=False))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips=None)
# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/sth-v1/rawframes_train/'
data_root_val = 'data/sth-v1/rawframes_val/'
ann_file_train = 'data/sth-v1/sth-v1_train_list.txt'
ann_file_val = 'data/sth-v1/sth-v1_val_list.txt'
ann_file_test = 'data/sth-v1/sth-v1_val_list.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=6,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
filename_tmpl='{:05}.jpg',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
filename_tmpl='{:05}.jpg',
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
filename_tmpl='{:05}.jpg',
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD',
constructor='TSMOptimizerConstructor',
paramwise_cfg=dict(fc_lr5=True),
lr=0.02,
momentum=0.9,
weight_decay=0.0005)
optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2))
# learning policy
lr_config = dict(
policy='CosineAnnealing',
min_lr_ratio=0.5,
warmup='linear',
warmup_ratio=0.1,
warmup_by_epoch=True,
warmup_iters=1)
total_epochs = 40
checkpoint_config = dict(interval=1)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook'),
])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/tin_r50_1x1x8_40e_sthv1_rgb/'
load_from = None
resume_from = None
workflow = [('train', 1)]
1 change: 1 addition & 0 deletions mmaction/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .dist_utils import * # noqa: F401, F403
from .evaluation import * # noqa: F401, F403
from .lr import * # noqa: F401, F403
from .optimizer import * # noqa: F401, F403
39 changes: 39 additions & 0 deletions mmaction/core/lr/tin_lr_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from mmcv.runner import HOOKS, LrUpdaterHook
from mmcv.runner.hooks.lr_updater import annealing_cos


@HOOKS.register_module()
class TINLrUpdaterHook(LrUpdaterHook):

def __init__(self, min_lr, **kwargs):
self.min_lr = min_lr
super(TINLrUpdaterHook, self).__init__(**kwargs)

def get_warmup_lr(self, cur_iters):
if self.warmup == 'linear':
# 'linear' warmup is rewritten according to TIN repo:
# https://github.com/deepcs233/TIN/blob/master/main.py#L409-L412
k = (cur_iters / self.warmup_iters) * (
1 - self.warmup_ratio) + self.warmup_ratio
warmup_lr = [_lr * k for _lr in self.regular_lr]
elif self.warmup == 'constant':
warmup_lr = [_lr * self.warmup_ratio for _lr in self.regular_lr]
elif self.warmup == 'exp':
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
warmup_lr = [_lr * k for _lr in self.regular_lr]
return warmup_lr

def get_lr(self, runner, base_lr):
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
else:
progress = runner.iter
max_progress = runner.max_iters

target_lr = self.min_lr
if self.warmup is not None:
progress = progress - self.warmup_iters
max_progress = max_progress - self.warmup_iters
factor = progress / max_progress
return annealing_cos(base_lr, target_lr, factor)
5 changes: 3 additions & 2 deletions mmaction/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .backbones import (ResNet, ResNet2Plus1d, ResNet3d, ResNet3dCSN,
ResNet3dSlowFast, ResNet3dSlowOnly, ResNetTSM)
ResNet3dSlowFast, ResNet3dSlowOnly, ResNetTIN,
ResNetTSM)
from .builder import (build_backbone, build_head, build_localizer, build_model,
build_recognizer)
from .common import Conv2plus1d
Expand All @@ -18,5 +19,5 @@
'ResNet3dSlowFast', 'SlowFastHead', 'Conv2plus1d', 'ResNet3dSlowOnly',
'BCELossWithLogits', 'LOCALIZERS', 'build_localizer', 'PEM', 'TEM',
'BinaryLogisticRegressionLoss', 'BMN', 'BMNLoss', 'build_model',
'OHEMHingeLoss', 'SSNLoss', 'ResNet3dCSN'
'OHEMHingeLoss', 'SSNLoss', 'ResNet3dCSN', 'ResNetTIN'
]
3 changes: 2 additions & 1 deletion mmaction/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from .resnet3d_csn import ResNet3dCSN
from .resnet3d_slowfast import ResNet3dSlowFast
from .resnet3d_slowonly import ResNet3dSlowOnly
from .resnet_tin import ResNetTIN
from .resnet_tsm import ResNetTSM

__all__ = [
'ResNet', 'ResNet3d', 'ResNetTSM', 'ResNet2Plus1d', 'ResNet3dSlowFast',
'ResNet3dSlowOnly', 'ResNet3dCSN'
'ResNet3dSlowOnly', 'ResNet3dCSN', 'ResNetTIN'
]
18 changes: 18 additions & 0 deletions mmaction/models/backbones/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ class ResNet(nn.Module):
Default: dict(type='ReLU', inplace=True).
norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var). Default: True.
partial_bn (bool): Whether to use partial bn. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
Expand All @@ -342,6 +343,7 @@ def __init__(self,
norm_cfg=dict(type='BN2d', requires_grad=True),
act_cfg=dict(type='ReLU', inplace=True),
norm_eval=True,
partial_bn=False,
with_cp=False):
super().__init__()
if depth not in self.arch_settings:
Expand All @@ -361,6 +363,7 @@ def __init__(self,
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.partial_bn = partial_bn
self.with_cp = with_cp

self.block, stage_blocks = self.arch_settings[depth]
Expand Down Expand Up @@ -551,6 +554,19 @@ def _freeze_stages(self):
for param in m.parameters():
param.requires_grad = False

def _partial_bn(self):
logger = get_root_logger()
logger.info('Freezing BatchNorm2D except the first one.')
count_bn = 0
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
count_bn += 1
if count_bn >= 2:
m.eval()
# shutdown update in frozen mode
m.weight.requires_grad = False
m.bias.requires_grad = False

def train(self, mode=True):
"""Set the optimization status when training."""
super().train(mode)
Expand All @@ -559,3 +575,5 @@ def train(self, mode=True):
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
if mode and self.partial_bn:
self._partial_bn()
Loading

0 comments on commit c9aea9c

Please sign in to comment.