Skip to content

Commit 57cb611

Browse files
authored
[yesno] Remove padding in TDNN (#21)
* Disable SpecAug for yesno. Also replace Adam with SGD. * Remove padding in the model to make the results reproducible.
1 parent 6c2c9b9 commit 57cb611

File tree

5 files changed

+20
-52
lines changed

5 files changed

+20
-52
lines changed

.github/workflows/run-yesno-recipe.yml

+2-13
Original file line numberDiff line numberDiff line change
@@ -69,21 +69,10 @@ jobs:
6969
run: |
7070
export PYTHONPATH=$PWD:$PYTHONPATH
7171
echo $PYTHONPATH
72-
ls -lh
7372
74-
# The following three lines are for macOS
75-
lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
76-
echo "lib_path: $lib_path"
77-
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
78-
ls -lh $lib_path
7973
8074
cd egs/yesno/ASR
8175
./prepare.sh
82-
python3 ./tdnn/train.py --num-epochs 100
83-
python3 ./tdnn/decode.py --epoch 99
84-
python3 ./tdnn/decode.py --epoch 95
85-
python3 ./tdnn/decode.py --epoch 90
86-
python3 ./tdnn/decode.py --epoch 80
87-
python3 ./tdnn/decode.py --epoch 70
88-
python3 ./tdnn/decode.py --epoch 60
76+
python3 ./tdnn/train.py
77+
python3 ./tdnn/decode.py
8978
# TODO: Check that the WER is less than some value

egs/yesno/ASR/tdnn/asr_datamodule.py

-12
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
K2SpeechRecognitionDataset,
2828
PrecomputedFeatures,
2929
SingleCutSampler,
30-
SpecAugment,
3130
)
3231
from lhotse.dataset.input_strategies import OnTheFlyFeatures
3332
from torch.utils.data import DataLoader
@@ -163,18 +162,8 @@ def train_dataloaders(self) -> DataLoader:
163162
)
164163
] + transforms
165164

166-
input_transforms = [
167-
SpecAugment(
168-
num_frame_masks=2,
169-
features_mask_size=27,
170-
num_feature_masks=2,
171-
frames_mask_size=100,
172-
)
173-
]
174-
175165
train = K2SpeechRecognitionDataset(
176166
cut_transforms=transforms,
177-
input_transforms=input_transforms,
178167
return_cuts=self.args.return_cuts,
179168
)
180169

@@ -194,7 +183,6 @@ def train_dataloaders(self) -> DataLoader:
194183
input_strategy=OnTheFlyFeatures(
195184
Fbank(FbankConfig(num_mel_bins=23))
196185
),
197-
input_transforms=input_transforms,
198186
return_cuts=self.args.return_cuts,
199187
)
200188

egs/yesno/ASR/tdnn/decode.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@ def get_parser():
3232
parser.add_argument(
3333
"--epoch",
3434
type=int,
35-
default=9,
35+
default=14,
3636
help="It specifies the checkpoint to use for decoding."
3737
"Note: Epoch counts from 0.",
3838
)
3939
parser.add_argument(
4040
"--avg",
4141
type=int,
42-
default=15,
42+
default=2,
4343
help="Number of checkpoints to average. Automatically select "
4444
"consecutive checkpoints before the checkpoint specified by "
4545
"'--epoch'. ",
@@ -104,16 +104,11 @@ def decode_one_batch(
104104
nnet_output = model(feature)
105105
# nnet_output is [N, T, C]
106106

107-
supervisions = batch["supervisions"]
108-
109-
supervision_segments = torch.stack(
110-
(
111-
supervisions["sequence_idx"],
112-
supervisions["start_frame"],
113-
supervisions["num_frames"],
114-
),
115-
1,
116-
).to(torch.int32)
107+
batch_size = nnet_output.shape[0]
108+
supervision_segments = torch.tensor(
109+
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
110+
dtype=torch.int32,
111+
)
117112

118113
lattice = get_lattice(
119114
nnet_output=nnet_output,

egs/yesno/ASR/tdnn/model.py

-3
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,13 @@ def __init__(self, num_features: int, num_classes: int):
2323
in_channels=num_features,
2424
out_channels=32,
2525
kernel_size=3,
26-
padding=1,
2726
),
2827
nn.ReLU(inplace=True),
2928
nn.BatchNorm1d(num_features=32, affine=False),
3029
nn.Conv1d(
3130
in_channels=32,
3231
out_channels=32,
3332
kernel_size=5,
34-
padding=4,
3533
dilation=2,
3634
),
3735
nn.ReLU(inplace=True),
@@ -40,7 +38,6 @@ def __init__(self, num_features: int, num_classes: int):
4038
in_channels=32,
4139
out_channels=32,
4240
kernel_size=5,
43-
padding=8,
4441
dilation=4,
4542
),
4643
nn.ReLU(inplace=True),

egs/yesno/ASR/tdnn/train.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,7 @@
2424
from icefall.dist import cleanup_dist, setup_dist
2525
from icefall.graph_compiler import CtcTrainingGraphCompiler
2626
from icefall.lexicon import Lexicon
27-
from icefall.utils import (
28-
AttributeDict,
29-
encode_supervisions,
30-
setup_logger,
31-
str2bool,
32-
)
27+
from icefall.utils import AttributeDict, setup_logger, str2bool
3328

3429

3530
def get_parser():
@@ -61,7 +56,7 @@ def get_parser():
6156
parser.add_argument(
6257
"--num-epochs",
6358
type=int,
64-
default=50,
59+
default=15,
6560
help="Number of epochs to train.",
6661
)
6762

@@ -129,11 +124,10 @@ def get_params() -> AttributeDict:
129124
{
130125
"exp_dir": Path("tdnn/exp"),
131126
"lang_dir": Path("data/lang_phone"),
132-
"lr": 1e-3,
127+
"lr": 1e-2,
133128
"feature_dim": 23,
134129
"weight_decay": 1e-6,
135130
"start_epoch": 0,
136-
"num_epochs": 50,
137131
"best_train_loss": float("inf"),
138132
"best_valid_loss": float("inf"),
139133
"best_train_epoch": -1,
@@ -278,9 +272,14 @@ def compute_loss(
278272
# different duration in decreasing order, required by
279273
# `k2.intersect_dense` called in `k2.ctc_loss`
280274
supervisions = batch["supervisions"]
281-
supervision_segments, texts = encode_supervisions(
282-
supervisions, subsampling_factor=1
275+
texts = supervisions["text"]
276+
277+
batch_size = nnet_output.shape[0]
278+
supervision_segments = torch.tensor(
279+
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
280+
dtype=torch.int32,
283281
)
282+
284283
decoding_graph = graph_compiler.compile(texts)
285284

286285
dense_fsa_vec = k2.DenseFsaVec(
@@ -491,7 +490,7 @@ def run(rank, world_size, args):
491490
if world_size > 1:
492491
model = DDP(model, device_ids=[rank])
493492

494-
optimizer = optim.AdamW(
493+
optimizer = optim.SGD(
495494
model.parameters(),
496495
lr=params.lr,
497496
weight_decay=params.weight_decay,

0 commit comments

Comments
 (0)