Skip to content

Commit

Permalink
update unittest (#649)
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin authored Feb 25, 2021
1 parent a16552c commit f3a26aa
Showing 1 changed file with 101 additions and 113 deletions.
214 changes: 101 additions & 113 deletions tests/test_runtime/test_eval_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,120 +124,108 @@ def test_eval_hook():
assert runner.meta is None or 'best_ckpt' not in runner.meta[
'hook_msgs']

# when `save_best` is set to 'auto', first metric will be used.
loader = DataLoader(EvalDataset())
model = Model()
data_loader = DataLoader(EvalDataset())
eval_hook = EvalHook(data_loader, interval=1, save_best='auto')

with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(
model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 8)

ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(
ckpt_path)
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == 7

# total_epochs = 8, return the best acc and corresponding epoch
loader = DataLoader(EvalDataset())
model = Model()
data_loader = DataLoader(EvalDataset())
eval_hook = EvalHook(data_loader, interval=1, save_best='acc')

with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(
model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 8)

ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(
ckpt_path)
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == 7

# total_epochs = 8, return the best score and corresponding epoch
data_loader = DataLoader(EvalDataset())
eval_hook = EvalHook(
data_loader, interval=1, save_best='score', rule='greater')
with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(
model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 8)

ckpt_path = osp.join(tmpdir, 'best_score_epoch_4.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(
ckpt_path)
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == 7

# total_epochs = 8, return the best score using less compare func
# and indicate corresponding epoch
data_loader = DataLoader(EvalDataset())
eval_hook = EvalHook(data_loader, save_best='acc', rule='less')
with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(
model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 8)

ckpt_path = osp.join(tmpdir, 'best_acc_epoch_6.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(
ckpt_path)
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == -3

# Test the EvalHook when resume happend
data_loader = DataLoader(EvalDataset())
# when `save_best` is set to 'auto', first metric will be used.
loader = DataLoader(EvalDataset())
model = Model()
data_loader = DataLoader(EvalDataset())
eval_hook = EvalHook(data_loader, interval=1, save_best='auto')

with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 8)

ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(ckpt_path)
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == 7

# total_epochs = 8, return the best acc and corresponding epoch
loader = DataLoader(EvalDataset())
model = Model()
data_loader = DataLoader(EvalDataset())
eval_hook = EvalHook(data_loader, interval=1, save_best='acc')

with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 8)

ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(ckpt_path)
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == 7

# total_epochs = 8, return the best score and corresponding epoch
data_loader = DataLoader(EvalDataset())
eval_hook = EvalHook(
data_loader, interval=1, save_best='score', rule='greater')
with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 8)

ckpt_path = osp.join(tmpdir, 'best_score_epoch_4.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(ckpt_path)
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == 7

# total_epochs = 8, return the best score using less compare func
# and indicate corresponding epoch
data_loader = DataLoader(EvalDataset())
eval_hook = EvalHook(data_loader, save_best='acc', rule='less')
with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 8)

ckpt_path = osp.join(tmpdir, 'best_acc_epoch_6.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(ckpt_path)
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == -3

# Test the EvalHook when resume happend
data_loader = DataLoader(EvalDataset())
eval_hook = EvalHook(data_loader, save_best='acc')
with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 2)

ckpt_path = osp.join(tmpdir, 'best_acc_epoch_2.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(ckpt_path)
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == 4

resume_from = osp.join(tmpdir, 'latest.pth')
loader = DataLoader(ExampleDataset())
eval_hook = EvalHook(data_loader, save_best='acc')
with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(
model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 2)

ckpt_path = osp.join(tmpdir, 'best_acc_epoch_2.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(
ckpt_path)
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == 4

resume_from = osp.join(tmpdir, 'latest.pth')
loader = DataLoader(ExampleDataset())
eval_hook = EvalHook(data_loader, save_best='acc')
runner = EpochBasedRunner(
model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.resume(resume_from)
runner.run([loader], [('train', 1)], 8)

ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(
ckpt_path)
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == 7
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.resume(resume_from)
runner.run([loader], [('train', 1)], 8)

ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(ckpt_path)
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == 7


@patch('mmaction.apis.single_gpu_test', MagicMock)
Expand Down

0 comments on commit f3a26aa

Please sign in to comment.