Skip to content

Commit cb04c8a

Browse files
authored
Limit the number of symbols per frame in RNN-T decoding. (#151)
1 parent 1d44da8 commit cb04c8a

File tree

8 files changed

+501
-21
lines changed

8 files changed

+501
-21
lines changed

egs/librispeech/ASR/transducer/beam_search.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,14 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
4343
T = encoder_out.size(1)
4444
t = 0
4545
hyp = []
46-
max_u = 1000 # terminate after this number of steps
47-
u = 0
4846

49-
while t < T and u < max_u:
47+
sym_per_frame = 0
48+
sym_per_utt = 0
49+
50+
max_sym_per_utt = 1000
51+
max_sym_per_frame = 3
52+
53+
while t < T and sym_per_utt < max_sym_per_utt:
5054
# fmt: off
5155
current_encoder_out = encoder_out[:, t:t+1, :]
5256
# fmt: on
@@ -61,8 +65,12 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
6165
hyp.append(y.item())
6266
y = y.reshape(1, 1)
6367
decoder_out, (h, c) = model.decoder(y, (h, c))
64-
u += 1
65-
else:
68+
69+
sym_per_utt += 1
70+
sym_per_frame += 1
71+
72+
if y == blank_id or sym_per_frame > max_sym_per_frame:
73+
sym_per_frame = 0
6674
t += 1
6775

6876
return hyp

egs/librispeech/ASR/transducer_lstm/__init__.py

Whitespace-only changes.

egs/librispeech/ASR/transducer_lstm/beam_search.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,14 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
4343
T = encoder_out.size(1)
4444
t = 0
4545
hyp = []
46-
max_u = 1000 # terminate after this number of steps
47-
u = 0
4846

49-
while t < T and u < max_u:
47+
sym_per_frame = 0
48+
sym_per_utt = 0
49+
50+
max_sym_per_utt = 1000
51+
max_sym_per_frame = 3
52+
53+
while t < T and sym_per_utt < max_sym_per_utt:
5054
# fmt: off
5155
current_encoder_out = encoder_out[:, t:t+1, :]
5256
# fmt: on
@@ -61,8 +65,12 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
6165
hyp.append(y.item())
6266
y = y.reshape(1, 1)
6367
decoder_out, (h, c) = model.decoder(y, (h, c))
64-
u += 1
65-
else:
68+
69+
sym_per_utt += 1
70+
sym_per_frame += 1
71+
72+
if y == blank_id or sym_per_frame > max_sym_per_frame:
73+
sym_per_frame = 0
6674
t += 1
6775

6876
return hyp

egs/librispeech/ASR/transducer_stateless/beam_search.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,18 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
3939
device = model.device
4040

4141
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
42-
decoder_out, (h, c) = model.decoder(sos)
42+
decoder_out = model.decoder(sos)
4343
T = encoder_out.size(1)
4444
t = 0
4545
hyp = []
46-
max_u = 1000 # terminate after this number of steps
47-
u = 0
4846

49-
while t < T and u < max_u:
47+
sym_per_frame = 0
48+
sym_per_utt = 0
49+
50+
max_sym_per_utt = 1000
51+
max_sym_per_frame = 3
52+
53+
while t < T and sym_per_utt < max_sym_per_utt:
5054
# fmt: off
5155
current_encoder_out = encoder_out[:, t:t+1, :]
5256
# fmt: on
@@ -60,9 +64,13 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
6064
if y != blank_id:
6165
hyp.append(y.item())
6266
y = y.reshape(1, 1)
63-
decoder_out, (h, c) = model.decoder(y, (h, c))
64-
u += 1
65-
else:
67+
decoder_out = model.decoder(y)
68+
69+
sym_per_utt += 1
70+
sym_per_frame += 1
71+
72+
if y == blank_id or sym_per_frame > max_sym_per_frame:
73+
sym_per_frame = 0
6674
t += 1
6775

6876
return hyp

egs/librispeech/ASR/transducer_stateless/conformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import torch
2424
from torch import Tensor, nn
25-
from transducer.transformer import Transformer
25+
from transformer import Transformer
2626

2727
from icefall.utils import make_pad_mask
2828

0 commit comments

Comments
 (0)