diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index ed6a6ea82e..769cd2a1d8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -24,7 +24,7 @@ from model import Transducer from icefall.decode import Nbest, one_best_decoding -from icefall.utils import get_texts +from icefall.utils import add_eos, add_sos, get_texts def fast_beam_search_one_best( @@ -46,7 +46,7 @@ def fast_beam_search_one_best( model: An instance of `Transducer`. decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Decoding graph used for decoding, may be a TrivialGraph or a LG. encoder_out: A tensor of shape (N, T, C) from the encoder. encoder_out_lens: @@ -106,7 +106,7 @@ def fast_beam_search_nbest_LG( model: An instance of `Transducer`. decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Decoding graph used for decoding, may be a TrivialGraph or a LG. encoder_out: A tensor of shape (N, T, C) from the encoder. encoder_out_lens: @@ -226,7 +226,7 @@ def fast_beam_search_nbest( model: An instance of `Transducer`. decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Decoding graph used for decoding, may be a TrivialGraph or a LG. encoder_out: A tensor of shape (N, T, C) from the encoder. encoder_out_lens: @@ -311,7 +311,7 @@ def fast_beam_search_nbest_oracle( model: An instance of `Transducer`. decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Decoding graph used for decoding, may be a TrivialGraph or a LG. encoder_out: A tensor of shape (N, T, C) from the encoder. encoder_out_lens: @@ -397,7 +397,7 @@ def fast_beam_search( model: An instance of `Transducer`. decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Decoding graph used for decoding, may be a TrivialGraph or a LG. encoder_out: A tensor of shape (N, T, C) from the encoder. encoder_out_lens: @@ -1219,13 +1219,15 @@ def fast_beam_search_with_nbest_rescoring( temperature: float = 1.0, ) -> Dict[str, List[List[int]]]: """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using modified beam search, and then - the shortest path within the lattice is used as the final output. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model. The shortest path within the + lattice is used as the final output. + Args: model: An instance of `Transducer`. decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Decoding graph used for decoding, may be a TrivialGraph or a LG. encoder_out: A tensor of shape (N, T, C) from the encoder. encoder_out_lens: @@ -1350,3 +1352,190 @@ def fast_beam_search_with_nbest_rescoring( ans[key] = hyps return ans + + +def fast_beam_search_with_nbest_rnn_rescoring( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + rnn_lm_model: torch.nn.Module, + rnn_lm_scale_list: List[float], + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, +) -> Dict[str, List[List[int]]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model and a rnn-lm. + The shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + rnn_lm_model: + A rnn-lm model used for LM rescoring + rnn_lm_scale_list: + A list of floats representing RNN score scales. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the + ngram LM scale value used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + # Now RNN-LM + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("sos_id") + eos_id = sp.piece_to_id("eos_id") + + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64) + y_tokens = y_tokens.to(torch.int64) + sentence_lengths = sentence_lengths.to(torch.int64) + + rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) + assert rnn_lm_nll.ndim == 2 + assert rnn_lm_nll.shape[0] == len(token_list) + rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) + + ans: Dict[str, List[List[int]]] = {} + for n_scale in ngram_lm_scale_list: + for rnn_scale in rnn_lm_scale_list: + key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" + tot_scores = ( + am_scores.values + + n_scale * ngram_lm_scores + + rnn_scale * rnn_lm_scores + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) + + ans[key] = hyps + + return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 8f55413e46..c3a03f2e16 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -112,6 +112,7 @@ fast_beam_search_nbest_oracle, fast_beam_search_one_best, fast_beam_search_with_nbest_rescoring, + fast_beam_search_with_nbest_rnn_rescoring, greedy_search, greedy_search_batch, modified_beam_search, @@ -125,8 +126,10 @@ load_checkpoint, ) from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, + load_averaged_model, setup_logger, store_transcripts, str2bool, @@ -342,6 +345,62 @@ def get_parser(): """, ) + parser.add_argument( + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, + ) + + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=True, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + add_model_arguments(parser) return parser @@ -355,6 +414,7 @@ def decode_one_batch( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None, + rnn_lm_model: torch.nn.Module = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -526,6 +586,30 @@ def decode_one_batch( nbest_scale=params.nbest_scale, temperature=params.temperature, ) + elif params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring": + ngram_lm_scale_list = [-0.5, -0.2, -0.1, -0.05, -0.02, 0] + ngram_lm_scale_list += [0.01, 0.02, 0.05] + ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.8] + ngram_lm_scale_list += [1.0, 1.5, 2.5, 3] + hyp_tokens = fast_beam_search_with_nbest_rnn_rescoring( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ngram_lm_scale_list=ngram_lm_scale_list, + num_paths=params.num_paths, + G=G, + sp=sp, + word_table=word_table, + rnn_lm_model=rnn_lm_model, + rnn_lm_scale_list=ngram_lm_scale_list, + use_double_scores=True, + nbest_scale=params.nbest_scale, + temperature=params.temperature, + ) else: batch_size = encoder_out.size(0) @@ -571,7 +655,10 @@ def decode_one_batch( f"temperature_{params.temperature}" ): hyps } - elif params.decoding_method == "fast_beam_search_with_nbest_rescoring": + elif params.decoding_method in [ + "fast_beam_search_with_nbest_rescoring", + "fast_beam_search_with_nbest_rnn_rescoring", + ]: prefix = ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" @@ -612,6 +699,7 @@ def decode_dataset( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None, + rnn_lm_model: torch.nn.Module = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -666,6 +754,7 @@ def decode_dataset( decoding_graph=decoding_graph, batch=batch, G=G, + rnn_lm_model=rnn_lm_model, ) for name, hyps in hyps_dict.items(): @@ -816,6 +905,7 @@ def main(): "fast_beam_search_nbest_oracle", "modified_beam_search", "fast_beam_search_with_nbest_rescoring", + "fast_beam_search_with_nbest_rnn_rescoring", ) params.res_dir = params.exp_dir / params.decoding_method @@ -919,7 +1009,10 @@ def main(): torch.load(lg_filename, map_location=device) ) decoding_graph.scores *= params.ngram_lm_scale - elif params.decoding_method == "fast_beam_search_with_nbest_rescoring": + elif params.decoding_method in [ + "fast_beam_search_with_nbest_rescoring", + "fast_beam_search_with_nbest_rnn_rescoring", + ]: logging.info(f"Loading word symbol table from {params.words_txt}") word_table = k2.SymbolTable.from_file(params.words_txt) @@ -932,14 +1025,43 @@ def main(): params.vocab_size - 1, device=device ) logging.info(f"G properties_str: {G.properties_str}") + rnn_lm_model = None + if ( + params.decoding_method + == "fast_beam_search_with_nbest_rnn_rescoring" + ): + rnn_lm_model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + if params.rnn_lm_avg == 1: + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + else: + rnn_lm_model = load_averaged_model( + params.rnn_lm_exp_dir, + rnn_lm_model, + params.rnn_lm_epoch, + params.rnn_lm_avg, + device, + ) + rnn_lm_model.eval() else: word_table = None decoding_graph = k2.trivial_graph( params.vocab_size - 1, device=device ) + rnn_lm_model = None else: decoding_graph = None word_table = None + rnn_lm_model = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -965,6 +1087,7 @@ def main(): word_table=word_table, decoding_graph=decoding_graph, G=G, + rnn_lm_model=rnn_lm_model, ) save_results( diff --git a/icefall/decode.py b/icefall/decode.py index 680e296198..e596876f40 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1006,6 +1006,8 @@ def rescore_with_rnn_lm( An FsaVec with axes [utt][state][arc]. num_paths: Number of paths to extract from the given lattice for rescoring. + rnn_lm_model: + A rnn-lm model used for LM rescoring model: A transformer model. See the class "Transformer" in conformer_ctc/transformer.py for its interface.