Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-171] Fix a bug that was causing training accuracy to be printed as nan sometimes #10437

Merged
merged 5 commits into from
Apr 9, 2018
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/mxnet/module/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import time
import logging
import warnings
from copy import deepcopy

from .. import metric
from .. import ndarray
Expand Down Expand Up @@ -523,8 +524,9 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
monitor.toc_print()

if batch_end_callback is not None:
arg_eval_metric = eval_metric if not end_of_batch else deepcopy(eval_metric)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too hacky and potentially very slow

Copy link
Contributor Author

@indhub indhub Apr 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deep copy happens maximum once an epoch. That adds negligible time to an epoch.

Hacky - yes. That's because we are trying to print metrics outside the loop which could get cleared inside the loop. Any solution that doesn't modify the existing behavior will be a little hacky.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can think of a less hacky solution, please propose. If not, I think this is better than printing accuracy as nan in training log which creates bad user experience. A number of users have reported this.

batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
eval_metric=eval_metric,
eval_metric=arg_eval_metric,
locals=locals())
for callback in _as_list(batch_end_callback):
callback(batch_end_params)
Expand Down