From 5db8b7637e747af51699700450f304b931794e09 Mon Sep 17 00:00:00 2001 From: xusu Date: Fri, 4 Sep 2020 16:31:29 +0800 Subject: [PATCH 1/8] First commit. --- setup.cfg | 2 +- tools/torch2onnx.py | 57 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 tools/torch2onnx.py diff --git a/setup.cfg b/setup.cfg index e720d52107..67cac41857 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,6 +19,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = mmaction -known_third_party = cv2,joblib,matplotlib,mmcv,numpy,pandas,pytest,scipy,seaborn,torch +known_third_party = cv2,joblib,matplotlib,mmcv,numpy,onnx,pandas,pytest,scipy,seaborn,torch no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tools/torch2onnx.py b/tools/torch2onnx.py new file mode 100644 index 0000000000..03d2bd138d --- /dev/null +++ b/tools/torch2onnx.py @@ -0,0 +1,57 @@ +import os.path as osp +import sys + +import mmcv +import onnx +import torch + +from mmaction.models import build_model + + +def _get_recognizer_cfg(config_path): + """Grab configs necessary to create a recognizer.""" + if not osp.exists(config_path): + raise Exception('Cannot find config path') + config = mmcv.Config.fromfile(config_path) + return config.model, config.data.test.pipeline, config.test_cfg + + +def torch2onnx(input, model): + input_names = ['input'] + output_names = ['output'] + torch.onnx.export( + model, + input, + 'exported_model.onnx', + verbose=True, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + 'input': [0], + 'output': [0] + }) + model = onnx.load('exported_model.onnx') + onnx.checker.check_model(model) + print(onnx.helper.printable_graph(model.graph)) + + +if __name__ == '__main__': + try: + config_path = sys.argv[1] + except BaseException: + print('Please indicate the config file path.') + model_cfg, test_pipeline, test_cfg = _get_recognizer_cfg(config_path) + t = None + s = None + for trans in test_pipeline: + if trans['type'] == 'SampleFrames': + t = trans['clip_len'] + elif trans['type'] == 'Resize': + if isinstance(trans['scale'], int): + s = trans['scale'] + elif isinstance(trans['scale'], tuple): + s = max(trans['scale']) + + dummy_input = torch.randn(8, 3, t, s, s).cuda() + model = build_model(model_cfg, train_cfg=None, test_cfg=test_cfg).cuda() + torch2onnx(dummy_input, model) From be65b1de17250a9e0841a0a6d564dc1f2e72f9f5 Mon Sep 17 00:00:00 2001 From: xusu Date: Tue, 8 Sep 2020 14:01:07 +0800 Subject: [PATCH 2/8] Revise and add nart-tool to caffemodel. --- tools/torch2onnx.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/tools/torch2onnx.py b/tools/torch2onnx.py index 03d2bd138d..dff45a94d4 100644 --- a/tools/torch2onnx.py +++ b/tools/torch2onnx.py @@ -4,6 +4,7 @@ import mmcv import onnx import torch +from mmcv.runner import load_checkpoint from mmaction.models import build_model @@ -11,7 +12,7 @@ def _get_recognizer_cfg(config_path): """Grab configs necessary to create a recognizer.""" if not osp.exists(config_path): - raise Exception('Cannot find config path') + raise FileNotFoundError('Cannot find config path') config = mmcv.Config.fromfile(config_path) return config.model, config.data.test.pipeline, config.test_cfg @@ -30,16 +31,31 @@ def torch2onnx(input, model): 'input': [0], 'output': [0] }) - model = onnx.load('exported_model.onnx') + model = onnx.load('exported_onnx_model.onnx') onnx.checker.check_model(model) print(onnx.helper.printable_graph(model.graph)) +def torch2caffe(input, model): + try: + import spring.nart.tools.pytorch as pytorch + except ImportError as e: + print(f'Cannot import nart tool: {e}') + return + with pytorch.convert_mode(): + pytorch.convert( + model, [input], + 'exported_caffe_model', + input_names=['input'], + output_names=['output']) + + if __name__ == '__main__': try: config_path = sys.argv[1] - except BaseException: - print('Please indicate the config file path.') + checkpoint_path = sys.argv[2] + except BaseException as e: + print(f'{e}:\nPlease indicate the config file and checkpoint path.') model_cfg, test_pipeline, test_cfg = _get_recognizer_cfg(config_path) t = None s = None @@ -52,6 +68,8 @@ def torch2onnx(input, model): elif isinstance(trans['scale'], tuple): s = max(trans['scale']) - dummy_input = torch.randn(8, 3, t, s, s).cuda() + dummy_input = torch.randn(1, 3, t, s, s).cuda() model = build_model(model_cfg, train_cfg=None, test_cfg=test_cfg).cuda() + load_checkpoint(model, checkpoint_path) torch2onnx(dummy_input, model) + torch2caffe(dummy_input, model) From f60c0454c38b96ba70df05e0902c6d065d6883fe Mon Sep 17 00:00:00 2001 From: xusu Date: Tue, 15 Sep 2020 18:11:18 +0800 Subject: [PATCH 3/8] Add localizer onnx support and better args. --- tools/torch2onnx.py | 134 ++++++++++++++++++++++++++++++-------------- 1 file changed, 91 insertions(+), 43 deletions(-) diff --git a/tools/torch2onnx.py b/tools/torch2onnx.py index dff45a94d4..a895d5f8e7 100644 --- a/tools/torch2onnx.py +++ b/tools/torch2onnx.py @@ -1,16 +1,42 @@ +import argparse import os.path as osp import sys import mmcv import onnx import torch +import torch.nn as nn from mmcv.runner import load_checkpoint from mmaction.models import build_model +sys.path.append('../') -def _get_recognizer_cfg(config_path): - """Grab configs necessary to create a recognizer.""" + +class RecognizerWarpper(nn.Module): + """Warpper that only inferences the part in computation graph.""" + + def __init__(self, recognizer): + super().__init__() + self.recognizer = recognizer + + def forward(self, x): + return self.recognizer.forward_dummy(x) + + +class LocalizerWarpper(nn.Module): + """Warpper that only inferences the part in computation graph.""" + + def __init__(self, localizer): + super().__init__() + self.localizer = localizer + + def forward(self, x): + return self.localizer._forward(x) + + +def _get_cfg(config_path): + """Grab configs necessary to create a model.""" if not osp.exists(config_path): raise FileNotFoundError('Cannot find config path') config = mmcv.Config.fromfile(config_path) @@ -18,58 +44,80 @@ def _get_recognizer_cfg(config_path): def torch2onnx(input, model): + exported_name = osp.basename(args.checkpoint).replace('.pth', '.onnx') input_names = ['input'] output_names = ['output'] torch.onnx.export( model, input, - 'exported_model.onnx', - verbose=True, + exported_name, + verbose=False, + # Using a higher version of onnx opset + opset_version=11, input_names=input_names, output_names=output_names, - dynamic_axes={ - 'input': [0], - 'output': [0] - }) - model = onnx.load('exported_onnx_model.onnx') + ) + model = onnx.load(exported_name) onnx.checker.check_model(model) - print(onnx.helper.printable_graph(model.graph)) -def torch2caffe(input, model): - try: - import spring.nart.tools.pytorch as pytorch - except ImportError as e: - print(f'Cannot import nart tool: {e}') - return - with pytorch.convert_mode(): - pytorch.convert( - model, [input], - 'exported_caffe_model', - input_names=['input'], - output_names=['output']) +def parse_args(): + parser = argparse.ArgumentParser(description='Export a model to onnx') + parser.add_argument('config', help='Train config file path') + parser.add_argument('checkpoint', help='Checkpoint file path') + parser.add_argument( + '--is-localizer', + action='store_true', + default=False, + help='Determine whether the model is a localizer') + parser.add_argument( + '--input-size', + type=int, + nargs='+', + default=None, + help='Input dimension, mandatory for localizers') + args = parser.parse_args() + args.input_size = tuple(args.input_size) if args.input_size else None + return args if __name__ == '__main__': - try: - config_path = sys.argv[1] - checkpoint_path = sys.argv[2] - except BaseException as e: - print(f'{e}:\nPlease indicate the config file and checkpoint path.') - model_cfg, test_pipeline, test_cfg = _get_recognizer_cfg(config_path) - t = None - s = None - for trans in test_pipeline: - if trans['type'] == 'SampleFrames': - t = trans['clip_len'] - elif trans['type'] == 'Resize': - if isinstance(trans['scale'], int): - s = trans['scale'] - elif isinstance(trans['scale'], tuple): - s = max(trans['scale']) - - dummy_input = torch.randn(1, 3, t, s, s).cuda() + args = parse_args() + config_path = args.config + checkpoint_path = args.checkpoint + + model_cfg, test_pipeline, test_cfg = _get_cfg(config_path) + + # hyperparams for recognizers model = build_model(model_cfg, train_cfg=None, test_cfg=test_cfg).cuda() - load_checkpoint(model, checkpoint_path) - torch2onnx(dummy_input, model) - torch2caffe(dummy_input, model) + if not args.is_localizer: + try: + dummy_input = torch.randn(args.input_size).cuda() + except TypeError: + for trans in test_pipeline: + if trans['type'] == 'SampleFrames': + t = trans['clip_len'] + n = trans['num_clips'] + elif trans['type'] == 'Resize': + if isinstance(trans['scale'], int): + s = trans['scale'] + elif isinstance(trans['scale'], tuple): + s = max(trans['scale']) + # #crop x (#batch * #clips) x #channel x clip_len x height x width + dummy_input = torch.randn(1, 1 * n, 3, t, s, s).cuda() + # squeeze the t-dimension for 2d model + dummy_input = dummy_input.squeeze(3) + warpped_model = RecognizerWarpper(model) + else: + try: + # #batch x #channel x length + dummy_input = torch.randn(args.input_size).cuda() + except TypeError as e: + print(f'{e}\nplease specify the input size for localizer.') + exit() + warpped_model = LocalizerWarpper(model) + load_checkpoint( + getattr(warpped_model, + 'recognizer' if not args.is_localizer else 'localizer'), + checkpoint_path) + torch2onnx(dummy_input, warpped_model) From 7650d646cf4291a90607691c738ccb34da68e514 Mon Sep 17 00:00:00 2001 From: xusu Date: Tue, 15 Sep 2020 19:03:48 +0800 Subject: [PATCH 4/8] Minor fix typos. --- tools/torch2onnx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tools/torch2onnx.py b/tools/torch2onnx.py index a895d5f8e7..5d79e6823e 100644 --- a/tools/torch2onnx.py +++ b/tools/torch2onnx.py @@ -88,7 +88,6 @@ def parse_args(): model_cfg, test_pipeline, test_cfg = _get_cfg(config_path) - # hyperparams for recognizers model = build_model(model_cfg, train_cfg=None, test_cfg=test_cfg).cuda() if not args.is_localizer: try: @@ -103,7 +102,7 @@ def parse_args(): s = trans['scale'] elif isinstance(trans['scale'], tuple): s = max(trans['scale']) - # #crop x (#batch * #clips) x #channel x clip_len x height x width + # #crop x (#batch * #clip) x #channel x clip_len x height x width dummy_input = torch.randn(1, 1 * n, 3, t, s, s).cuda() # squeeze the t-dimension for 2d model dummy_input = dummy_input.squeeze(3) From ec54bc049a236ef3bc89c931e2d38afee2c2c923 Mon Sep 17 00:00:00 2001 From: xusu Date: Tue, 15 Sep 2020 20:07:36 +0800 Subject: [PATCH 5/8] Warp import onnx in try catch. Minor fix typo. --- requirements/optional.txt | 1 + setup.cfg | 2 +- tools/torch2onnx.py | 23 ++++++++++++++--------- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/requirements/optional.txt b/requirements/optional.txt index 9db216aa39..09fa110836 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,3 +1,4 @@ av decord +onnx PyTurboJPEG diff --git a/setup.cfg b/setup.cfg index 67cac41857..e720d52107 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,6 +19,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = mmaction -known_third_party = cv2,joblib,matplotlib,mmcv,numpy,onnx,pandas,pytest,scipy,seaborn,torch +known_third_party = cv2,joblib,matplotlib,mmcv,numpy,pandas,pytest,scipy,seaborn,torch no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tools/torch2onnx.py b/tools/torch2onnx.py index 5d79e6823e..4684e5029a 100644 --- a/tools/torch2onnx.py +++ b/tools/torch2onnx.py @@ -1,20 +1,25 @@ import argparse import os.path as osp import sys +import warnings import mmcv -import onnx import torch import torch.nn as nn from mmcv.runner import load_checkpoint from mmaction.models import build_model +try: + import onnx +except ImportError: + warnings.warn('Please install onnx to support onnx exporting.') + sys.path.append('../') -class RecognizerWarpper(nn.Module): - """Warpper that only inferences the part in computation graph.""" +class RecognizerWrapper(nn.Module): + """Wrapper that only inferences the part in computation graph.""" def __init__(self, recognizer): super().__init__() @@ -24,8 +29,8 @@ def forward(self, x): return self.recognizer.forward_dummy(x) -class LocalizerWarpper(nn.Module): - """Warpper that only inferences the part in computation graph.""" +class LocalizerWrapper(nn.Module): + """Wrapper that only inferences the part in computation graph.""" def __init__(self, localizer): super().__init__() @@ -106,7 +111,7 @@ def parse_args(): dummy_input = torch.randn(1, 1 * n, 3, t, s, s).cuda() # squeeze the t-dimension for 2d model dummy_input = dummy_input.squeeze(3) - warpped_model = RecognizerWarpper(model) + wrapped_model = RecognizerWrapper(model) else: try: # #batch x #channel x length @@ -114,9 +119,9 @@ def parse_args(): except TypeError as e: print(f'{e}\nplease specify the input size for localizer.') exit() - warpped_model = LocalizerWarpper(model) + wrapped_model = LocalizerWrapper(model) load_checkpoint( - getattr(warpped_model, + getattr(wrapped_model, 'recognizer' if not args.is_localizer else 'localizer'), checkpoint_path) - torch2onnx(dummy_input, warpped_model) + torch2onnx(dummy_input, wrapped_model) From cfaacc05b47a790fa04c2b8247a0a4bc95449823 Mon Sep 17 00:00:00 2001 From: xusu Date: Tue, 15 Sep 2020 21:10:00 +0800 Subject: [PATCH 6/8] Fix comma. --- tools/torch2onnx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tools/torch2onnx.py b/tools/torch2onnx.py index 4684e5029a..e394743ca3 100644 --- a/tools/torch2onnx.py +++ b/tools/torch2onnx.py @@ -60,8 +60,7 @@ def torch2onnx(input, model): # Using a higher version of onnx opset opset_version=11, input_names=input_names, - output_names=output_names, - ) + output_names=output_names) model = onnx.load(exported_name) onnx.checker.check_model(model) From 3b7ebac103c062d9f62fc07f4eb2e09b4d8e27be Mon Sep 17 00:00:00 2001 From: xusu Date: Tue, 15 Sep 2020 23:51:47 +0800 Subject: [PATCH 7/8] Remove sys.path apending. --- tools/torch2onnx.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tools/torch2onnx.py b/tools/torch2onnx.py index e394743ca3..a847dc4bd8 100644 --- a/tools/torch2onnx.py +++ b/tools/torch2onnx.py @@ -1,6 +1,5 @@ import argparse import os.path as osp -import sys import warnings import mmcv @@ -15,8 +14,6 @@ except ImportError: warnings.warn('Please install onnx to support onnx exporting.') -sys.path.append('../') - class RecognizerWrapper(nn.Module): """Wrapper that only inferences the part in computation graph.""" From e83513f103d739ac3a5ddde79281091106e736e7 Mon Sep 17 00:00:00 2001 From: xusu Date: Wed, 16 Sep 2020 11:42:13 +0800 Subject: [PATCH 8/8] Add doc and changelog. Minor. Minor. --- README.md | 2 +- docs/changelog.md | 1 + docs/tutorials/export_model.md | 41 ++++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 docs/tutorials/export_model.md diff --git a/README.md b/README.md index 9195673654..ebba9bdccb 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ Please refer to [data_preparation.md](docs/data_preparation.md) for a general kn ## Get Started Please see [getting_started.md](docs/getting_started.md) for the basic usage of MMAction2. -There are also tutorials for [finetuning models](docs/tutorials/finetune.md), [adding new dataset](docs/tutorials/new_dataset.md), [designing data pipeline](docs/tutorials/data_pipeline.md), and [adding new modules](docs/tutorials/new_modules.md). +There are also tutorials for [finetuning models](docs/tutorials/finetune.md), [adding new dataset](docs/tutorials/new_dataset.md), [designing data pipeline](docs/tutorials/data_pipeline.md), [exporting model to onnx](docs/tutorials/export_model.md) and [adding new modules](docs/tutorials/new_modules.md). A Colab tutorial is also provided. You may preview the notebook [here](demo/mmaction2_tutorial.ipynb) or directly [run](https://colab.research.google.com/github/open-mmlab/mmaction2/blob/master/demo/mmaction2_tutorial.ipynb) on Colab. diff --git a/docs/changelog.md b/docs/changelog.md index f52276289b..af7570d1e2 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -6,6 +6,7 @@ **New Features** - Support to run real-time action recognition from web camera ([#171](https://github.com/open-mmlab/mmaction2/pull/171)) +- Support to export the pytorch models to onnx ones. ([#160](https://github.com/open-mmlab/mmaction2/pull/160)) **ModelZoo** diff --git a/docs/tutorials/export_model.md b/docs/tutorials/export_model.md new file mode 100644 index 0000000000..a0ec1ad202 --- /dev/null +++ b/docs/tutorials/export_model.md @@ -0,0 +1,41 @@ +# Tutorial 5: Exporting a model to ONNX + +Open Neural Network Exchange [(ONNX)](https://onnx.ai/) is an open ecosystem that empowers AI developers to choose the right tools as their project evolves. So far, our codebase supports onnx exporting from pytorch models trained with mmaction2. The supported models are: + ++ I3D ++ TSN ++ TIN ++ TSM ++ R(2+1)D ++ SLOWFAST ++ SLOWONLY ++ BMN ++ BSN(tem, pem) + +## Usage +For simple exporting, you can use the [script](../../tools/torch2onnx.py) here. Note that the package `onnx` is requried for verification after exporting. + +### Prerequisite +First, install onnx. +```shell +pip install onnx +``` + +### Recognizers +For recognizers, if your model are trained with a config from mmaction2 and intend to inference it according to the test pipeline, simply run: +```shell +python tools/torch2onnx.py $CONFIG_PATH $CHECKPOINT_PATH +``` + +Otherwise, if you want to customize the input tensor shape, you can modify the `test_pipeline` in your config `$CONFIG_PATH`, or run: +```shell +python tools/torch2onnx.py $CONFIG_PATH $CHECKPOINT_PATH --input-size $BATCHS $CROPS $CHANNELS $CLIP_LENGTH $HEIGHT $WIDTH +``` + +### Localizer +For localizers, we *only* support customized input size, since our abstractions for localizers(eg. SSN, BMN) are not unified. Please run: +```shell +python tools/torch2onnx.py $CONFIG_PATH $CHECKPOINT_PATH --is-localizer --input-size $INPUT_SIZE +``` + +Please fire an issue if you discover any checkpoints that are not perfectly exported or suffer some loss in accuracy.