Skip to content

Commit 3a9ff31

Browse files
committed
word level lm rescore
1 parent ac9655c commit 3a9ff31

File tree

4 files changed

+137
-6
lines changed

4 files changed

+137
-6
lines changed

egs/librispeech/ASR/lm.sh

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
stage=1
2+
3+
text=data/local/lm/librispeech-lm-norm.txt.gz
4+
text_dir=data/lm/text
5+
all_train_text=$text_dir/librispeech.txt
6+
# there are 40,398,052 pieces in all_train_text, which will take 50 MINUTES to be tokenized, with a single process.
7+
# use $train_pieces data to validate pipeline
8+
# train_pieces=300000 # 15 times of dev.txt
9+
# uncomment follwoing line to use all_train_text
10+
train_pieces=
11+
dev_text=$text_dir/dev.txt
12+
if [ $stage -le 0 ]; then
13+
# reference:
14+
# https://github.com/kaldi-asr/kaldi/blob/pybind11/egs/librispeech/s5/local/rnnlm/tuning/run_tdnn_lstm_1a.sh#L75
15+
# use the same data seperation method to kaldi whose result can be used as a baseline
16+
if [ ! -f $text ]; then
17+
wget http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz -P data/local/lm
18+
fi
19+
echo -n >$text_dir/dev.txt
20+
# hold out one in every 2000 lines as dev data.
21+
gunzip -c $text | cut -d ' ' -f2- | awk -v text_dir=$text_dir '{if(NR%2000 == 0) { print >text_dir"/dev.txt"; } else {print;}}' >$all_train_text
22+
fi
23+
24+
if [ $stage -eq 1 ]; then
25+
# for text_file in dev.txt librispeech.txt; do
26+
# python ./vq_pruned_transducer_stateless2/tokenize_text.py \
27+
# --tokenizer-path ./data/lang_bpe_500/bpe.model \
28+
# --text-file ./data/lm/text/$text_file
29+
# done
30+
lmplz -o 4 --text data/lm/text/librispeech.txt --arpa train.arpa -S 10%
31+
# lmplz -o 4 --text data/lm/text/librispeech.txt --arpa discount_train.arpa -S 10% \
32+
# --discount_fallback
33+
# lmplz -o 4 --text data/lm/text/librispeech.txt.tokens --arpa token_train.arpa -S 10% \
34+
# --discount_fallback 0.5
35+
fi

egs/librispeech/ASR/vq_pruned_transducer_stateless2/beam_search.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Dict, List, Optional
1919

2020
import k2
21+
import kenlm
2122
import torch
2223
from model import Transducer
2324

@@ -267,6 +268,9 @@ class Hypothesis:
267268
# It contains only one entry.
268269
log_prob: torch.Tensor
269270

