Skip to content

Commit 111fe85

Browse files
authored
Model doc revamp - first PR (pytorch#5821)
* First PR for model doc revamp * Deactivating fail on warning, temporarily * Remove commnet * Minor changes * Typos * Added TODO in Makefile * Keep old models.rst file intact, move new docs into new models_new.rst file
1 parent c888d6d commit 111fe85

File tree

10 files changed

+369
-61
lines changed

10 files changed

+369
-61
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ docs/build
1515
docs/source/auto_examples/
1616
docs/source/gen_modules/
1717
docs/source/generated/
18+
docs/source/models/generated/
1819
# pytorch-sphinx-theme gets installed here
1920
docs/src
2021

docs/Makefile

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ ifneq ($(EXAMPLES_PATTERN),)
66
endif
77

88
# You can set these variables from the command line.
9-
SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS)
9+
# TODO: Once the models doc revamp is done, set back the -W option to raise
10+
# errors on warnings. See https://github.com/pytorch/vision/pull/5821#discussion_r850500693
11+
SPHINXOPTS = -j auto $(EXAMPLES_PATTERN_OPTS)
1012
SPHINXBUILD = sphinx-build
1113
SPHINXPROJ = torchvision
1214
SOURCEDIR = source

docs/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ numpy
33
sphinx-copybutton>=0.3.1
44
sphinx-gallery>=0.9.0
55
sphinx==3.5.4
6+
tabulate
67
# This pin is only needed for sphinx<4.0.2. See https://github.com/pytorch/vision/issues/5673 for details
78
Jinja2<3.1.*
89
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme

docs/source/conf.py

+56
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@
2121
# sys.path.insert(0, os.path.abspath('.'))
2222

2323
import os
24+
import textwrap
25+
from pathlib import Path
2426

2527
import pytorch_sphinx_theme
2628
import torchvision
29+
import torchvision.models as M
30+
from tabulate import tabulate
2731

2832

2933
# -- General configuration ------------------------------------------------
@@ -292,5 +296,57 @@ def inject_minigalleries(app, what, name, obj, options, lines):
292296
lines.append("\n")
293297

294298

299+
def inject_weight_metadata(app, what, name, obj, options, lines):
300+
301+
if obj.__name__.endswith("_Weights"):
302+
lines[:] = ["The model builder above accepts the following values as the ``weights`` parameter:"]
303+
lines.append("")
304+
for field in obj:
305+
lines += [f"**{str(field)}**:", ""]
306+
307+
table = []
308+
for k, v in field.meta.items():
309+
if k == "categories":
310+
continue
311+
elif k == "recipe":
312+
v = f"`link <{v}>`__"
313+
table.append((str(k), str(v)))
314+
table = tabulate(table, tablefmt="rst")
315+
lines += [".. table::", ""]
316+
lines += textwrap.indent(table, " " * 4).split("\n")
317+
lines.append("")
318+
319+
320+
def generate_classification_table():
321+
322+
weight_enums = [getattr(M, name) for name in dir(M) if name.endswith("_Weights")]
323+
weights = [w for weight_enum in weight_enums for w in weight_enum]
324+
325+
column_names = ("**Weight**", "**Acc@1**", "**Acc@5**", "**Params**", "**Recipe**")
326+
content = [
327+
(
328+
f":class:`{w} <{type(w).__name__}>`",
329+
w.meta["acc@1"],
330+
w.meta["acc@5"],
331+
f"{w.meta['num_params']/1e6:.1f}M",
332+
f"`link <{w.meta['recipe']}>`__",
333+
)
334+
for w in weights
335+
]
336+
table = tabulate(content, headers=column_names, tablefmt="rst")
337+
338+
generated_dir = Path("generated")
339+
generated_dir.mkdir(exist_ok=True)
340+
with open(generated_dir / "classification_table.rst", "w+") as table_file:
341+
table_file.write(".. table::\n")
342+
table_file.write(" :widths: 100 10 10 20 10\n\n")
343+
table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n")
344+
345+
346+
generate_classification_table()
347+
348+
295349
def setup(app):
350+
296351
app.connect("autodoc-process-docstring", inject_minigalleries)
352+
app.connect("autodoc-process-docstring", inject_weight_metadata)

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ architectures, and common image transformations for computer vision.
3838
ops
3939
io
4040
feature_extraction
41+
models_new
4142

