diff --git a/mmaction/apis/__init__.py b/mmaction/apis/__init__.py index 8a68055e70..15961080e7 100644 --- a/mmaction/apis/__init__.py +++ b/mmaction/apis/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .inference import inference_recognizer, init_recognizer from .test import multi_gpu_test, single_gpu_test -from .train import train_model +from .train import init_random_seed, train_model __all__ = [ 'train_model', 'init_recognizer', 'inference_recognizer', 'multi_gpu_test', - 'single_gpu_test' + 'single_gpu_test', 'init_random_seed' ] diff --git a/mmaction/apis/train.py b/mmaction/apis/train.py index 54b1fb0163..f04ed6c712 100644 --- a/mmaction/apis/train.py +++ b/mmaction/apis/train.py @@ -2,7 +2,9 @@ import copy as cp import os.path as osp +import numpy as np import torch +import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook, build_optimizer, get_dist_info) @@ -15,6 +17,39 @@ from .test import multi_gpu_test +def init_random_seed(seed=None, device='cuda'): + """Initialize random seed. + + If the seed is not set, the seed will be automatically randomized, + and then broadcast to all processes to prevent some potential bugs. + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + Returns: + int: Seed to be used. + """ + if seed is not None: + return seed + + # Make sure all ranks share the same random seed to prevent + # some potential bugs. Please refer to + # https://github.com/open-mmlab/mmdetection/issues/6339 + rank, world_size = get_dist_info() + seed = np.random.randint(2**31) + + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + + dist.broadcast(random_num, src=0) + return random_num.item() + + def train_model(model, dataset, cfg, diff --git a/tools/train.py b/tools/train.py index 705473665e..e871c53088 100644 --- a/tools/train.py +++ b/tools/train.py @@ -13,7 +13,7 @@ from mmcv.utils import get_git_hash from mmaction import __version__ -from mmaction.apis import train_model +from mmaction.apis import init_random_seed, train_model from mmaction.datasets import build_dataset from mmaction.models import build_model from mmaction.utils import collect_env, get_root_logger, register_module_hooks @@ -143,12 +143,13 @@ def main(): logger.info(f'Config: {cfg.pretty_text}') # set random seeds - if args.seed is not None: - logger.info(f'Set random seed to {args.seed}, ' - f'deterministic: {args.deterministic}') - set_random_seed(args.seed, deterministic=args.deterministic) - cfg.seed = args.seed - meta['seed'] = args.seed + seed = init_random_seed(args.seed) + logger.info(f'Set random seed to {seed}, ' + f'deterministic: {args.deterministic}') + set_random_seed(seed, deterministic=args.deterministic) + + cfg.seed = seed + meta['seed'] = seed meta['config_name'] = osp.basename(args.config) meta['work_dir'] = osp.basename(cfg.work_dir.rstrip('/\\'))