Skip to content

Commit

Permalink
Add RNN-LM rescoring in fast beam search (#475)
Browse files Browse the repository at this point in the history
  • Loading branch information
ezerhouni authored Jul 18, 2022
1 parent aec222e commit 608473b
Show file tree
Hide file tree
Showing 3 changed files with 325 additions and 11 deletions.
207 changes: 198 additions & 9 deletions egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from model import Transducer

from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
from icefall.utils import add_eos, add_sos, get_texts


def fast_beam_search_one_best(
Expand All @@ -46,7 +46,7 @@ def fast_beam_search_one_best(
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
Expand Down Expand Up @@ -106,7 +106,7 @@ def fast_beam_search_nbest_LG(
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
Expand Down Expand Up @@ -226,7 +226,7 @@ def fast_beam_search_nbest(
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
Expand Down Expand Up @@ -311,7 +311,7 @@ def fast_beam_search_nbest_oracle(
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
Expand Down Expand Up @@ -397,7 +397,7 @@ def fast_beam_search(
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
Expand Down Expand Up @@ -1219,13 +1219,15 @@ def fast_beam_search_with_nbest_rescoring(
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model. The shortest path within the
lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
Expand Down Expand Up @@ -1350,3 +1352,190 @@ def fast_beam_search_with_nbest_rescoring(
ans[key] = hyps

return ans


def fast_beam_search_with_nbest_rnn_rescoring(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
ngram_lm_scale_list: List[float],
num_paths: int,
G: k2.Fsa,
sp: spm.SentencePieceProcessor,
word_table: k2.SymbolTable,
rnn_lm_model: torch.nn.Module,
rnn_lm_scale_list: List[float],
oov_word: str = "<UNK>",
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model and a rnn-lm.
The shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi.
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
ngram_lm_scale_list:
A list of floats representing LM score scales.
num_paths:
Number of paths to extract from the decoded lattice.
G:
An FsaVec containing only a single FSA. It is an n-gram LM.
sp:
The BPE model.
word_table:
The word symbol table.
rnn_lm_model:
A rnn-lm model used for LM rescoring
rnn_lm_scale_list:
A list of floats representing RNN score scales.
oov_word:
OOV words are replaced with this word.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
temperature:
Softmax temperature.
Returns:
Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
ngram LM scale value used during decoding, i.e., 0.1.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)

nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.

nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores

am_scores = nbest.tot_scores()

# Now we need to compute the LM scores of each path.
# (1) Get the token IDs of each Path. We assume the decoding_graph
# is an acceptor, i.e., lattice is also an acceptor
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc]

tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous())
tokens = tokens.remove_values_leq(0) # remove -1 and 0

token_list: List[List[int]] = tokens.tolist()
word_list: List[List[str]] = sp.decode(token_list)

assert isinstance(oov_word, str), oov_word
assert oov_word in word_table, oov_word
oov_word_id = word_table[oov_word]

word_ids_list: List[List[int]] = []

for words in word_list:
this_word_ids = []
for w in words.split():
if w in word_table:
this_word_ids.append(word_table[w])
else:
this_word_ids.append(oov_word_id)
word_ids_list.append(this_word_ids)

word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device)
word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas)

num_unique_paths = len(word_ids_list)

b_to_a_map = torch.zeros(
num_unique_paths,
dtype=torch.int32,
device=lattice.device,
)

rescored_word_fsas = k2.intersect_device(
a_fsas=G,
b_fsas=word_fsas_with_self_loops,
b_to_a_map=b_to_a_map,
sorted_match_a=True,
ret_arc_maps=False,
)

rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas)
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
use_double_scores=True,
log_semiring=False,
)

# Now RNN-LM
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("sos_id")
eos_id = sp.piece_to_id("eos_id")

sos_tokens = add_sos(tokens, sos_id)
tokens_eos = add_eos(tokens, eos_id)
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]

x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)

x_tokens = x_tokens.to(torch.int64)
y_tokens = y_tokens.to(torch.int64)
sentence_lengths = sentence_lengths.to(torch.int64)

rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths)
assert rnn_lm_nll.ndim == 2
assert rnn_lm_nll.shape[0] == len(token_list)
rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1)

ans: Dict[str, List[List[int]]] = {}
for n_scale in ngram_lm_scale_list:
for rnn_scale in rnn_lm_scale_list:
key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}"
tot_scores = (
am_scores.values
+ n_scale * ngram_lm_scores
+ rnn_scale * rnn_lm_scores
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)

ans[key] = hyps

return ans
Loading

0 comments on commit 608473b

Please sign in to comment.