Skip to content

Commit 5b6699a

Browse files
authored
Minor fixes to the RNN-T Conformer model (#152)
* Disable weight decay. * Remove input feature batchnorm.. * Replace BatchNorm in the Conformer model with LayerNorm. * Use tanh in the joint network. * Remove sos ID. * Reduce the number of decoder layers from 4 to 2. * Minor fixes. * Fix typos.
1 parent fb6a57e commit 5b6699a

19 files changed

+147
-86
lines changed

.github/workflows/run-pretrained-transducer-stateless.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
name: run-pre-trained-tranducer-stateless
17+
name: run-pre-trained-trandsucer-stateless
1818

1919
on:
2020
push:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
2+
3+
# See ../../LICENSE for clarification regarding multiple authors
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
name: run-pre-trained-transducer
18+
19+
on:
20+
push:
21+
branches:
22+
- master
23+
pull_request:
24+
types: [labeled]
25+
26+
jobs:
27+
run_pre_trained_transducer:
28+
if: github.event.label.name == 'ready' || github.event_name == 'push'
29+
runs-on: ${{ matrix.os }}
30+
strategy:
31+
matrix:
32+
os: [ubuntu-18.04]
33+
python-version: [3.7, 3.8, 3.9]
34+
torch: ["1.10.0"]
35+
torchaudio: ["0.10.0"]
36+
k2-version: ["1.9.dev20211101"]
37+
38+
fail-fast: false
39+
40+
steps:
41+
- uses: actions/checkout@v2
42+
with:
43+
fetch-depth: 0
44+
45+
- name: Setup Python ${{ matrix.python-version }}
46+
uses: actions/setup-python@v1
47+
with:
48+
python-version: ${{ matrix.python-version }}
49+
50+
- name: Install Python dependencies
51+
run: |
52+
python3 -m pip install --upgrade pip pytest
53+
# numpy 1.20.x does not support python 3.6
54+
pip install numpy==1.19
55+
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
56+
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
57+
58+
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
59+
python3 -m pip install kaldifeat
60+
# We are in ./icefall and there is a file: requirements.txt in it
61+
pip install -r requirements.txt
62+
63+
- name: Install graphviz
64+
shell: bash
65+
run: |
66+
python3 -m pip install -qq graphviz
67+
sudo apt-get -qq install graphviz
68+
69+
- name: Download pre-trained model
70+
shell: bash
71+
run: |
72+
sudo apt-get -qq install git-lfs tree sox
73+
cd egs/librispeech/ASR
74+
mkdir tmp
75+
cd tmp
76+
git lfs install
77+
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23
78+
79+
cd ..
80+
tree tmp
81+
soxi tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
82+
ls -lh tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
83+
84+
- name: Run greedy search decoding
85+
shell: bash
86+
run: |
87+
export PYTHONPATH=$PWD:PYTHONPATH
88+
cd egs/librispeech/ASR
89+
./transducer/pretrained.py \
90+
--method greedy_search \
91+
--checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
92+
--bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
93+
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
94+
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
95+
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav
96+
97+
- name: Run beam search decoding
98+
shell: bash
99+
run: |
100+
export PYTHONPATH=$PWD:$PYTHONPATH
101+
cd egs/librispeech/ASR
102+
./transducer/pretrained.py \
103+
--method beam_search \
104+
--beam-size 4 \
105+
--checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
106+
--bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
107+
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
108+
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
109+
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ The best WER with greedy search is:
7171

7272
| | test-clean | test-other |
7373
|-----|------------|------------|
74-
| WER | 3.16 | 7.71 |
74+
| WER | 3.07 | 7.51 |
7575

7676
We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing)
7777

egs/librispeech/ASR/RESULTS.md

+13-10
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
### LibriSpeech BPE training results (Transducer)
44

