-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-171] Fix a bug that was causing training accuracy to be printed as nan sometimes #10437
Conversation
Output of finetune notebook before fix:
Output of finetune notebook after fix:
|
python/mxnet/module/base_module.py
Outdated
@@ -523,8 +524,9 @@ def fit(self, train_data, eval_data=None, eval_metric='acc', | |||
monitor.toc_print() | |||
|
|||
if batch_end_callback is not None: | |||
arg_eval_metric = eval_metric if not end_of_batch else deepcopy(eval_metric) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is too hacky and potentially very slow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deep copy happens maximum once an epoch. That adds negligible time to an epoch.
Hacky - yes. That's because we are trying to print metrics outside the loop which could get cleared inside the loop. Any solution that doesn't modify the existing behavior will be a little hacky.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you can think of a less hacky solution, please propose. If not, I think this is better than printing accuracy as nan in training log which creates bad user experience. A number of users have reported this.
overhead except for the batch in an epoch.
Deep copy is slow and can cause a lot of problems. We can't do this kind of hack in the main package |
…d as nan sometimes (apache#10437) * Fix a bug that was causing training accuracy to be printed as nan sometimes. * Avoid the additional 'arg_eval_metric' variable. There should be no overhead except for the batch in an epoch. * Fix lint. * For the last batch, Capture metrics before callback and use it to print epoch metrics * Remove unused import
…d as nan sometimes (apache#10437) * Fix a bug that was causing training accuracy to be printed as nan sometimes. * Avoid the additional 'arg_eval_metric' variable. There should be no overhead except for the batch in an epoch. * Fix lint. * For the last batch, Capture metrics before callback and use it to print epoch metrics * Remove unused import
Description
Fix a bug (issue #4253) that was causing training accuracy to be printed as nan sometimes
The "Epoch[x] Train-accuracy=xxx" line printed at the end of every epoch gives the impression that the accuracy metric is for the entire epoch. In reality, it is NOT.
For example if an epoch consists of 101 batches and we were printing metrics every 10 batch because callback was created that way, what will be printed as 'Train-accuracy=xxx' at the end of an epoch is actually just the accuracy from a single batch (101'st batch). Printing this as 'Epoch[x] Train-accuracy=xxx' is very misleading.
Ideally we should remove this misleading print statement. But then, I'm sure there is a lot of existing scripts out there that look for this statement. We can't remove this without breaking those scripts. Since this will be a breaking change, let's do it in a major version change.
For now, to avoid the nan error, we can avoid resetting the metrics when processing callback for the last batch. It will be reset at the beginning of the next epoch anyway.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.