diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 8e924bf96c..49b1308b0f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -98,27 +98,28 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", ) + parser.add_argument( - "--avg", + "--iter", type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, ) parser.add_argument( - "--avg-last-n", + "--avg", type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -453,13 +454,19 @@ def main(): ) params.res_dir = params.exp_dir / params.decoding_method - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -485,8 +492,20 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 251456c955..1ef05d964a 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -216,27 +216,62 @@ def save_checkpoint_with_global_batch_idx( ) -def find_checkpoints(out_dir: Path) -> List[str]: +def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]: """Find all available checkpoints in a directory. The checkpoint filenames have the form: `checkpoint-xxx.pt` where xxx is a numerical value. + Assume you have the following checkpoints in the folder `foo`: + + - checkpoint-1.pt + - checkpoint-20.pt + - checkpoint-300.pt + - checkpoint-4000.pt + + Case 1 (Return all checkpoints):: + + find_checkpoints(out_dir='foo') + + Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e., + checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt) + + find_checkpoints(out_dir='foo', iteration=20) + + Case 3 (Return checkpoints older than checkpoint-20.pt, i.e., + checkpoint-20.pt, checkpoint-1.pt):: + + find_checkpoints(out_dir='foo', iteration=-20) + Args: out_dir: The directory where to search for checkpoints. + iteration: + If it is 0, return all available checkpoints. + If it is positive, return the checkpoints whose iteration number is + greater than or equal to `iteration`. + If it is negative, return the checkpoints whose iteration number is + less than or equal to `-iteration`. Returns: Return a list of checkpoint filenames, sorted in descending order by the numerical value in the filename. """ checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) pattern = re.compile(r"checkpoint-([0-9]+).pt") - idx_checkpoints = [ + iter_checkpoints = [ (int(pattern.search(c).group(1)), c) for c in checkpoints ] + # iter_checkpoints is a list of tuples. Each tuple contains + # two elements: (iteration_number, checkpoint-iteration_number.pt) + + iter_checkpoints = sorted( + iter_checkpoints, reverse=True, key=lambda x: x[0] + ) + if iteration >= 0: + ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration] + else: + ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration] - idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0]) - ans = [ic[1] for ic in idx_checkpoints] return ans