diff --git a/tools/test.py b/tools/test.py index b4ea2178e4..5166e55f32 100644 --- a/tools/test.py +++ b/tools/test.py @@ -94,6 +94,18 @@ def parse_args(): return args +def turn_off_pretrained(cfg): + # recursively find all pretrained in the model config, + # and set them None to avoid redundant pretrain steps for testing + if 'pretrained' in cfg: + cfg.pretrained = None + + # recursively turn off pretrained value + for sub_cfg in cfg.values(): + if isinstance(sub_cfg, dict): + turn_off_pretrained(sub_cfg) + + def main(): args = parse_args() @@ -174,6 +186,9 @@ def main(): **cfg.data.get('test_dataloader', {})) data_loader = build_dataloader(dataset, **dataloader_setting) + # remove redundant pretrain steps for testing + turn_off_pretrained(cfg.model) + # build the model and load checkpoint model = build_model( cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))