Skip to content

Commit

Permalink
[Improvement] Set RandAugment as Imgaug default transforms. (#585)
Browse files Browse the repository at this point in the history
* first commit for image randaugment

* fix test

* update changelog

* use percentage params for TranslateX &S TranslateY

* update tsm-r50 sthv1 config and result

* delete blank line

* fix config

* add tsm-r50 flip + randaugment sthv1

* remove useless annotations

* Update README.md

* fix

Co-authored-by: Jintao Lin <528557675@qq.com>
  • Loading branch information
irvingzhang0512 and dreamerlin authored Mar 25, 2021
1 parent 70d0c7b commit 8a79e52
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 28 deletions.
2 changes: 2 additions & 0 deletions configs/recognition/tsm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
|:--|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
|[tsm_r50_1x1x8_50e_sthv1_rgb](/configs/recognition/tsm/tsm_r50_1x1x8_50e_sthv1_rgb.py) |height 100|8| ResNet50 | ImageNet| 45.58 / 47.70|75.02 / 76.12|[45.50 / 47.33](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)|[74.34 / 76.60](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)| 7077| [ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x8_50e_sthv1_rgb/tsm_r50_1x1x8_50e_sthv1_rgb_20210203-01dce462.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x8_50e_sthv1_rgb/20210203_150227.log)| [json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x8_50e_sthv1_rgb/20210203_150227.log.json)|
|[tsm_r50_flip_1x1x8_50e_sthv1_rgb](/configs/recognition/tsm/tsm_r50_flip_1x1x8_50e_sthv1_rgb.py) |height 100|8| ResNet50 | ImageNet| 47.10 / 48.51|76.02 / 77.56|[45.50 / 47.33](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)|[74.34 / 76.60](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)| 7077| [ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_flip_1x1x8_50e_sthv1_rgb/tsm_r50_flip_1x1x8_50e_sthv1_rgb_20210203-12596f16.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_flip_1x1x8_50e_sthv1_rgb/20210203_145829.log)| [json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_flip_1x1x8_50e_sthv1_rgb/20210203_145829.log.json)|
|[tsm_r50_randaugment_1x1x8_50e_sthv1_rgb](/configs/recognition/tsm/tsm_r50_randaugment_1x1x8_50e_sthv1_rgb.py) |height 100|8| ResNet50 | ImageNet| 47.16 / 48.90|76.07 / 77.92|[45.50 / 47.33](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)|[74.34 / 76.60](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)| 7077| [ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_randaugment_1x1x8_50e_sthv1_rgb/tsm_r50_randaugment_1x1x8_50e_sthv1_rgb_20210324-481268d9.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_randaugment_1x1x8_50e_sthv1_rgb/tsm_r50_randaugment_1x1x8_50e_sthv1_rgb.log)| [json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_randaugment_1x1x8_50e_sthv1_rgb/tsm_r50_randaugment_1x1x8_50e_sthv1_rgb.json)|
|[tsm_r50_flip_randaugment_1x1x8_50e_sthv1_rgb](/configs/recognition/tsm/tsm_r50_flip_randaugment_1x1x8_50e_sthv1_rgb.py) |height 100|8| ResNet50 | ImageNet| 47.85 / 50.31|76.78 / 78.18|[45.50 / 47.33](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)|[74.34 / 76.60](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)| 7077| [ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_flip_randaugment_1x1x8_50e_sthv1_rgb/tsm_r50_flip_randaugment_1x1x8_50e_sthv1_rgb_20210324-76937692.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_flip_randaugment_1x1x8_50e_sthv1_rgb/tsm_r50_flip_randaugment_1x1x8_50e_sthv1_rgb.log)| [json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_flip_randaugment_1x1x8_50e_sthv1_rgb/tsm_r50_flip_randaugment_1x1x8_50e_sthv1_rgb.json)|
|[tsm_r50_1x1x16_50e_sthv1_rgb](/configs/recognition/tsm/tsm_r50_1x1x16_50e_sthv1_rgb.py)|height 100|8| ResNet50 | ImageNet|47.62 / 49.28|76.63 / 77.82|[47.05 / 48.61](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)|[76.40 / 77.96](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)|10390|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x16_50e_sthv1_rgb/tsm_r50_1x1x16_50e_sthv1_rgb_20201010-17fa49f6.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x16_50e_sthv1_rgb/20201010_221240.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x16_50e_sthv1_rgb/20201010_221240.log.json)|
|[tsm_r101_1x1x8_50e_sthv1_rgb](/configs/recognition/tsm/tsm_r101_1x1x8_50e_sthv1_rgb.py)|height 100|8| ResNet50 | ImageNet|45.72 / 48.43|74.67 / 76.72|[46.64 / 48.13](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)|[75.40 / 77.31](https://github.com/mit-han-lab/temporal-shift-module/tree/8d53d6fda40bea2f1b37a6095279c4b454d672bd#training)|9800|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r101_1x1x8_50e_sthv1_rgb/tsm_r101_1x1x8_50e_sthv1_rgb_20201010-43fedf2e.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r101_1x1x8_50e_sthv1_rgb/20201010_224055.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r101_1x1x8_50e_sthv1_rgb/20201010_224055.log.json)|

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
_base_ = [
'../../_base_/models/tsm_r50.py', '../../_base_/schedules/sgd_tsm_50e.py',
'../../_base_/default_runtime.py'
]

# model settings
model = dict(cls_head=dict(num_classes=174))

# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/sthv1/rawframes'
data_root_val = 'data/sthv1/rawframes'
ann_file_train = 'data/sthv1/sthv1_train_list_rawframes.txt'
ann_file_val = 'data/sthv1/sthv1_val_list_rawframes.txt'
ann_file_test = 'data/sthv1/sthv1_val_list_rawframes.txt'

sthv1_flip_label_map = {2: 4, 4: 2, 30: 41, 41: 30, 52: 66, 66: 52}
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)

