-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Enable torchvision backbones #720
Conversation
@@ -69,7 +85,8 @@ def __init__(self, | |||
|
|||
def init_weights(self): | |||
"""Initialize the model network weights.""" | |||
self.backbone.init_weights() | |||
if self.backbone_from in ['mmcls', 'mmaction2']: | |||
self.backbone.init_weights() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happened for torchvision
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchvision init all modules in the init function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Users might call this function by hand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for which case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean users want to init the backbones in their own way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No specific case. I'm concerning the semantic of this function. If the function name is "init weights", then it is expected to init weights. However, current code can silently refuse to init weights based on the value of some external variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add a warning in init_weights of recognizer2d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep this for now. If this is not correct, there will be bug reports eventually.
if self.backbone_from == 'torchvision': | ||
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1): | ||
# apply adaptive avg pooling | ||
x = nn.AdaptiveAvgPool2d(1)(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this is needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shapes of different torchvision models' outputs vary, it may be N x C x H x W, N x C x 1 x 1 or N x C.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if the special-cased processing is expected. Why backbones from other sources are not handled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Backbones from other sources are written by us (mmcls / mmaction2), thus are more standard. In torchvision, there can be some legacy issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay
Codecov Report
@@ Coverage Diff @@
## master #720 +/- ##
==========================================
- Coverage 85.45% 85.41% -0.04%
==========================================
Files 130 130
Lines 9371 9401 +30
Branches 1572 1580 +8
==========================================
+ Hits 8008 8030 +22
- Misses 962 967 +5
- Partials 401 404 +3
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
Kindly ping @innerlee |
Good to go after rebasing. |
@kennymckormick use |
@innerlee ready for merging |
An example:
backbone=dict(type='torchvision.resnet50')