diff --git a/mmaction/apis/train.py b/mmaction/apis/train.py index f04ed6c712..7a3cd1351b 100644 --- a/mmaction/apis/train.py +++ b/mmaction/apis/train.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy as cp +import os import os.path as osp import numpy as np @@ -205,21 +206,22 @@ def train_model(model, if test['test_last'] or test['test_best']: best_ckpt_path = None if test['test_best']: - if hasattr(eval_hook, 'best_ckpt_path'): - best_ckpt_path = eval_hook.best_ckpt_path - - if best_ckpt_path is None or not osp.exists(best_ckpt_path): + ckpt_paths = [x for x in os.listdir(cfg.work_dir) if 'best' in x] + ckpt_paths = [x for x in ckpt_paths if x.endswith('.pth')] + if len(ckpt_paths) == 0: + runner.logger.info('Warning: test_best set, but no ckpt found') test['test_best'] = False - if best_ckpt_path is None: - runner.logger.info('Warning: test_best set as True, but ' - 'is not applicable ' - '(eval_hook.best_ckpt_path is None)') - else: - runner.logger.info('Warning: test_best set as True, but ' - 'is not applicable (best_ckpt ' - f'{best_ckpt_path} not found)') if not test['test_last']: return + elif len(ckpt_paths) > 1: + epoch_ids = [ + int(x.split('epoch_')[-1][:-4]) for x in ckpt_paths + ] + best_ckpt_path = ckpt_paths[np.argmax(epoch_ids)] + else: + best_ckpt_path = ckpt_paths[0] + if best_ckpt_path: + best_ckpt_path = osp.join(cfg.work_dir, best_ckpt_path) test_dataset = build_dataset(cfg.data.test, dict(test_mode=True)) gpu_collect = cfg.get('evaluation', {}).get('gpu_collect', False) @@ -242,7 +244,7 @@ def train_model(model, if test['test_last']: names.append('last') ckpts.append(None) - if test['test_best']: + if test['test_best'] and best_ckpt_path is not None: names.append('best') ckpts.append(best_ckpt_path) diff --git a/mmaction/core/evaluation/ava_utils.py b/mmaction/core/evaluation/ava_utils.py index 7f6571d478..e7aa10b2f6 100644 --- a/mmaction/core/evaluation/ava_utils.py +++ b/mmaction/core/evaluation/ava_utils.py @@ -1,6 +1,6 @@ # This piece of code is directly adapted from ActivityNet official repo # https://github.com/activitynet/ActivityNet/blob/master/ -# Evaluation/get_ava_performance.py. Some unused codes are removed. +# Evaluation/get_ava_performance.py. Some unused codes are removed. import csv import logging import time