Skip to content

Commit

Permalink
add optional switch
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin committed Aug 3, 2020
1 parent a5bbb26 commit 5f617fd
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
1 change: 1 addition & 0 deletions mmaction/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .dist_utils import * # noqa: F401, F403
from .evaluation import * # noqa: F401, F403
from .fp16 import * # noqa: F401, F403
from .lr import * # noqa: F401, F403
from .optimizer import * # noqa: F401, F403
39 changes: 39 additions & 0 deletions mmaction/core/lr/tin_lr_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from mmcv.runner import HOOKS, LrUpdaterHook
from mmcv.runner.hooks.lr_updater import annealing_cos


@HOOKS.register_module()
class TINLrUpdaterHook(LrUpdaterHook):

def __init__(self, min_lr, **kwargs):
self.min_lr = min_lr
super(TINLrUpdaterHook, self).__init__(**kwargs)

def get_warmup_lr(self, cur_iters):
if self.warmup == 'linear':
# 'linear' warmup is rewritten according to TIN repo:
# https://github.com/deepcs233/TIN/blob/master/main.py#L409-L412
k = (cur_iters / self.warmup_iters) * (
1 - self.warmup_ratio) + self.warmup_ratio
warmup_lr = [_lr * k for _lr in self.regular_lr]
elif self.warmup == 'constant':
warmup_lr = [_lr * self.warmup_ratio for _lr in self.regular_lr]
elif self.warmup == 'exp':
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
warmup_lr = [_lr * k for _lr in self.regular_lr]
return warmup_lr

def get_lr(self, runner, base_lr):
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
else:
progress = runner.iter
max_progress = runner.max_iters

target_lr = self.min_lr
if self.warmup is not None:
progress = progress - self.warmup_iters
max_progress = max_progress - self.warmup_iters
factor = progress / max_progress
return annealing_cos(base_lr, target_lr, factor)
19 changes: 19 additions & 0 deletions mmaction/models/backbones/resnet_tin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn

from ...utils import get_root_logger
from ..registry import BACKBONES
from .resnet_tsm import ResNetTSM

Expand Down Expand Up @@ -309,6 +310,7 @@ class ResNetTIN(ResNetTSM):
num_segments (int): Number of frame segments. Default: 8.
is_tin (bool): Whether to apply temporal interlace. Default: True.
shift_div (int): Number of division parts for shift. Default: 4.
partial_bn (bool): Whether to use partial bn. Default: False.
kwargs (dict, optional): Arguments for ResNet.
"""

Expand All @@ -317,11 +319,13 @@ def __init__(self,
num_segments=8,
is_tin=True,
shift_div=4,
partial_bn=False,
**kwargs):
super().__init__(depth, **kwargs)
self.num_segments = num_segments
self.is_tin = is_tin
self.shift_div = shift_div
self.partial_bn = partial_bn

def make_temporal_interlace(self):
"""Make temporal interlace for some layers."""
Expand Down Expand Up @@ -372,3 +376,18 @@ def init_weights(self):
self.make_temporal_interlace()
if len(self.non_local_cfg) != 0:
self.make_non_local()

def train(self, mode=True):
super().train(mode)
if mode and self.partial_bn:
logger = get_root_logger()
logger.info('Freezing BatchNorm2D except the first one.')
count_bn = 0
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
count_bn += 1
if count_bn >= 2:
m.eval()
# shutdown update in frozen mode
m.weight.requires_grad = False
m.bias.requires_grad = False

0 comments on commit 5f617fd

Please sign in to comment.