Skip to content

Commit

Permalink
Fix typo in BaseHead default loss_factor. (#446)
Browse files Browse the repository at this point in the history
* Fix typo in default BaseHead.

* Minor fix unittest.

Fix docstring.

* Add changelog.

* Fix docstring.

Fix typo.
  • Loading branch information
su authored Dec 16, 2020
1 parent 5e0ffc1 commit 4cc48fc
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


**Bug and Typo Fixes**
- Fix typo in default argument of BaseHead. ([#446](https://github.com/open-mmlab/mmaction2/pull/446))

**ModelZoo**

Expand Down
4 changes: 2 additions & 2 deletions mmaction/models/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BaseHead(nn.Module, metaclass=ABCMeta):
num_classes (int): Number of classes to be classified.
in_channels (int): Number of channels in input feature.
loss_cls (dict): Config for building loss.
Default: dict(type='CrossEntropyLoss').
Default: dict(type='CrossEntropyLoss', loss_weight=1.0).
multi_class (bool): Determines whether it is a multi-class
recognition task. Default: False.
label_smooth_eps (float): Epsilon used in label smooth.
Expand All @@ -46,7 +46,7 @@ class BaseHead(nn.Module, metaclass=ABCMeta):
def __init__(self,
num_classes,
in_channels,
loss_cls=dict(type='CrossEntropyLoss', loss_factor=1.0),
loss_cls=dict(type='CrossEntropyLoss', loss_weight=1.0),
multi_class=False,
label_smooth_eps=0.0):
super().__init__()
Expand Down
11 changes: 10 additions & 1 deletion tests/test_models/test_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class ExampleHead(BaseHead):
# use a ExampleHead to success BaseHead
# use an ExampleHead to test BaseHead
def init_weights(self):
pass

Expand All @@ -24,6 +24,15 @@ def test_base_head():
assert 'loss_cls' in losses.keys()
assert losses.get('loss_cls') > 0, 'cls loss should be non-zero'

head = ExampleHead(3, 400, dict(type='CrossEntropyLoss', loss_weight=2.0))

cls_scores = torch.rand((3, 4))
# When truth is non-empty then cls loss should be nonzero for random inputs
gt_labels = torch.LongTensor([2] * 3).squeeze()
losses = head.loss(cls_scores, gt_labels)
assert 'loss_cls' in losses.keys()
assert losses.get('loss_cls') > 0, 'cls loss should be non-zero'


def test_i3d_head():
"""Test loss method, layer construction, attributes and forward function in
Expand Down

0 comments on commit 4cc48fc

Please sign in to comment.