Skip to content

Commit

Permalink
[Feature] Omnisource training (#242)
Browse files Browse the repository at this point in the history
* add  image dataset

* update image dataset

* add rawvideo dataset

* add ImageDecode and PseudoClipBuilder in __init__

* update __init__.py

* init omnisource runner

* update code

* fix bug

* override `end_of_epoch` hook

* add aliasing in OmniSource Runner

* fix log

* add two arguments to datasets

* add DistributedPowerSampler

* resolve comments

* update rawvideo_dataset

* update

* resolve comments

* update changelog

* update PseudoClipBuilder

* fix bug

* add fix

* a better fix

* + evaluate for rawframe dataset

* add load_json_annotations for rawvideo_dataset (since the key is video_dir)

* bug fix

* fix bug

* passing train_ratio

* resolve lint

* add  image dataset

* update image dataset

* add rawvideo dataset

* add ImageDecode and PseudoClipBuilder in __init__

* update __init__.py

* init omnisource runner

* update code

* fix bug

* override `end_of_epoch` hook

* add aliasing in OmniSource Runner

* fix log

* add two arguments to datasets

* add DistributedPowerSampler

* resolve comments

* update rawvideo_dataset

* update PseudoClipBuilder

* fix bug

* add fix

* a better fix

* + evaluate for rawframe dataset

* add load_json_annotations for rawvideo_dataset (since the key is video_dir)

* bug fix

* fix bug

* passing train_ratio

* resolve lint

* add missing register

* resolve comments

* update changelog

* resolve comments

* resolve comments

* resolve comments

* resolve conflicts
  • Loading branch information
kennymckormick authored Oct 30, 2020
1 parent 67151c1 commit bccc96a
Show file tree
Hide file tree
Showing 16 changed files with 628 additions and 28 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
**New Features**
- Support AVA dataset preparation ([#266](https://github.com/open-mmlab/mmaction2/pull/266))
- Support the training of video recognition dataset with multiple tag categories ([#235](https://github.com/open-mmlab/mmaction2/pull/235))
- Support joint training with multiple training datasets of multiple formats, including images, untrimmed videos, etc. ([#242](https://github.com/open-mmlab/mmaction2/pull/242))
- Support specify a start epoch to conduct evaluation ([#216](https://github.com/open-mmlab/mmaction2/pull/216))
- Implement X3D models, support testing with model weights converted from SlowFast ([#288](https://github.com/open-mmlab/mmaction2/pull/288))

Expand Down
49 changes: 39 additions & 10 deletions mmaction/apis/train.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import copy as cp

import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
build_optimizer)
from mmcv.runner.hooks import Fp16OptimizerHook

from ..core import DistEpochEvalHook, EpochEvalHook
from ..core import (DistEpochEvalHook, EpochEvalHook,
OmniSourceDistSamplerSeedHook, OmniSourceRunner)
from ..datasets import build_dataloader, build_dataset
from ..utils import get_root_logger

Expand Down Expand Up @@ -33,19 +36,37 @@ def train_model(model,

# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]

dataloader_setting = dict(
videos_per_gpu=cfg.data.get('videos_per_gpu', 2),
workers_per_gpu=cfg.data.get('workers_per_gpu', 0),
# cfg.gpus will be ignored if distributed
videos_per_gpu=cfg.data.get('videos_per_gpu', 1),
workers_per_gpu=cfg.data.get('workers_per_gpu', 1),
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed)
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('train_dataloader', {}))

data_loaders = [
build_dataloader(ds, **dataloader_setting) for ds in dataset
]
if cfg.omnisource:
# The option can override videos_per_gpu
train_ratio = cfg.data.get('train_ratio', None)
omni_videos_per_gpu = cfg.data.get('omni_videos_per_gpu', None)
if omni_videos_per_gpu is None:
dataloader_settings = [dataloader_setting] * len(dataset)
else:
dataloader_settings = []
for videos_per_gpu in omni_videos_per_gpu:
this_setting = cp.deepcopy(dataloader_setting)
this_setting['videos_per_gpu'] = videos_per_gpu
dataloader_settings.append(this_setting)
data_loaders = [
build_dataloader(ds, **setting)
for ds, setting in zip(dataset, dataloader_settings)
]

else:
data_loaders = [
build_dataloader(ds, **dataloader_setting) for ds in dataset
]

# put model on gpus
if distributed:
Expand All @@ -63,7 +84,9 @@ def train_model(model,

# build runner
optimizer = build_optimizer(model, cfg.optimizer)
runner = EpochBasedRunner(

Runner = OmniSourceRunner if cfg.omnisource else EpochBasedRunner
runner = Runner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
Expand All @@ -87,7 +110,10 @@ def train_model(model,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
if distributed:
runner.register_hook(DistSamplerSeedHook())
if cfg.omnisource:
runner.register_hook(OmniSourceDistSamplerSeedHook())
else:
runner.register_hook(DistSamplerSeedHook())

if validate:
eval_cfg = cfg.get('evaluation', {})
Expand All @@ -109,4 +135,7 @@ def train_model(model,
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
runner_kwargs = dict()
if cfg.omnisource:
runner_kwargs = dict(train_ratio=train_ratio)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs, **runner_kwargs)
1 change: 1 addition & 0 deletions mmaction/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .evaluation import * # noqa: F401, F403
from .lr import * # noqa: F401, F403
from .optimizer import * # noqa: F401, F403
from .runner import * # noqa: F401, F403
3 changes: 3 additions & 0 deletions mmaction/core/runner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .omnisource_runner import OmniSourceDistSamplerSeedHook, OmniSourceRunner

__all__ = ['OmniSourceRunner', 'OmniSourceDistSamplerSeedHook']
162 changes: 162 additions & 0 deletions mmaction/core/runner/omnisource_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) Open-MMLab. All rights reserved.
import time
import warnings

import mmcv
from mmcv.runner import EpochBasedRunner, Hook
from mmcv.runner.utils import get_host_info


def cycle(iterable):
iterator = iter(iterable)
while True:
try:
yield next(iterator)
except StopIteration:
iterator = iter(iterable)


class OmniSourceDistSamplerSeedHook(Hook):

def before_epoch(self, runner):
for data_loader in runner.data_loaders:
if hasattr(data_loader.sampler, 'set_epoch'):
# in case the data loader uses `SequentialSampler` in Pytorch
data_loader.sampler.set_epoch(runner.epoch)
elif hasattr(data_loader.batch_sampler.sampler, 'set_epoch'):
# batch sampler in pytorch wraps the sampler as its attributes.
data_loader.batch_sampler.sampler.set_epoch(runner.epoch)


class OmniSourceRunner(EpochBasedRunner):
"""OmniSource Epoch-based Runner.
This runner train models epoch by epoch, the epoch length is defined by the
dataloader[0], which is the main dataloader.
"""

def run_iter(self, data_batch, train_mode, source, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
# Since we have multiple sources, we add a suffix to log_var names,
# so that we can differentiate them.
if 'log_vars' in outputs:
log_vars = outputs['log_vars']
log_vars = {k + source: v for k, v in log_vars.items()}
self.log_buffer.update(log_vars, outputs['num_samples'])

self.outputs = outputs

def train(self, data_loaders, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loaders = data_loaders
self.main_loader = self.data_loaders[0]
# Add aliasing
self.data_loader = self.main_loader
self.aux_loaders = self.data_loaders[1:]
self.aux_iters = [cycle(loader) for loader in self.aux_loaders]

auxiliary_iter_times = [1] * len(self.aux_loaders)
use_aux_per_niter = 1
if 'train_ratio' in kwargs:
train_ratio = kwargs.pop('train_ratio')
use_aux_per_niter = train_ratio[0]
auxiliary_iter_times = train_ratio[1:]

self._max_iters = self._max_epochs * len(self.main_loader)

self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition

for i, data_batch in enumerate(self.main_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, source='')
self.call_hook('after_train_iter')

if self._iter % use_aux_per_niter != 0:
self._iter += 1
continue

for idx, n_times in enumerate(auxiliary_iter_times):
for step in range(n_times):
data_batch = next(self.aux_iters[idx])
self.call_hook('before_train_iter')
self.run_iter(
data_batch, train_mode=True, source=f'/aux{idx}')
self.call_hook('after_train_iter')
self._iter += 1

self.call_hook('after_train_epoch')
self._epoch += 1

# Now that we use validate hook, not implement this func to save efforts.
def val(self, data_loader, **kwargs):
raise NotImplementedError

def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training.
`data_loaders[0]` is the main data_loader, which contains
target datasets and determines the epoch length.
`data_loaders[1:]` are auxiliary data loaders, which contain
auxiliary web datasets.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2)] means running 2
epochs for training iteratively. Note that val epoch is not
supported for this runner for simplicity.
max_epochs (int | None): The max epochs that training lasts,
deprecated now. Default: None.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(workflow) == 1 and workflow[0][0] == 'train'
if max_epochs is not None:
warnings.warn(
'setting max_epochs in run is deprecated, '
'please set max_epochs in runner_config', DeprecationWarning)
self._max_epochs = max_epochs

assert self._max_epochs is not None, (
'max_epochs must be specified during instantiation')

mode, epochs = workflow[0]
self._max_iters = self._max_epochs * len(data_loaders[0])

work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow,
self._max_epochs)
self.call_hook('before_run')

while self.epoch < self._max_epochs:
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
else:
raise TypeError(
f'mode in workflow must be a str, but got {mode}')

for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
epoch_runner(data_loaders, **kwargs)

time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
5 changes: 4 additions & 1 deletion mmaction/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from .builder import build_dataloader, build_dataset
from .dataset_wrappers import RepeatDataset
from .hvu_dataset import HVUDataset
from .image_dataset import ImageDataset
from .rawframe_dataset import RawframeDataset
from .rawvideo_dataset import RawVideoDataset
from .ssn_dataset import SSNDataset
from .video_dataset import VideoDataset

__all__ = [
'VideoDataset', 'build_dataloader', 'build_dataset', 'RepeatDataset',
'RawframeDataset', 'BaseDataset', 'ActivityNetDataset', 'SSNDataset',
'HVUDataset', 'AudioDataset', 'AudioFeatureDataset', 'AVADataset'
'HVUDataset', 'AudioDataset', 'AudioFeatureDataset', 'ImageDataset',
'RawVideoDataset', 'AVADataset'
]
41 changes: 38 additions & 3 deletions mmaction/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import copy
import os.path as osp
from abc import ABCMeta, abstractmethod
from collections import defaultdict

import mmcv
import numpy as np
import torch
from torch.utils.data import Dataset

Expand Down Expand Up @@ -37,6 +39,14 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
from 0. Default: 1.
modality (str): Modality of data. Support 'RGB', 'Flow', 'Audio'.
Default: 'RGB'.
sample_by_class (bool): Sampling by class, should be set `True` when
performing inter-class data balancing. Only compatible with
`multi_class == False`. Only applies for training. Default: False.
power (float | None): We support sampling data with the probability
proportional to the power of its label frequency (freq ^ power)
when sampling data. `power == 1` indicates uniformly sampling all
data; `power == 0` indicates uniformly sampling all classes.
Default: None.
"""

def __init__(self,
Expand All @@ -47,7 +57,9 @@ def __init__(self,
multi_class=False,
num_classes=None,
start_index=1,
modality='RGB'):
modality='RGB',
sample_by_class=False,
power=None):
super().__init__()

self.ann_file = ann_file
Expand All @@ -58,8 +70,14 @@ def __init__(self,
self.num_classes = num_classes
self.start_index = start_index
self.modality = modality
self.sample_by_class = sample_by_class
self.power = power
assert not (self.multi_class and self.sample_by_class)

self.pipeline = Compose(pipeline)
self.video_infos = self.load_annotations()
if self.sample_by_class:
self.video_infos_by_class = self.parse_by_class()

@abstractmethod
def load_annotations(self):
Expand All @@ -85,6 +103,13 @@ def load_json_annotations(self):
video_infos[i]['label'] = video_infos[i]['label'][0]
return video_infos

def parse_by_class(self):
video_infos_by_class = defaultdict(list)
for item in self.video_infos:
label = item['label']
video_infos_by_class[label].append(item)
return video_infos_by_class

@abstractmethod
def evaluate(self, results, metrics, logger):
"""Evaluation for the dataset.
Expand All @@ -105,7 +130,12 @@ def dump_results(self, results, out):

def prepare_train_frames(self, idx):
"""Prepare the frames for training given the index."""
results = copy.deepcopy(self.video_infos[idx])
if self.sample_by_class:
# Then, the idx is the class index
samples = self.video_infos_by_class[idx]
results = copy.deepcopy(np.random.choice(samples))
else:
results = copy.deepcopy(self.video_infos[idx])
results['modality'] = self.modality
results['start_index'] = self.start_index

Expand All @@ -120,7 +150,12 @@ def prepare_train_frames(self, idx):

def prepare_test_frames(self, idx):
"""Prepare the frames for testing given the index."""
results = copy.deepcopy(self.video_infos[idx])
if self.sample_by_class:
# Then, the idx is the class index
samples = self.video_infos_by_class[idx]
results = copy.deepcopy(np.random.choice(samples))
else:
results = copy.deepcopy(self.video_infos[idx])
results['modality'] = self.modality
results['start_index'] = self.start_index

Expand Down
Loading

0 comments on commit bccc96a

Please sign in to comment.