Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CSN model #87

Merged
merged 13 commits into from
Aug 9, 2020
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# model settings
model = dict(
type='Recognizer3D',
backbone=dict(
type='ResNet3dCSN',
pretrained2d=False,
pretrained= # noqa: E251
'https://openmmlab.oss-accelerate.aliyuncs.com/mmaction/recognition/csn/ircsn_from_scratch_r152_ig65m_20200807-771c4135.pth', # noqa: E501
depth=152,
with_pool2=False,
bottleneck_mode='ir',
norm_eval=True,
bn_frozen=True,
zero_init_residual=False),
cls_head=dict(
type='I3DHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
dropout_ratio=0.5,
init_std=0.01))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips=None)
# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/kinetics400/rawframes_train'
data_root_val = 'data/kinetics400/rawframes_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=32,
frame_interval=2,
num_clips=1,
test_mode=True),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=32,
frame_interval=2,
num_clips=10,
test_mode=True),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=3,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD', lr=0.0005, momentum=0.9, weight_decay=0.0001) # 0.0005 for 32g
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
step=[32, 48],
warmup='linear',
warmup_ratio=0.1,
warmup_by_epoch=True,
warmup_iters=16)
total_epochs = 58
checkpoint_config = dict(interval=2)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))
log_config = dict(
interval=20,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/ircsn_ig65m_pretrained_bnfrozen_r152_32x2x1_58e_kinetics400_rgb' # noqa: E501
load_from = None
resume_from = None
workflow = [('train', 1)]
find_unused_parameters = True
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# model settings
model = dict(
type='Recognizer3D',
backbone=dict(
type='ResNet3dCSN',
pretrained2d=False,
pretrained= # noqa: E251
'https://openmmlab.oss-accelerate.aliyuncs.com/mmaction/recognition/csn/ircsn_from_scratch_r152_ig65m_20200807-771c4135.pth', # noqa: E501
depth=152,
with_pool2=False,
bottleneck_mode='ir',
norm_eval=False,
zero_init_residual=False),
cls_head=dict(
type='I3DHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
dropout_ratio=0.5,
init_std=0.01))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips=None)
# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/kinetics400/rawframes_train'
data_root_val = 'data/kinetics400/rawframes_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=32,
frame_interval=2,
num_clips=1,
test_mode=True),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=32,
frame_interval=2,
num_clips=10,
test_mode=True),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=3,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD', lr=0.0005, momentum=0.9, weight_decay=0.0001) # 0.0005 for 32g
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
step=[32, 48],
warmup='linear',
warmup_ratio=0.1,
warmup_by_epoch=True,
warmup_iters=16)
total_epochs = 58
checkpoint_config = dict(interval=2)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))
log_config = dict(
interval=20,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/ircsn_ig65m_pretrained_r152_32x2x1_58e_kinetics400_rgb'
load_from = None
resume_from = None
workflow = [('train', 1)]
6 changes: 3 additions & 3 deletions mmaction/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .backbones import (ResNet, ResNet2Plus1d, ResNet3d, ResNet3dSlowFast,
ResNet3dSlowOnly, ResNetTSM)
from .backbones import (ResNet, ResNet2Plus1d, ResNet3d, ResNet3dCSN,
ResNet3dSlowFast, ResNet3dSlowOnly, ResNetTSM)
from .builder import (build_backbone, build_head, build_localizer, build_model,
build_recognizer)
from .common import Conv2plus1d
Expand All @@ -18,5 +18,5 @@
'ResNet3dSlowFast', 'SlowFastHead', 'Conv2plus1d', 'ResNet3dSlowOnly',
'BCELossWithLogits', 'LOCALIZERS', 'build_localizer', 'PEM', 'TEM',
'BinaryLogisticRegressionLoss', 'BMN', 'BMNLoss', 'build_model',
'OHEMHingeLoss', 'SSNLoss'
'OHEMHingeLoss', 'SSNLoss', 'ResNet3dCSN'
]
3 changes: 2 additions & 1 deletion mmaction/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .resnet import ResNet
from .resnet2plus1d import ResNet2Plus1d
from .resnet3d import ResNet3d
from .resnet3d_csn import ResNet3dCSN
from .resnet3d_slowfast import ResNet3dSlowFast
from .resnet3d_slowonly import ResNet3dSlowOnly
from .resnet_tsm import ResNetTSM

__all__ = [
'ResNet', 'ResNet3d', 'ResNetTSM', 'ResNet2Plus1d', 'ResNet3dSlowFast',
'ResNet3dSlowOnly'
'ResNet3dSlowOnly', 'ResNet3dCSN'
]
21 changes: 14 additions & 7 deletions mmaction/models/backbones/resnet3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,10 @@ class ResNet3d(nn.Module):
non_local (Sequence[int]): Determine whether to apply non-local module
in the corresponding block of each stages. Default: (0, 0, 0, 0).
non_local_cfg (dict): Config for non-local module. Default: ``dict()``.
zero_init_residual (bool): Whether to use zero initialization for
residual block, Default: True.
zero_init_residual (bool):
Whether to use zero initialization for residual block,
Default: True.
kwargs (dict, optional): Key arguments for "make_res_layer".
"""

arch_settings = {
Expand Down Expand Up @@ -405,7 +407,8 @@ def __init__(self,
with_cp=False,
non_local=(0, 0, 0, 0),
non_local_cfg=dict(),
zero_init_residual=True):
zero_init_residual=True,
**kwargs):
super().__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
Expand Down Expand Up @@ -467,7 +470,8 @@ def __init__(self,
non_local_cfg=self.non_local_cfg,
inflate=self.stage_inflations[i],
inflate_style=self.inflate_style,
with_cp=with_cp)
with_cp=with_cp,
**kwargs)
self.inplanes = planes * self.block.expansion
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
Expand All @@ -492,7 +496,8 @@ def make_res_layer(self,
norm_cfg=None,
act_cfg=None,
conv_cfg=None,
with_cp=False):
with_cp=False,
**kwargs):
"""Build residual layer for ResNet3D.

Args:
Expand Down Expand Up @@ -565,7 +570,8 @@ def make_res_layer(self,
norm_cfg=norm_cfg,
conv_cfg=conv_cfg,
act_cfg=act_cfg,
with_cp=with_cp))
with_cp=with_cp,
**kwargs))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
Expand All @@ -583,7 +589,8 @@ def make_res_layer(self,
norm_cfg=norm_cfg,
conv_cfg=conv_cfg,
act_cfg=act_cfg,
with_cp=with_cp))
with_cp=with_cp,
**kwargs))

return nn.Sequential(*layers)

Expand Down
Loading