Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
congee524 committed Mar 30, 2021
1 parent cf2a15f commit 255fc3f
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions mmaction/models/recognizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F
from mmcv.runner import auto_fp16

from ...utils import get_root_logger
from .. import builder


Expand Down Expand Up @@ -61,15 +62,22 @@ def __init__(self,
self.backbone_from = 'torchvision'

if partial_bn:
count_bn = 0
for m in self.backbone.modules():
if isinstance(m, nn.BatchNorm2d):
count_bn += 1
if count_bn >= 2:
m.eval()
# shutdown update in frozen mode
m.weight.requires_grad = False
m.bias.requires_grad = False

def my_train(self, mode=True):
super().train(mode)
count_bn = 0
logger = get_root_logger()
logger.info('Freezing BatchNorm2D except the first one.')
for m in self.backbone.modules():
if isinstance(m, nn.BatchNorm2d):
count_bn += 1
if count_bn >= 2:
m.eval()
# shutdown update in frozen mode
m.weight.requires_grad = False
m.bias.requires_grad = False

self.backbone.train = my_train

else:
self.backbone = builder.build_backbone(backbone)
Expand Down

0 comments on commit 255fc3f

Please sign in to comment.