Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor asr_datamodule. #15

Merged
merged 3 commits into from
Aug 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions egs/librispeech/ASR/conformer_ctc/asr_datamodule.py
16 changes: 10 additions & 6 deletions egs/librispeech/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
import k2
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer

from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import (
get_lattice,
nbest_decoding,
Expand Down Expand Up @@ -222,7 +222,7 @@ def decode_one_batch(
use_double_scores=params.use_double_scores,
scale=params.lattice_score_scale,
)
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}"
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa

hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
Expand Down Expand Up @@ -317,7 +317,11 @@ def decode_dataset(
results = []

num_cuts = 0
tot_num_batches = len(dl)

try:
num_batches = len(dl)
except TypeError:
num_batches = "?"

results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
Expand Down Expand Up @@ -346,10 +350,10 @@ def decode_dataset(
num_cuts += len(batch["supervisions"]["text"])

if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"

logging.info(
f"batch {batch_idx}/{tot_num_batches}, cuts processed until now is "
f"{num_cuts}"
f"batch {batch_idx}, cuts processed until now is {num_cuts}"
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results

Expand Down
8 changes: 2 additions & 6 deletions egs/librispeech/ASR/conformer_ctc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam

from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon
from icefall.utils import (
Expand Down Expand Up @@ -61,9 +60,6 @@ def get_parser():
help="Should various information be logged in tensorboard.",
)

# TODO: add extra arguments and support DDP training.
# Currently, only single GPU training is implemented. Will add
# DDP training once single GPU training is finished.
return parser


Expand Down Expand Up @@ -463,7 +459,7 @@ def train_one_epoch(

optimizer.zero_grad()
loss.backward()
clip_grad_value_(model.parameters(), 5.0)
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()

loss_cpu = loss.detach().cpu().item()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import argparse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it makes sense -- but maybe it's sufficient to have a single copy of this script one level of directories up, and if any recipe requires non-standard processing, it would make it's own copy at the "current" directory level?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about putting a symlink to other model directories to this file?
I was thinking that each model is as self-contained as possible.
If someone wants to modify this file, he/she can replace the symlink with a copy of this file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that makes sense to me

import logging
from functools import lru_cache
from pathlib import Path
from typing import List, Union

from lhotse import Fbank, FbankConfig, load_manifest
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
CutMix,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
Expand All @@ -19,7 +21,7 @@
from icefall.utils import str2bool


class AsrDataModule(DataModule):
class LibriSpeechAsrDataModule(DataModule):
"""
DataModule for K2 ASR experiments.
It assumes there is always one train and valid dataloader,
Expand Down Expand Up @@ -47,6 +49,13 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech. "
"Otherwise, use 100h subset.",
)
group.add_argument(
"--feature-dir",
type=Path,
Expand Down Expand Up @@ -77,7 +86,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=True,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
Expand All @@ -104,6 +113,29 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)

group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)

def train_dataloaders(self) -> DataLoader:
logging.info("About to get train cuts")
Expand Down Expand Up @@ -138,9 +170,9 @@ def train_dataloaders(self) -> DataLoader:
]

train = K2SpeechRecognitionDataset(
cuts_train,
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)

if self.args.on_the_fly_feats:
Expand All @@ -154,61 +186,74 @@ def train_dataloaders(self) -> DataLoader:
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
cuts_train = cuts_train.drop_features()
train = K2SpeechRecognitionDataset(
cuts=cuts_train,
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)

if self.args.bucketing_sampler:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=True,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method='equal_duration',
bucket_method="equal_duration",
drop_last=True,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=True,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")

train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=2,
num_workers=self.args.num_workers,
persistent_workers=False,
)

return train_dl

def valid_dataloaders(self) -> DataLoader:
logging.info("About to get dev cuts")
cuts_valid = self.valid_cuts()

transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms

logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
cuts_valid = cuts_valid.drop_features()
validate = K2SpeechRecognitionDataset(
cuts_valid.drop_features(),
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(cuts_valid)
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = SingleCutSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
Expand All @@ -218,6 +263,7 @@ def valid_dataloaders(self) -> DataLoader:
num_workers=2,
persistent_workers=False,
)

return valid_dl

def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
Expand All @@ -230,10 +276,12 @@ def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
for cuts_test in cuts:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
cuts_test,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
)
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = SingleCutSampler(
cuts_test, max_duration=self.args.max_duration
Expand All @@ -248,3 +296,42 @@ def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
return test_loaders
else:
return test_loaders[0]

@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(
self.args.feature_dir / "cuts_train-clean-100.json.gz"
)
if self.args.full_libri:
cuts_train = (
cuts_train
+ load_manifest(
self.args.feature_dir / "cuts_train-clean-360.json.gz"
)
+ load_manifest(
self.args.feature_dir / "cuts_train-other-500.json.gz"
)
)
return cuts_train

@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(
self.args.feature_dir / "cuts_dev-clean.json.gz"
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
return cuts_valid

@lru_cache()
def test_cuts(self) -> List[CutSet]:
test_sets = ["test-clean", "test-other"]
cuts = []
for test_set in test_sets:
logging.debug("About to get test cuts")
cuts.append(
load_manifest(
self.args.feature_dir / f"cuts_{test_set}.json.gz"
)
)
return cuts
11 changes: 9 additions & 2 deletions egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import k2
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from model import TdnnLstm

from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import (
get_lattice,
nbest_decoding,
Expand Down Expand Up @@ -237,6 +237,11 @@ def decode_dataset(

num_cuts = 0

try:
num_batches = len(dl)
except TypeError:
num_batches = "?"

results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
Expand All @@ -262,8 +267,10 @@ def decode_dataset(
num_cuts += len(batch["supervisions"]["text"])

if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"

logging.info(
f"batch {batch_idx}, cuts processed until now is {num_cuts}"
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results

Expand Down
Loading