5-
#### 2021-12-22
5+
#### Conformer encoder + embedding decoder
6+
7+
Using commit `fb6a57e9e01dd8aae2af2a6b4568daad8bc8ab32`.
8+
69
Conformer encoder + non-current decoder. The decoder
710
contains only an embedding layer and a Conv1d (with kernel size 2).
811

@@ -60,18 +63,18 @@ avg=10
6063
```
6164

6265

63-
#### 2021-12-17
64-
Using commit `cb04c8a7509425ab45fae888b0ca71bbbd23f0de`.
66+
#### Conformer encoder + LSTM decoder
67+
Using commit `TODO`.
6568

6669
Conformer encoder + LSTM decoder.
6770

6871
The best WER is
6972

7073
| | test-clean | test-other |
7174
|-----|------------|------------|
72-
| WER | 3.16 | 7.71 |
75+
| WER | 3.07 | 7.51 |
7376

74-
using `--epoch 26 --avg 12` with **greedy search**.
77+
using `--epoch 34 --avg 11` with **greedy search**.
7578

7679
The training command to reproduce the above WER is:
7780

@@ -80,19 +83,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
8083
8184
./transducer/train.py \
8285
--world-size 4 \
83-
--num-epochs 30 \
86+
--num-epochs 35 \
8487
--start-epoch 0 \
8588
--exp-dir transducer/exp-lr-2.5-full \
8689
--full-libri 1 \
87-
--max-duration 250 \
90+
--max-duration 180 \
8891
--lr-factor 2.5
8992
```
9093

9194
The decoding command is:
9295

9396
```
94-
epoch=26
95-
avg=12
97+
epoch=34
98+
avg=11
9699
97100
./transducer/decode.py \
98101
--epoch $epoch \
@@ -102,7 +105,7 @@ avg=12
102105
--max-duration 100
103106
```
104107

105-
You can find the tensorboard log at: <https://tensorboard.dev/experiment/PYIbeD6zRJez1ViXaRqqeg/>
108+
You can find the tensorboard log at: <https://tensorboard.dev/experiment/D7NQc3xqTpyVmWi5FnWjrA>
106109

107110

108111
### LibriSpeech BPE training results (Conformer-CTC)

