Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX exporting support. #160

Merged
merged 8 commits into from
Sep 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ Please refer to [data_preparation.md](docs/data_preparation.md) for a general kn
## Get Started

Please see [getting_started.md](docs/getting_started.md) for the basic usage of MMAction2.
There are also tutorials for [finetuning models](docs/tutorials/finetune.md), [adding new dataset](docs/tutorials/new_dataset.md), [designing data pipeline](docs/tutorials/data_pipeline.md), and [adding new modules](docs/tutorials/new_modules.md).
There are also tutorials for [finetuning models](docs/tutorials/finetune.md), [adding new dataset](docs/tutorials/new_dataset.md), [designing data pipeline](docs/tutorials/data_pipeline.md), [exporting model to onnx](docs/tutorials/export_model.md) and [adding new modules](docs/tutorials/new_modules.md).

A Colab tutorial is also provided. You may preview the notebook [here](demo/mmaction2_tutorial.ipynb) or directly [run](https://colab.research.google.com/github/open-mmlab/mmaction2/blob/master/demo/mmaction2_tutorial.ipynb) on Colab.

Expand Down
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

**New Features**
- Support to run real-time action recognition from web camera ([#171](https://github.com/open-mmlab/mmaction2/pull/171))
- Support to export the pytorch models to onnx ones. ([#160](https://github.com/open-mmlab/mmaction2/pull/160))

**ModelZoo**

Expand Down
41 changes: 41 additions & 0 deletions docs/tutorials/export_model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Tutorial 5: Exporting a model to ONNX

Open Neural Network Exchange [(ONNX)](https://onnx.ai/) is an open ecosystem that empowers AI developers to choose the right tools as their project evolves. So far, our codebase supports onnx exporting from pytorch models trained with mmaction2. The supported models are:

+ I3D
+ TSN
+ TIN
+ TSM
+ R(2+1)D
+ SLOWFAST
+ SLOWONLY
+ BMN
+ BSN(tem, pem)

## Usage
For simple exporting, you can use the [script](../../tools/torch2onnx.py) here. Note that the package `onnx` is requried for verification after exporting.

### Prerequisite
First, install onnx.
```shell
pip install onnx
```

### Recognizers
For recognizers, if your model are trained with a config from mmaction2 and intend to inference it according to the test pipeline, simply run:
```shell
python tools/torch2onnx.py $CONFIG_PATH $CHECKPOINT_PATH
```

Otherwise, if you want to customize the input tensor shape, you can modify the `test_pipeline` in your config `$CONFIG_PATH`, or run:
```shell
python tools/torch2onnx.py $CONFIG_PATH $CHECKPOINT_PATH --input-size $BATCHS $CROPS $CHANNELS $CLIP_LENGTH $HEIGHT $WIDTH
```

### Localizer
For localizers, we *only* support customized input size, since our abstractions for localizers(eg. SSN, BMN) are not unified. Please run:
```shell
python tools/torch2onnx.py $CONFIG_PATH $CHECKPOINT_PATH --is-localizer --input-size $INPUT_SIZE
```

Please fire an issue if you discover any checkpoints that are not perfectly exported or suffer some loss in accuracy.
1 change: 1 addition & 0 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
av
decord
onnx
PyTurboJPEG
123 changes: 123 additions & 0 deletions tools/torch2onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import argparse
import os.path as osp
import warnings

import mmcv
import torch
import torch.nn as nn
from mmcv.runner import load_checkpoint

from mmaction.models import build_model

try:
import onnx
except ImportError:
warnings.warn('Please install onnx to support onnx exporting.')


class RecognizerWrapper(nn.Module):
"""Wrapper that only inferences the part in computation graph."""

def __init__(self, recognizer):
super().__init__()
self.recognizer = recognizer

def forward(self, x):
return self.recognizer.forward_dummy(x)


class LocalizerWrapper(nn.Module):
"""Wrapper that only inferences the part in computation graph."""

def __init__(self, localizer):
super().__init__()
self.localizer = localizer

def forward(self, x):
return self.localizer._forward(x)


def _get_cfg(config_path):
"""Grab configs necessary to create a model."""
if not osp.exists(config_path):
raise FileNotFoundError('Cannot find config path')
config = mmcv.Config.fromfile(config_path)
return config.model, config.data.test.pipeline, config.test_cfg


def torch2onnx(input, model):
exported_name = osp.basename(args.checkpoint).replace('.pth', '.onnx')
input_names = ['input']
output_names = ['output']
torch.onnx.export(
model,
input,
exported_name,
verbose=False,
# Using a higher version of onnx opset
opset_version=11,
input_names=input_names,
output_names=output_names)
model = onnx.load(exported_name)
onnx.checker.check_model(model)


def parse_args():
parser = argparse.ArgumentParser(description='Export a model to onnx')
parser.add_argument('config', help='Train config file path')
parser.add_argument('checkpoint', help='Checkpoint file path')
parser.add_argument(
'--is-localizer',
action='store_true',
default=False,
help='Determine whether the model is a localizer')
parser.add_argument(
'--input-size',
type=int,
nargs='+',
default=None,
help='Input dimension, mandatory for localizers')
args = parser.parse_args()
args.input_size = tuple(args.input_size) if args.input_size else None
return args


if __name__ == '__main__':
args = parse_args()
config_path = args.config
checkpoint_path = args.checkpoint

model_cfg, test_pipeline, test_cfg = _get_cfg(config_path)

model = build_model(model_cfg, train_cfg=None, test_cfg=test_cfg).cuda()
if not args.is_localizer:
try:
dummy_input = torch.randn(args.input_size).cuda()
except TypeError:
for trans in test_pipeline:
if trans['type'] == 'SampleFrames':
t = trans['clip_len']
n = trans['num_clips']
elif trans['type'] == 'Resize':
if isinstance(trans['scale'], int):
s = trans['scale']
elif isinstance(trans['scale'], tuple):
s = max(trans['scale'])
# #crop x (#batch * #clip) x #channel x clip_len x height x width
dummy_input = torch.randn(1, 1 * n, 3, t, s, s).cuda()
# squeeze the t-dimension for 2d model
dummy_input = dummy_input.squeeze(3)
wrapped_model = RecognizerWrapper(model)
else:
try:
# #batch x #channel x length
dummy_input = torch.randn(args.input_size).cuda()
except TypeError as e:
print(f'{e}\nplease specify the input size for localizer.')
exit()
wrapped_model = LocalizerWrapper(model)
load_checkpoint(
getattr(wrapped_model,
'recognizer' if not args.is_localizer else 'localizer'),
checkpoint_path)
torch2onnx(dummy_input, wrapped_model)