train_pipeline = [
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1,
num_fixed_crops=13),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5, flip_label_map=sthv1_flip_label_map),
dict(type='Imgaug', transforms='default'),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=8,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
filename_tmpl='{:05}.jpg',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
filename_tmpl='{:05}.jpg',
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
filename_tmpl='{:05}.jpg',
pipeline=test_pipeline))
evaluation = dict(
interval=2, metrics=['top_k_accuracy', 'mean_class_accuracy'])

# optimizer
optimizer = dict(weight_decay=0.0005)

# runtime settings
work_dir = './work_dirs/tsm_r50_flip_randaugment_1x1x8_50e_sthv1_rgb/'
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
]

# model settings
# model settings# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
Expand Down
94 changes: 94 additions & 0 deletions configs/recognition/tsm/tsm_r50_randaugment_1x1x8_50e_sthv1_rgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
_base_ = [
'../../_base_/models/tsm_r50.py', '../../_base_/schedules/sgd_tsm_50e.py',
'../../_base_/default_runtime.py'
]

# model settings
model = dict(cls_head=dict(num_classes=174))

# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/sthv1/rawframes'
data_root_val = 'data/sthv1/rawframes'
ann_file_train = 'data/sthv1/sthv1_train_list_rawframes.txt'
ann_file_val = 'data/sthv1/sthv1_val_list_rawframes.txt'
ann_file_test = 'data/sthv1/sthv1_val_list_rawframes.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1,
num_fixed_crops=13),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Imgaug', transforms='default'),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=8,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
filename_tmpl='{:05}.jpg',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
filename_tmpl='{:05}.jpg',
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
filename_tmpl='{:05}.jpg',
pipeline=test_pipeline))
evaluation = dict(
interval=2, metrics=['top_k_accuracy', 'mean_class_accuracy'])

# optimizer
optimizer = dict(weight_decay=0.0005)

