From dbc76dbd85bb17217d387cc6005e9da718a2fe2e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 20 Aug 2021 23:44:22 +0800 Subject: [PATCH 1/3] WIP: Refactor asr_datamodule. --- .../ASR/conformer_ctc}/asr_datamodule.py | 160 ++++++-- egs/librispeech/ASR/conformer_ctc/decode.py | 19 +- egs/librispeech/ASR/conformer_ctc/train.py | 8 +- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 362 ++++++++++++++++++ egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 14 +- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 11 +- icefall/dataset/librispeech.py | 68 ---- 7 files changed, 528 insertions(+), 114 deletions(-) rename {icefall/dataset => egs/librispeech/ASR/conformer_ctc}/asr_datamodule.py (63%) create mode 100644 egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py delete mode 100644 icefall/dataset/librispeech.py diff --git a/icefall/dataset/asr_datamodule.py b/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py similarity index 63% rename from icefall/dataset/asr_datamodule.py rename to egs/librispeech/ASR/conformer_ctc/asr_datamodule.py index 73eef9c31f..b3bd823ff7 100644 --- a/icefall/dataset/asr_datamodule.py +++ b/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py @@ -1,17 +1,20 @@ import argparse import logging +from functools import lru_cache from pathlib import Path from typing import List, Union -from lhotse import Fbank, FbankConfig, load_manifest +from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, CutConcatenate, CutMix, K2SpeechRecognitionDataset, + PrecomputedFeatures, SingleCutSampler, SpecAugment, ) +from lhotse.dataset.dataloading import LhotseDataLoader from lhotse.dataset.input_strategies import OnTheFlyFeatures from torch.utils.data import DataLoader @@ -19,7 +22,7 @@ from icefall.utils import str2bool -class AsrDataModule(DataModule): +class LibriSpeechAsrDataModule(DataModule): """ DataModule for K2 ASR experiments. It assumes there is always one train and valid dataloader, @@ -47,6 +50,13 @@ def add_arguments(cls, parser: argparse.ArgumentParser): "effective batch sizes, sampling strategies, applied data " "augmentations, etc.", ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) group.add_argument( "--feature-dir", type=Path, @@ -104,6 +114,38 @@ def add_arguments(cls, parser: argparse.ArgumentParser): "extraction. Will drop existing precomputed feature manifests " "if available.", ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--num-workers-inner", + type=int, + default=8, + help="The number of sub-workers (replicated for each of " + "training dataloader workers) that parallelize " + "the I/O to collect each batch.", + ) def train_dataloaders(self) -> DataLoader: logging.info("About to get train cuts") @@ -138,9 +180,9 @@ def train_dataloaders(self) -> DataLoader: ] train = K2SpeechRecognitionDataset( - cuts_train, cut_transforms=transforms, input_transforms=input_transforms, + return_cuts=self.args.return_cuts, ) if self.args.on_the_fly_feats: @@ -154,14 +196,14 @@ def train_dataloaders(self) -> DataLoader: # to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa # Drop feats to be on the safe side. - cuts_train = cuts_train.drop_features() train = K2SpeechRecognitionDataset( - cuts=cuts_train, cut_transforms=transforms, input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) + Fbank(FbankConfig(num_mel_bins=80)), + num_workers=self.args.num_workers_inner, ), input_transforms=input_transforms, + return_cuts=self.args.return_cuts, ) if self.args.bucketing_sampler: @@ -169,9 +211,9 @@ def train_dataloaders(self) -> DataLoader: train_sampler = BucketingSampler( cuts_train, max_duration=self.args.max_duration, - shuffle=True, + shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - bucket_method='equal_duration', + bucket_method="equal_duration", drop_last=True, ) else: @@ -179,45 +221,73 @@ def train_dataloaders(self) -> DataLoader: train_sampler = SingleCutSampler( cuts_train, max_duration=self.args.max_duration, - shuffle=True, + shuffle=self.args.shuffle, ) logging.info("About to create train dataloader") - train_dl = DataLoader( + + # train_dl = DataLoader( + # train, + # sampler=train_sampler, + # batch_size=None, + # num_workers=2, + # persistent_workers=False, + # ) + + train_dl = LhotseDataLoader( train, sampler=train_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, + num_workers=self.args.num_workers, + prefetch_factor=5, ) + return train_dl def valid_dataloaders(self) -> DataLoader: logging.info("About to get dev cuts") cuts_valid = self.valid_cuts() + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + logging.info("About to create dev dataset") if self.args.on_the_fly_feats: - cuts_valid = cuts_valid.drop_features() validate = K2SpeechRecognitionDataset( - cuts_valid.drop_features(), + cut_transforms=transforms, input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=80)) ), + return_cuts=self.args.return_cuts, ) else: - validate = K2SpeechRecognitionDataset(cuts_valid) + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) valid_sampler = SingleCutSampler( cuts_valid, max_duration=self.args.max_duration, + shuffle=False, ) logging.info("About to create dev dataloader") - valid_dl = DataLoader( + # valid_dl = DataLoader( + # validate, + # sampler=valid_sampler, + # batch_size=None, + # num_workers=2, + # persistent_workers=False, + # ) + + valid_dl = LhotseDataLoader( validate, sampler=valid_sampler, - batch_size=None, num_workers=2, - persistent_workers=False, ) + return valid_dl def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: @@ -230,21 +300,63 @@ def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: for cuts_test in cuts: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - cuts_test, input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) + Fbank(FbankConfig(num_mel_bins=80), num_workers=4) + if self.args.on_the_fly_feats + else PrecomputedFeatures() ), + return_cuts=self.args.return_cuts, ) sampler = SingleCutSampler( cuts_test, max_duration=self.args.max_duration ) logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=1 - ) + # test_dl = DataLoader( + # test, batch_size=None, sampler=sampler, num_workers=1 + # ) + test_dl = LhotseDataLoader(test, sampler=sampler, num_workers=2) test_loaders.append(test_dl) if is_list: return test_loaders else: return test_loaders[0] + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + cuts_train = load_manifest( + self.args.feature_dir / "cuts_train-clean-100.json.gz" + ) + if self.args.full_libri: + cuts_train = ( + cuts_train + + load_manifest( + self.args.feature_dir / "cuts_train-clean-360.json.gz" + ) + + load_manifest( + self.args.feature_dir / "cuts_train-other-500.json.gz" + ) + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest( + self.args.feature_dir / "cuts_dev-clean.json.gz" + ) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz") + return cuts_valid + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + test_sets = ["test-clean", "test-other"] + cuts = [] + for test_set in test_sets: + logging.debug("About to get test cuts") + cuts.append( + load_manifest( + self.args.feature_dir / f"cuts_{test_set}.json.gz" + ) + ) + return cuts diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index c17a8b284a..77f35253eb 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -13,11 +13,11 @@ import k2 import torch import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.decode import ( get_lattice, nbest_decoding, @@ -222,7 +222,7 @@ def decode_one_batch( use_double_scores=params.use_double_scores, scale=params.lattice_score_scale, ) - key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" + key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] @@ -317,7 +317,11 @@ def decode_dataset( results = [] num_cuts = 0 - tot_num_batches = len(dl) + + try: + num_batches = len(dl) + except TypeError: + num_batches = None results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -346,10 +350,13 @@ def decode_dataset( num_cuts += len(batch["supervisions"]["text"]) if batch_idx % 100 == 0: + if num_batches is not None: + batch_str = f"{batch_idx}/{num_batches}" + else: + batch_str = f"{batch_idx}" + logging.info( - f"batch {batch_idx}/{tot_num_batches}, cuts processed until now is " - f"{num_cuts}" - f"batch {batch_idx}, cuts processed until now is {num_cuts}" + f"batch {batch_str}, cuts processed until now is {num_cuts}" ) return results diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index d3ea8efb01..d17ee61642 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -13,10 +13,10 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_value_ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam @@ -24,7 +24,6 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.utils import ( @@ -61,9 +60,6 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) - # TODO: add extra arguments and support DDP training. - # Currently, only single GPU training is implemented. Will add - # DDP training once single GPU training is finished. return parser @@ -463,7 +459,7 @@ def train_one_epoch( optimizer.zero_grad() loss.backward() - clip_grad_value_(model.parameters(), 5.0) + clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() loss_cpu = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py new file mode 100644 index 0000000000..b3bd823ff7 --- /dev/null +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -0,0 +1,362 @@ +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import List, Union + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( + BucketingSampler, + CutConcatenate, + CutMix, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.dataloading import LhotseDataLoader +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool + + +class LibriSpeechAsrDataModule(DataModule): + """ + DataModule for K2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + super().add_arguments(parser) + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) + group.add_argument( + "--feature-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=500.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=False, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=True, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--num-workers-inner", + type=int, + default=8, + help="The number of sub-workers (replicated for each of " + "training dataloader workers) that parallelize " + "the I/O to collect each batch.", + ) + + def train_dataloaders(self) -> DataLoader: + logging.info("About to get train cuts") + cuts_train = self.train_cuts() + + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") + + logging.info("About to create train dataset") + transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [ + SpecAugment( + num_frame_masks=2, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ] + + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)), + num_workers=self.args.num_workers_inner, + ), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using BucketingSampler.") + train_sampler = BucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + bucket_method="equal_duration", + drop_last=True, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + # train_dl = DataLoader( + # train, + # sampler=train_sampler, + # batch_size=None, + # num_workers=2, + # persistent_workers=False, + # ) + + train_dl = LhotseDataLoader( + train, + sampler=train_sampler, + num_workers=self.args.num_workers, + prefetch_factor=5, + ) + + return train_dl + + def valid_dataloaders(self) -> DataLoader: + logging.info("About to get dev cuts") + cuts_valid = self.valid_cuts() + + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = SingleCutSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + # valid_dl = DataLoader( + # validate, + # sampler=valid_sampler, + # batch_size=None, + # num_workers=2, + # persistent_workers=False, + # ) + + valid_dl = LhotseDataLoader( + validate, + sampler=valid_sampler, + num_workers=2, + ) + + return valid_dl + + def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: + cuts = self.test_cuts() + is_list = isinstance(cuts, list) + test_loaders = [] + if not is_list: + cuts = [cuts] + + for cuts_test in cuts: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80), num_workers=4) + if self.args.on_the_fly_feats + else PrecomputedFeatures() + ), + return_cuts=self.args.return_cuts, + ) + sampler = SingleCutSampler( + cuts_test, max_duration=self.args.max_duration + ) + logging.debug("About to create test dataloader") + # test_dl = DataLoader( + # test, batch_size=None, sampler=sampler, num_workers=1 + # ) + test_dl = LhotseDataLoader(test, sampler=sampler, num_workers=2) + test_loaders.append(test_dl) + + if is_list: + return test_loaders + else: + return test_loaders[0] + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + cuts_train = load_manifest( + self.args.feature_dir / "cuts_train-clean-100.json.gz" + ) + if self.args.full_libri: + cuts_train = ( + cuts_train + + load_manifest( + self.args.feature_dir / "cuts_train-clean-360.json.gz" + ) + + load_manifest( + self.args.feature_dir / "cuts_train-other-500.json.gz" + ) + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest( + self.args.feature_dir / "cuts_dev-clean.json.gz" + ) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz") + return cuts_valid + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + test_sets = ["test-clean", "test-other"] + cuts = [] + for test_set in test_sets: + logging.debug("About to get test cuts") + cuts.append( + load_manifest( + self.args.feature_dir / f"cuts_{test_set}.json.gz" + ) + ) + return cuts diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 9a1aad5790..2aca804fae 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -10,10 +10,10 @@ import k2 import torch import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule from model import TdnnLstm from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.decode import ( get_lattice, nbest_decoding, @@ -237,6 +237,11 @@ def decode_dataset( num_cuts = 0 + try: + num_batches = len(dl) + except TypeError: + num_batches = None + results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] @@ -262,8 +267,13 @@ def decode_dataset( num_cuts += len(batch["supervisions"]["text"]) if batch_idx % 100 == 0: + if num_batches is not None: + batch_str = f"{batch_idx}/{num_batches}" + else: + batch_str = f"{batch_idx}" + logging.info( - f"batch {batch_idx}, cuts processed until now is {num_cuts}" + f"batch {batch_str}, cuts processed until now is {num_cuts}" ) return results diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index dbb9f64ecf..4adb988a07 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 -# This is just at the very beginning ... - import argparse import logging from pathlib import Path @@ -14,16 +12,16 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim +from asr_datamodule import LibriSpeechAsrDataModule from lhotse.utils import fix_random_seed from model import TdnnLstm from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_value_ +from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon @@ -61,9 +59,6 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) - # TODO: add extra arguments and support DDP training. - # Currently, only single GPU training is implemented. Will add - # DDP training once single GPU training is finished. return parser @@ -406,7 +401,7 @@ def train_one_epoch( optimizer.zero_grad() loss.backward() - clip_grad_value_(model.parameters(), 5.0) + clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() loss_cpu = loss.detach().cpu().item() diff --git a/icefall/dataset/librispeech.py b/icefall/dataset/librispeech.py deleted file mode 100644 index 5c18041ed8..0000000000 --- a/icefall/dataset/librispeech.py +++ /dev/null @@ -1,68 +0,0 @@ -import argparse -import logging -from functools import lru_cache -from typing import List - -from lhotse import CutSet, load_manifest - -from icefall.dataset.asr_datamodule import AsrDataModule -from icefall.utils import str2bool - - -class LibriSpeechAsrDataModule(AsrDataModule): - """ - LibriSpeech ASR data module. Can be used for 100h subset - (``--full-libri false``) or full 960h set. - The train and valid cuts for standard Libri splits are - concatenated into a single CutSet/DataLoader. - """ - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - super().add_arguments(parser) - group = parser.add_argument_group(title="LibriSpeech specific options") - group.add_argument( - "--full-libri", - type=str2bool, - default=True, - help="When enabled, use 960h LibriSpeech.", - ) - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest( - self.args.feature_dir / "cuts_train-clean-100.json.gz" - ) - if self.args.full_libri: - cuts_train = ( - cuts_train - + load_manifest( - self.args.feature_dir / "cuts_train-clean-360.json.gz" - ) - + load_manifest( - self.args.feature_dir / "cuts_train-other-500.json.gz" - ) - ) - return cuts_train - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - cuts_valid = load_manifest( - self.args.feature_dir / "cuts_dev-clean.json.gz" - ) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz") - return cuts_valid - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - test_sets = ["test-clean", "test-other"] - cuts = [] - for test_set in test_sets: - logging.debug("About to get test cuts") - cuts.append( - load_manifest( - self.args.feature_dir / f"cuts_{test_set}.json.gz" - ) - ) - return cuts From 8a8bf67faf201469b71ff6a87b8dde6aee39f910 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 21 Aug 2021 08:15:40 +0800 Subject: [PATCH 2/3] Fixes after review. --- .../ASR/conformer_ctc/asr_datamodule.py | 363 +----------------- egs/librispeech/ASR/conformer_ctc/decode.py | 7 +- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 48 +-- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 7 +- icefall/checkpoint.py | 2 +- 5 files changed, 22 insertions(+), 405 deletions(-) mode change 100644 => 120000 egs/librispeech/ASR/conformer_ctc/asr_datamodule.py diff --git a/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py b/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py deleted file mode 100644 index b3bd823ff7..0000000000 --- a/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py +++ /dev/null @@ -1,362 +0,0 @@ -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import List, Union - -from lhotse import CutSet, Fbank, FbankConfig, load_manifest -from lhotse.dataset import ( - BucketingSampler, - CutConcatenate, - CutMix, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SingleCutSampler, - SpecAugment, -) -from lhotse.dataset.dataloading import LhotseDataLoader -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool - - -class LibriSpeechAsrDataModule(DataModule): - """ - DataModule for K2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - super().add_arguments(parser) - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--full-libri", - type=str2bool, - default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", - ) - group.add_argument( - "--feature-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=500.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=False, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=True, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--num-workers-inner", - type=int, - default=8, - help="The number of sub-workers (replicated for each of " - "training dataloader workers) that parallelize " - "the I/O to collect each batch.", - ) - - def train_dataloaders(self) -> DataLoader: - logging.info("About to get train cuts") - cuts_train = self.train_cuts() - - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") - - logging.info("About to create train dataset") - transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [ - SpecAugment( - num_frame_masks=2, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ] - - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)), - num_workers=self.args.num_workers_inner, - ), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using BucketingSampler.") - train_sampler = BucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - bucket_method="equal_duration", - drop_last=True, - ) - else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - # train_dl = DataLoader( - # train, - # sampler=train_sampler, - # batch_size=None, - # num_workers=2, - # persistent_workers=False, - # ) - - train_dl = LhotseDataLoader( - train, - sampler=train_sampler, - num_workers=self.args.num_workers, - prefetch_factor=5, - ) - - return train_dl - - def valid_dataloaders(self) -> DataLoader: - logging.info("About to get dev cuts") - cuts_valid = self.valid_cuts() - - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = SingleCutSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - # valid_dl = DataLoader( - # validate, - # sampler=valid_sampler, - # batch_size=None, - # num_workers=2, - # persistent_workers=False, - # ) - - valid_dl = LhotseDataLoader( - validate, - sampler=valid_sampler, - num_workers=2, - ) - - return valid_dl - - def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: - cuts = self.test_cuts() - is_list = isinstance(cuts, list) - test_loaders = [] - if not is_list: - cuts = [cuts] - - for cuts_test in cuts: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80), num_workers=4) - if self.args.on_the_fly_feats - else PrecomputedFeatures() - ), - return_cuts=self.args.return_cuts, - ) - sampler = SingleCutSampler( - cuts_test, max_duration=self.args.max_duration - ) - logging.debug("About to create test dataloader") - # test_dl = DataLoader( - # test, batch_size=None, sampler=sampler, num_workers=1 - # ) - test_dl = LhotseDataLoader(test, sampler=sampler, num_workers=2) - test_loaders.append(test_dl) - - if is_list: - return test_loaders - else: - return test_loaders[0] - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest( - self.args.feature_dir / "cuts_train-clean-100.json.gz" - ) - if self.args.full_libri: - cuts_train = ( - cuts_train - + load_manifest( - self.args.feature_dir / "cuts_train-clean-360.json.gz" - ) - + load_manifest( - self.args.feature_dir / "cuts_train-other-500.json.gz" - ) - ) - return cuts_train - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - cuts_valid = load_manifest( - self.args.feature_dir / "cuts_dev-clean.json.gz" - ) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz") - return cuts_valid - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - test_sets = ["test-clean", "test-other"] - cuts = [] - for test_set in test_sets: - logging.debug("About to get test cuts") - cuts.append( - load_manifest( - self.args.feature_dir / f"cuts_{test_set}.json.gz" - ) - ) - return cuts diff --git a/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py b/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py new file mode 120000 index 0000000000..fa1b8cca3c --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py @@ -0,0 +1 @@ +../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 77f35253eb..c540b1ea1e 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -321,7 +321,7 @@ def decode_dataset( try: num_batches = len(dl) except TypeError: - num_batches = None + num_batches = "?" results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -350,10 +350,7 @@ def decode_dataset( num_cuts += len(batch["supervisions"]["text"]) if batch_idx % 100 == 0: - if num_batches is not None: - batch_str = f"{batch_idx}/{num_batches}" - else: - batch_str = f"{batch_idx}" + batch_str = f"{batch_idx}/{num_batches}" logging.info( f"batch {batch_str}, cuts processed until now is {num_cuts}" diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index b3bd823ff7..748b9541cc 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -14,7 +14,6 @@ SingleCutSampler, SpecAugment, ) -from lhotse.dataset.dataloading import LhotseDataLoader from lhotse.dataset.input_strategies import OnTheFlyFeatures from torch.utils.data import DataLoader @@ -87,7 +86,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): group.add_argument( "--concatenate-cuts", type=str2bool, - default=True, + default=False, help="When enabled, utterances (cuts) will be concatenated " "to minimize the amount of padding.", ) @@ -199,8 +198,7 @@ def train_dataloaders(self) -> DataLoader: train = K2SpeechRecognitionDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)), - num_workers=self.args.num_workers_inner, + Fbank(FbankConfig(num_mel_bins=80)) ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, @@ -225,19 +223,12 @@ def train_dataloaders(self) -> DataLoader: ) logging.info("About to create train dataloader") - # train_dl = DataLoader( - # train, - # sampler=train_sampler, - # batch_size=None, - # num_workers=2, - # persistent_workers=False, - # ) - - train_dl = LhotseDataLoader( + train_dl = DataLoader( train, sampler=train_sampler, - num_workers=self.args.num_workers, - prefetch_factor=5, + batch_size=None, + num_workers=2, + persistent_workers=False, ) return train_dl @@ -274,18 +265,12 @@ def valid_dataloaders(self) -> DataLoader: shuffle=False, ) logging.info("About to create dev dataloader") - # valid_dl = DataLoader( - # validate, - # sampler=valid_sampler, - # batch_size=None, - # num_workers=2, - # persistent_workers=False, - # ) - - valid_dl = LhotseDataLoader( + valid_dl = DataLoader( validate, sampler=valid_sampler, + batch_size=None, num_workers=2, + persistent_workers=False, ) return valid_dl @@ -301,20 +286,19 @@ def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80), num_workers=4) - if self.args.on_the_fly_feats - else PrecomputedFeatures() - ), + Fbank(FbankConfig(num_mel_bins=80)) + ) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) sampler = SingleCutSampler( cuts_test, max_duration=self.args.max_duration ) logging.debug("About to create test dataloader") - # test_dl = DataLoader( - # test, batch_size=None, sampler=sampler, num_workers=1 - # ) - test_dl = LhotseDataLoader(test, sampler=sampler, num_workers=2) + test_dl = DataLoader( + test, batch_size=None, sampler=sampler, num_workers=1 + ) test_loaders.append(test_dl) if is_list: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 2aca804fae..72f39ef408 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -240,7 +240,7 @@ def decode_dataset( try: num_batches = len(dl) except TypeError: - num_batches = None + num_batches = "?" results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -267,10 +267,7 @@ def decode_dataset( num_cuts += len(batch["supervisions"]["text"]) if batch_idx % 100 == 0: - if num_batches is not None: - batch_str = f"{batch_idx}/{num_batches}" - else: - batch_str = f"{batch_idx}" + batch_str = f"{batch_idx}/{num_batches}" logging.info( f"batch {batch_str}, cuts processed until now is {num_cuts}" diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index e45df4fe47..a64ecfcf67 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -91,7 +91,7 @@ def load_checkpoint( checkpoint.pop("model") def load(name, obj): - s = checkpoint[name] + s = checkpoint.get(name, None) if obj and s: obj.load_state_dict(s) checkpoint.pop(name) From ed16585c58ea8413d5d26dc92a9bb4b9dc5d79f0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 21 Aug 2021 08:25:34 +0800 Subject: [PATCH 3/3] Minor fixes. --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 748b9541cc..8d8c7a3667 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -137,15 +137,6 @@ def add_arguments(cls, parser: argparse.ArgumentParser): "collect the batches.", ) - group.add_argument( - "--num-workers-inner", - type=int, - default=8, - help="The number of sub-workers (replicated for each of " - "training dataloader workers) that parallelize " - "the I/O to collect each batch.", - ) - def train_dataloaders(self) -> DataLoader: logging.info("About to get train cuts") cuts_train = self.train_cuts() @@ -227,7 +218,7 @@ def train_dataloaders(self) -> DataLoader: train, sampler=train_sampler, batch_size=None, - num_workers=2, + num_workers=self.args.num_workers, persistent_workers=False, )