egs/librispeech/ASR/transducer/beam_search.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ def beam_search(
111111
# support only batch_size == 1 for now
112112
assert encoder_out.size(0) == 1, encoder_out.size(0)
113113
blank_id = model.decoder.blank_id
114-
sos_id = model.decoder.sos_id
115114
device = model.device
116115

117116
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
@@ -192,7 +191,7 @@ def beam_search(
192191

193192
# Second, choose other labels
194193
for i, v in enumerate(log_prob.tolist()):
195-
if i in (blank_id, sos_id):
194+
if i == blank_id:
196195
continue
197196
new_ys = y_star.ys + [i]
198197
new_log_prob = y_star.log_prob + v

egs/librispeech/ASR/transducer/conformer.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def __init__(
5656
cnn_module_kernel: int = 31,
5757
normalize_before: bool = True,
5858
vgg_frontend: bool = False,
59-
use_feat_batchnorm: bool = False,
6059
) -> None:
6160
super(Conformer, self).__init__(
6261
num_features=num_features,
@@ -69,7 +68,6 @@ def __init__(
6968
dropout=dropout,
7069
normalize_before=normalize_before,
7170
vgg_frontend=vgg_frontend,
72-
use_feat_batchnorm=use_feat_batchnorm,
7371
)
7472

7573
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@@ -107,11 +105,6 @@ def forward(
107105
- logit_lens, a tensor of shape (batch_size,) containing the number
108106
of frames in `logits` before padding.
109107
"""
110-
if self.use_feat_batchnorm:
111-
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
112-
x = self.feat_batchnorm(x)
113-
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
114-
115108
x = self.encoder_embed(x)
116109
x, pos_emb = self.encoder_pos(x)
117110
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
@@ -873,7 +866,7 @@ def __init__(
873866
groups=channels,
874867
bias=bias,
875868
)
876-
self.norm = nn.BatchNorm1d(channels)
869+
self.norm = nn.LayerNorm(channels)
877870
self.pointwise_conv2 = nn.Conv1d(
878871
channels,
879872
channels,
@@ -903,7 +896,12 @@ def forward(self, x: Tensor) -> Tensor:
903896

904897
# 1D Depthwise Conv
905898
x = self.depthwise_conv(x)
906-
x = self.activation(self.norm(x))
899+
# x is (batch, channels, time)
900+
x = x.permute(0, 2, 1)
901+
x = self.norm(x)
902+
x = x.permute(0, 2, 1)
903+
904+
x = self.activation(x)
907905

908906
x = self.pointwise_conv2(x) # (batch, channel, time)
909907

egs/librispeech/ASR/transducer/decode.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,14 @@ def get_parser():
7070
parser.add_argument(
7171
"--epoch",
7272
type=int,
73-
default=26,
73+
default=34,
7474
help="It specifies the checkpoint to use for decoding."
7575
"Note: Epoch counts from 0.",
7676
)
7777
parser.add_argument(
7878
"--avg",
7979
type=int,
80-
default=12,
80+
default=11,
8181
help="Number of checkpoints to average. Automatically select "
8282
"consecutive checkpoints before the checkpoint specified by "
8383
"'--epoch'. ",
@@ -129,10 +129,9 @@ def get_params() -> AttributeDict:
129129
"dim_feedforward": 2048,
130130
"num_encoder_layers": 12,
131131
"vgg_frontend": False,
132-
"use_feat_batchnorm": True,
133132
# decoder params
134133
"decoder_embedding_dim": 1024,
135-
"num_decoder_layers": 4,
134+
"num_decoder_layers": 2,
136135
"decoder_hidden_dim": 512,
137136
"env_info": get_env_info(),
138137
}
@@ -151,7 +150,6 @@ def get_encoder_model(params: AttributeDict):
151150
dim_feedforward=params.dim_feedforward,
152151
num_encoder_layers=params.num_encoder_layers,
153152
vgg_frontend=params.vgg_frontend,
154-
use_feat_batchnorm=params.use_feat_batchnorm,
155153
)
156154
return encoder
157155

@@ -161,7 +159,6 @@ def get_decoder_model(params: AttributeDict):
161159
vocab_size=params.vocab_size,
162160
embedding_dim=params.decoder_embedding_dim,
163161
blank_id=params.blank_id,
164-
sos_id=params.sos_id,
165162
num_layers=params.num_decoder_layers,
166163
hidden_dim=params.decoder_hidden_dim,
167164
output_dim=params.encoder_out_dim,
@@ -401,7 +398,6 @@ def main():
401398

402399
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
403400
params.blank_id = sp.piece_to_id("<blk>")
404-
params.sos_id = sp.piece_to_id("<sos/eos>")
405401
params.vocab_size = sp.get_piece_size()
406402

407403
logging.info(params)

egs/librispeech/ASR/transducer/decoder.py

-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def __init__(
2727
vocab_size: int,
2828
embedding_dim: int,
2929
blank_id: int,
30-
sos_id: int,
3130
num_layers: int,
3231
hidden_dim: int,
3332
output_dim: int,
@@ -42,8 +41,6 @@ def __init__(
4241
Dimension of the input embedding.
4342
blank_id:
4443
The ID of the blank symbol.
45-
sos_id:
46-
The ID of the SOS symbol.
4744
num_layers:
4845
Number of LSTM layers.
4946
hidden_dim:
@@ -71,7 +68,6 @@ def __init__(
7168
dropout=rnn_dropout,
7269
)
7370
self.blank_id = blank_id
74-
self.sos_id = sos_id
7571
self.output_linear = nn.Linear(hidden_dim, output_dim)
7672

7773
def forward(

0 commit comments

Comments
 (0)