4243
.. toctree::
4344
:maxdepth: 1

docs/source/models/resnet.rst

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
ResNet
2+
======
3+
4+
.. currentmodule:: torchvision.models
5+
6+
The ResNet model is based on the `Deep Residual Learning for Image Recognition
7+
<https://arxiv.org/abs/1512.03385>`_ paper.
8+
9+
10+
Model builders
11+
--------------
12+
13+
The following model builders can be used to instanciate a ResNet model, with or
14+
without pre-trained weights. All the model builders internally rely on the
15+
``torchvision.models.resnet.ResNet`` base class. Please refer to the `source
16+
code
17+
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_ for
18+
more details about this class.
19+
20+
.. autosummary::
21+
:toctree: generated/
22+
:template: function.rst
23+
24+
resnet18
25+
resnet34
26+
resnet50
27+
resnet101
28+
resnet152

docs/source/models/vgg.rst

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
VGG
2+
===
3+
4+
.. currentmodule:: torchvision.models
5+
6+
The VGG model is based on the `Very Deep Convolutional Networks for Large-Scale
7+
Image Recognition <https://arxiv.org/abs/1409.1556>`_ paper.
8+
9+
10+
Model builders
11+
--------------
12+
13+
The following model builders can be used to instanciate a VGG model, with or
14+
without pre-trained weights. All the model buidlers internally rely on the
15+
``torchvision.models.vgg.VGG`` base class. Please refer to the `source code
16+
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_ for
17+
more details about this class.
18+
19+
.. autosummary::
20+
:toctree: generated/
21+
:template: function.rst
22+
23+
vgg11
24+
vgg11_bn
25+
vgg13
26+
vgg13_bn
27+
vgg16
28+
vgg16_bn
29+
vgg19
30+
vgg19_bn

docs/source/models_new.rst

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
.. _models_new:
2+
3+
Models and pre-trained weights - New
4+
####################################
5+
6+
.. note::
7+
8+
These are the new models docs, documenting the new multi-weight API.
9+
TODO: Once all is done, remove the "- New" part in the title above, and
10+
rename this file as models.rst
11+
12+
13+
The ``torchvision.models`` subpackage contains definitions of models for addressing
14+
different tasks, including: image classification, pixelwise semantic
15+
segmentation, object detection, instance segmentation, person
16+
keypoint detection, video classification, and optical flow.
17+
18+
.. note ::
19+
Backward compatibility is guaranteed for loading a serialized
20+
``state_dict`` to the model created using old PyTorch version.
21+
On the contrary, loading entire saved models or serialized
22+
``ScriptModules`` (seralized using older versions of PyTorch)
23+
may not preserve the historic behaviour. Refer to the following
24+
`documentation
25+
<https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
26+
27+
28+
Classification
29+
==============
30+
31+
.. currentmodule:: torchvision.models
32+
33+
The following classification models are available, with or without pre-trained
34+
weights:
35+
36+
.. toctree::
37+
:maxdepth: 1
38+
39+
models/resnet
40+
models/vgg
41+
42+
43+
Table of all available classification weights
44+
---------------------------------------------
45+
46+
Accuracies are reported on ImageNet
47+
48+
.. include:: generated/classification_table.rst
49+
50+
51+
Object Detection, Instance Segmentation and Person Keypoint Detection
52+
=====================================================================
53+
54+
TODO: Something similar to classification models: list of models + table of weights

torchvision/models/resnet.py

+75-20
Original file line numberDiff line numberDiff line change
@@ -556,12 +556,23 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
556556