# runtime settings
work_dir = './work_dirs/tsm_r50_randaugment_1x1x8_50e_sthv1_rgb/'
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
dropout_ratio=0.4,
init_std=0.01),
# model training and testing settings
# train_cfg=dict(
# blending=dict(type="CutmixBlending", num_classes=400, alpha=.2)),
train_cfg=dict(
blending=dict(type='MixupBlending', num_classes=400, alpha=.2)),
test_cfg=dict(average_clips=None))
Expand Down
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
**Improvements**

- Add slowfast config/json/log/ckpt for training custom classes of AVA ([#678](https://github.com/open-mmlab/mmaction2/pull/678))
- Set RandAugment as Imgaug default transforms ([#585](https://github.com/open-mmlab/mmaction2/pull/585))

**Bug and Typo Fixes**

Expand Down
72 changes: 47 additions & 25 deletions mmaction/datasets/pipelines/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,38 +134,59 @@ def __init__(self, transforms):
[self.imgaug_builder(t) for t in self.transforms])

def default_transforms(self):
"""Default transforms for imgaug."""
"""Default transforms for imgaug.
Implement RandAugment by imgaug.
Plase visit `https://arxiv.org/abs/1909.13719` for more information.
Augmenters and hyper parameters are borrowed from the following repo:
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py # noqa
Miss one augmenter ``SolarizeAdd`` since imgaug doesn't support this.
Returns:
dict: The constructed RandAugment transforms.
"""
# RandAugment hyper params
num_augmenters = 2
cur_magnitude, max_magnitude = 9, 10
cur_level = 1.0 * cur_magnitude / max_magnitude

return [
dict(type='Rotate', rotate=(-30, 30)),
dict(
type='SomeOf',
n=(0, 3),
n=num_augmenters,
children=[
dict(
type='OneOf',
children=[
dict(type='GaussianBlur', sigma=(0, 0.5)),
dict(type='AverageBlur', k=(2, 7)),
dict(type='MedianBlur', k=(3, 11))
]),
type='ShearX',
shear=17.19 * cur_level * random.choice([-1, 1])),
dict(
type='ShearY',
shear=17.19 * cur_level * random.choice([-1, 1])),
dict(
type='TranslateX',
percent=.2 * cur_level * random.choice([-1, 1])),
dict(
type='TranslateY',
percent=.2 * cur_level * random.choice([-1, 1])),
dict(
type='Rotate',
rotate=30 * cur_level * random.choice([-1, 1])),
dict(type='Posterize', nb_bits=max(1, int(4 * cur_level))),
dict(type='Solarize', threshold=256 * cur_level),
dict(type='EnhanceColor', factor=1.8 * cur_level + .1),
dict(type='EnhanceContrast', factor=1.8 * cur_level + .1),
dict(
type='OneOf',
children=[
dict(
type='Dropout', p=(0.01, 0.1),
per_channel=0.5),
dict(
type='CoarseDropout',
p=(0.03, 0.15),
size_percent=(0.02, 0.05),
per_channel=0.2),
]),
type='EnhanceBrightness', factor=1.8 * cur_level + .1),
dict(type='EnhanceSharpness', factor=1.8 * cur_level + .1),
dict(type='Autocontrast', cutoff=0),
dict(type='Equalize'),
dict(type='Invert', p=1.),
dict(
type='AdditiveGaussianNoise',
loc=0,
scale=(0.0, 0.05 * 255),
per_channel=0.5),
type='Cutout',
nb_iterations=1,
size=0.2 * cur_level,
squared=True),
]),
]

Expand All @@ -188,7 +209,8 @@ def imgaug_builder(self, cfg):

obj_type = args.pop('type')
if mmcv.is_str(obj_type):
obj_cls = getattr(iaa, obj_type)
obj_cls = getattr(iaa, obj_type) if hasattr(iaa, obj_type) \
else getattr(iaa.pillike, obj_type)
elif issubclass(obj_type, iaa.Augmenter):
obj_cls = obj_type
else:
Expand Down

0 comments on commit 8a79e52

Please sign in to comment.