271+
last_start_idx: int
272+
state: None # lm state
273+
270274
@property
271275
def key(self) -> str:
272276
"""Return a string representation of self.ys"""
@@ -637,6 +641,7 @@ def beam_search(
637641
model: Transducer,
638642
encoder_out: torch.Tensor,
639643
beam: int = 4,
644+
lmr = None, # lm rescorer
640645
) -> List[int]:
641646
"""
642647
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@@ -677,7 +682,10 @@ def beam_search(
677682
t = 0
678683

679684
B = HypothesisList()
680-
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
685+
686+
start_state= kenlm.State()
687+
lmr.lm.BeginSentenceWrite(start_state)
688+
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, last_start_idx=0, state=start_state))
681689

682690
max_sym_per_utt = 20000
683691

@@ -738,7 +746,7 @@ def beam_search(
738746
new_y_star_log_prob = y_star.log_prob + skip_log_prob
739747

740748
# ys[:] returns a copy of ys
741-
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
749+
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob, last_start_idx=y_star.last_start_idx, state=y_star.state))
742750

743751
# Second, process other non-blank labels
744752
values, indices = log_prob.topk(beam + 1)
@@ -747,7 +755,9 @@ def beam_search(
747755
continue
748756
new_ys = y_star.ys + [i]
749757
new_log_prob = y_star.log_prob + v
750-
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
758+
tmp = Hypothesis(ys=new_ys, log_prob=new_log_prob, last_start_idx=y_star.last_start_idx, state=y_star.state)
759+
lmr.rescore(tmp)
760+
A.add(tmp)
751761

752762
# Check whether B contains more than "beam" elements more probable
753763
# than the most probable in A

egs/librispeech/ASR/vq_pruned_transducer_stateless2/decode.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from typing import Dict, List, Optional, Tuple
6464

6565
import k2
66+
import kenlm
6667
import sentencepiece as spm
6768
import torch
6869
import torch.nn as nn
@@ -88,12 +89,19 @@
8889
write_error_stats,
8990
)
9091

92+
from lm import LMRescorer
9193

9294
def get_parser():
9395
parser = argparse.ArgumentParser(
9496
formatter_class=argparse.ArgumentDefaultsHelpFormatter
9597
)
9698

99+
parser.add_argument(
100+
"--lm-weight",
101+
type=float,
102+
default=0.0,
103+
)
104+
97105
parser.add_argument(
98106
"--epoch",
99107
type=int,
@@ -206,6 +214,7 @@ def decode_one_batch(
206214
sp: spm.SentencePieceProcessor,
207215
batch: dict,
208216
decoding_graph: Optional[k2.Fsa] = None,
217+
lmr=None,
209218
) -> Dict[str, List[List[str]]]:
210219
"""Decode one batch and return the result in a dict. The dict has the
211220
following format:
@@ -298,6 +307,7 @@ def decode_one_batch(
298307
model=model,
299308
encoder_out=encoder_out_i,
300309
beam=params.beam_size,
310+
lmr=lmr,
301311
)
302312
else:
303313
raise ValueError(
@@ -325,6 +335,7 @@ def decode_dataset(
325335
model: nn.Module,
326336
sp: spm.SentencePieceProcessor,
327337
decoding_graph: Optional[k2.Fsa] = None,
338+
lmr=None,
328339
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
329340
"""Decode dataset.
330341
@@ -369,6 +380,7 @@ def decode_dataset(
369380
sp=sp,
370381
decoding_graph=decoding_graph,
371382
batch=batch,
383+
lmr=lmr,
372384
)
373385

374386
for name, hyps in hyps_dict.items():
@@ -399,19 +411,19 @@ def save_results(
399411
test_set_wers = dict()
400412
for key, results in results_dict.items():
401413
recog_path = (
402-
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
414+
params.res_dir / f"lm_weight-{params.lm_weight}-recogs-{test_set_name}-{key}-{params.suffix}.txt"
403415
)
404416
store_transcripts(filename=recog_path, texts=results)
405417
logging.info(f"The transcripts are stored in {recog_path}")
406418

407419
# The following prints out WERs, per-word error statistics and aligned
408420
# ref/hyp pairs.
409421
errs_filename = (
410-
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
422+
params.res_dir / f"lm_weight-{params.lm_weight}-errs-{test_set_name}-{key}-{params.suffix}.txt"
411423
)
412424
with open(errs_filename, "w") as f:
413425
wer = write_error_stats(
414-
f, f"{test_set_name}-{key}", results, enable_log=True
426+
f, f"lm_weight-{params.lm_weight}-{test_set_name}-{key}", results, enable_log=True
415427
)
416428
test_set_wers[key] = wer
417429

@@ -479,6 +491,8 @@ def main():
479491
# <blk> is defined in local/train_bpe_model.py
480492
params.blank_id = sp.piece_to_id("<blk>")
481493
params.vocab_size = sp.get_piece_size()
494+
LM = "/ceph-data2/ly/kenlm/train_lm/train.bin"
495+
params.lm_path = f'{LM}'
482496

483497
logging.info(params)
484498

@@ -506,6 +520,10 @@ def main():
506520
model.eval()
507521
model.device = device
508522

523+
lm_model = kenlm.LanguageModel(LM)
524+
525+
lmr=LMRescorer(Path("./data/lang_bpe_500/"), blank_id = model.decoder.blank_id, lm=lm_model, weight=params.lm_weight)
526+
509527
if params.decoding_method == "fast_beam_search":
510528
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
511529
else:
@@ -532,6 +550,7 @@ def main():
532550
model=model,
533551
sp=sp,
534552
decoding_graph=decoding_graph,
553+
lmr=lmr,
535554
)
536555

537556
save_results(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from pathlib import Path
2+
from icefall.lexicon import read_lexicon
3+
import sentencepiece as spm
4+
import kenlm
5+
6+
def extract_start_tokens(lang_dir: Path = Path("./data/lang_bpe_500/"):
7+
tokens = read_lexicon(lang_dir / "/tokens.txt")
8+
9+
# Get the leading underscore of '▁THE 4'.
10+
# Actually its not a underscore, its just looks similar to it.
11+
word_start_char = tokens[4][0][0]
12+
13+
word_start_token = []
14+
non_start_token = []
15+
16+
aux=['<sos/eos>', '<unk>']
17+
for t in tokens:
18+
leading_char = t[0][0]
19+
if leading_char == word_start_char or t[0] in aux:
20+
word_start_token.append(t)
21+
else:
22+
non_start_token.append(t)
23+
24+
write_lexicon(lang_dir / "word_start_tokens.txt", word_start_token)
25+
write_lexicon(lang_dir / "non_start_tokens.txt", non_start_token)
26+
27+
def lexicon_to_dict(lexicon):
28+
token2idx = {}
29+
idx2token = {}
30+
for token, idx in lexicon:
31+
assert len(idx) == 1
32+
idx = idx[0]
33+
token2idx[token] = int(idx)
34+
idx2token[int(idx)] = token
35+
return token2idx, idx2token
36+
37+
38+
class LMRescorer:
39+
def __init__(self, lang_dir, blank_id, lm, weight):
40+
self.lm=lm
41+
self.start_token2idx, self.start_idx2token = lexicon_to_dict(read_lexicon(lang_dir/"word_start_tokens.txt"))
42+
self.nonstart_token2idx, self.nonstart_idx2token = lexicon_to_dict(read_lexicon(lang_dir/"non_start_tokens.txt"))
43+
self.token2idx, self.idx2token = lexicon_to_dict(read_lexicon(lang_dir/"tokens.txt"))
44+
self.sp = spm.SentencePieceProcessor()
45+
self.sp.load(str(lang_dir/"bpe.model"))
46+
self.blank_id = blank_id
47+
self.weight = weight
48+
49+
def rescore(self, hyp):
50+
if self.weight > 0 and hyp.ys[-1] in self.start_idx2token:
51+
word = self.previous_word(hyp)
52+
output_state= kenlm.State()
53+
lm_score = self.lm.BaseScore(hyp.state, word, output_state)
54+
hyp.state = output_state
55+
hyp.log_prob += self.weight * lm_score
56+
return hyp
57+
58+
def previous_word(self, hyp):
59+
last_start_idx = hyp.last_start_idx
60+
tokens_seq = hyp.ys[last_start_idx: -1]
61+
tokens_seq = [t for t in tokens_seq if t!=self.blank_id]
62+
word = self.sp.decode(tokens_seq)
63+
hyp.last_start_idx = len(hyp.ys) - 1
64+
return word
65+
66+
67+

0 commit comments

Comments
 (0)