Skip to content

Commit d68b8e9

Browse files
authored
Disable CUDA_LAUNCH_BLOCKING in wenetspeech recipes. (#554)
* Disable CUDA_LAUNCH_BLOCKING in wenetspeech recipes. * minor fixes
1 parent 235eb07 commit d68b8e9

File tree

3 files changed

+22
-22
lines changed

3 files changed

+22
-22
lines changed

egs/wenetspeech/ASR/local/preprocess_wenetspeech.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from lhotse import CutSet, SupervisionSegment
2424
from lhotse.recipes.utils import read_manifests_if_cached
2525

26+
from icefall import setup_logger
27+
2628
# Similar text filtering and normalization procedure as in:
2729
# https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh
2830

@@ -48,13 +50,17 @@ def preprocess_wenet_speech():
4850
output_dir = Path("data/fbank")
4951
output_dir.mkdir(exist_ok=True)
5052

53+
# Note: By default, we preprocess all sub-parts.
54+
# You can delete those that you don't need.
55+
# For instance, if you don't want to use the L subpart, just remove
56+
# the line below containing "L"
5157
dataset_parts = (
52-
"L",
53-
"M",
54-
"S",
5558
"DEV",
5659
"TEST_NET",
5760
"TEST_MEETING",
61+
"S",
62+
"M",
63+
"L",
5864
)
5965

6066
logging.info("Loading manifest (may take 10 minutes)")
@@ -81,10 +87,13 @@ def preprocess_wenet_speech():
8187
logging.info(f"Normalizing text in {partition}")
8288
for sup in m["supervisions"]:
8389
text = str(sup.text)
84-
logging.info(f"Original text: {text}")
90+
orig_text = text
8591
sup.text = normalize_text(sup.text)
8692
text = str(sup.text)
87-
logging.info(f"Normalize text: {text}")
93+
if len(orig_text) != len(text):
94+
logging.info(
95+
f"\nOriginal text vs normalized text:\n{orig_text}\n{text}"
96+
)
8897

8998
# Create long-recording cut manifests.
9099
logging.info(f"Processing {partition}")
@@ -109,12 +118,10 @@ def preprocess_wenet_speech():
109118

110119

111120
def main():
112-
formatter = (
113-
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
114-
)
115-
logging.basicConfig(format=formatter, level=logging.INFO)
121+
setup_logger(log_filename="./log-preprocess-wenetspeech")
116122

117123
preprocess_wenet_speech()
124+
logging.info("Done")
118125

119126

120127
if __name__ == "__main__":

egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@
8181

8282
import argparse
8383
import logging
84-
import os
8584
import warnings
8685
from pathlib import Path
8786
from shutil import copyfile
@@ -120,8 +119,6 @@
120119
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
121120
]
122121

123-
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
124-
125122

126123
def get_parser():
127124
parser = argparse.ArgumentParser(
@@ -162,7 +159,7 @@ def get_parser():
162159
default=0,
163160
help="""Resume training from from this epoch.
164161
If it is positive, it will load checkpoint from
165-
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
162+
pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt
166163
""",
167164
)
168165

@@ -361,8 +358,8 @@ def get_params() -> AttributeDict:
361358
"best_valid_loss": float("inf"),
362359
"best_train_epoch": -1,
363360
"best_valid_epoch": -1,
364-
"batch_idx_train": 10,
365-
"log_interval": 1,
361+
"batch_idx_train": 0,
362+
"log_interval": 50,
366363
"reset_interval": 200,
367364
# parameters for conformer
368365
"feature_dim": 80,
@@ -545,7 +542,7 @@ def compute_loss(
545542
warmup: float = 1.0,
546543
) -> Tuple[Tensor, MetricsTracker]:
547544
"""
548-
Compute CTC loss given the model and its inputs.
545+
Compute RNN-T loss given the model and its inputs.
549546
Args:
550547
params:
551548
Parameters for training. See :func:`get_params`.
@@ -573,7 +570,7 @@ def compute_loss(
573570
texts = batch["supervisions"]["text"]
574571

575572
y = graph_compiler.texts_to_ids(texts)
576-
if type(y) == list:
573+
if isinstance(y, list):
577574
y = k2.RaggedTensor(y).to(device)
578575
else:
579576
y = y.to(device)
@@ -697,7 +694,6 @@ def train_one_epoch(
697694
tot_loss = MetricsTracker()
698695

699696
for batch_idx, batch in enumerate(train_dl):
700-
701697
params.batch_idx_train += 1
702698
batch_size = len(batch["supervisions"]["text"])
703699

egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
import argparse
6262
import copy
6363
import logging
64-
import os
6564
import warnings
6665
from pathlib import Path
6766
from shutil import copyfile
@@ -103,8 +102,6 @@
103102
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
104103
]
105104

106-
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
107-
108105

109106
def add_model_arguments(parser: argparse.ArgumentParser):
110107
parser.add_argument(
@@ -684,7 +681,7 @@ def compute_loss(
684681
texts = batch["supervisions"]["text"]
685682

686683
y = graph_compiler.texts_to_ids(texts)
687-
if type(y) == list:
684+
if isinstance(y, list):
688685
y = k2.RaggedTensor(y).to(device)
689686
else:
690687
y = y.to(device)

0 commit comments

Comments
 (0)