557557
@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
558558
def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
559-
r"""ResNet-18 model from
560-
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
559+
"""ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
561560
562561
Args:
563-
weights (ResNet18_Weights, optional): The pretrained weights for the model
564-
progress (bool): If True, displays a progress bar of the download to stderr
562+
weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
563+
pretrained weights to use. See
564+
:class:`~torchvision.models.ResNet18_Weights` below for
565+
more details, and possible values. By default, no pre-trained
566+
weights are used.
567+
progress (bool, optional): If True, displays a progress bar of the
568+
download to stderr. Default is True.
569+
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
570+
base class. Please refer to the `source code
571+
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
572+
for more details about this class.
573+
574+
.. autoclass:: torchvision.models.ResNet18_Weights
575+
:members:
565576
"""
566577
weights = ResNet18_Weights.verify(weights)
567578

@@ -570,12 +581,23 @@ def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = Tru
570581

571582
@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
572583
def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
573-
r"""ResNet-34 model from
574-
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
584+
"""ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
575585
576586
Args:
577-
weights (ResNet34_Weights, optional): The pretrained weights for the model
578-
progress (bool): If True, displays a progress bar of the download to stderr
587+
weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
588+
pretrained weights to use. See
589+
:class:`~torchvision.models.ResNet34_Weights` below for
590+
more details, and possible values. By default, no pre-trained
591+
weights are used.
592+
progress (bool, optional): If True, displays a progress bar of the
593+
download to stderr. Default is True.
594+
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
595+
base class. Please refer to the `source code
596+
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
597+
for more details about this class.
598+
599+
.. autoclass:: torchvision.models.ResNet34_Weights
600+
:members:
579601
"""
580602
weights = ResNet34_Weights.verify(weights)
581603

@@ -584,12 +606,23 @@ def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = Tru
584606

585607
@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
586608
def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
587-
r"""ResNet-50 model from
588-
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
609+
"""ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
589610
590611
Args:
591-
weights (ResNet50_Weights, optional): The pretrained weights for the model
592-
progress (bool): If True, displays a progress bar of the download to stderr
612+
weights (:class:`~torchvision.models.ResNet50_Weights`, optional): The
613+
pretrained weights to use. See
614+
:class:`~torchvision.models.ResNet50_Weights` below for
615+
more details, and possible values. By default, no pre-trained
616+
weights are used.
617+
progress (bool, optional): If True, displays a progress bar of the
618+
download to stderr. Default is True.
619+
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
620+
base class. Please refer to the `source code
621+
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
622+
for more details about this class.
623+
624+
.. autoclass:: torchvision.models.ResNet50_Weights
625+
:members:
593626
"""
594627
weights = ResNet50_Weights.verify(weights)
595628

@@ -598,12 +631,23 @@ def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = Tru
598631

599632
@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
600633
def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
601-
r"""ResNet-101 model from
602-
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
634+
"""ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
603635
604636
Args:
605-
weights (ResNet101_Weights, optional): The pretrained weights for the model
606-
progress (bool): If True, displays a progress bar of the download to stderr
637+
weights (:class:`~torchvision.models.ResNet101_Weights`, optional): The
638+
pretrained weights to use. See
639+
:class:`~torchvision.models.ResNet101_Weights` below for
640+
more details, and possible values. By default, no pre-trained
641+
weights are used.
642+
progress (bool, optional): If True, displays a progress bar of the
643+
download to stderr. Default is True.
644+
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
645+
base class. Please refer to the `source code
646+
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
647+
for more details about this class.
648+
649+
.. autoclass:: torchvision.models.ResNet101_Weights
650+
:members:
607651
"""
608652
weights = ResNet101_Weights.verify(weights)
609653

@@ -612,12 +656,23 @@ def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = T
612656

613657
@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
614658
def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
615-
r"""ResNet-152 model from
616-
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
659+
"""ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
617660
618661
Args:
619-
weights (ResNet152_Weights, optional): The pretrained weights for the model
620-
progress (bool): If True, displays a progress bar of the download to stderr
662+
weights (:class:`~torchvision.models.ResNet152_Weights`, optional): The
663+
pretrained weights to use. See
664+
:class:`~torchvision.models.ResNet152_Weights` below for
665+
more details, and possible values. By default, no pre-trained
666+
weights are used.
667+
progress (bool, optional): If True, displays a progress bar of the
668+
download to stderr. Default is True.
669+
**kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
670+
base class. Please refer to the `source code
671+
<https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
672+
for more details about this class.
673+
674+
.. autoclass:: torchvision.models.ResNet152_Weights
675+
:members:
621676
"""
622677
weights = ResNet152_Weights.verify(weights)
623678

0 commit comments

Comments
 (0)