Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recipe for the yes_no dataset. #16

Merged
merged 11 commits into from
Aug 23, 2021
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class LibriSpeechAsrDataModule(DataModule):
"""
DataModule for K2 ASR experiments.
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def main():
logging.info(f"device: {device}")

HLG = k2.Fsa.from_dict(
torch.load("data/lang_phone/HLG.pt", map_location="cpu")
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
Expand Down
134 changes: 134 additions & 0 deletions egs/yesno/ASR/local/compile_hlg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#!/usr/bin/env python3

"""
This script takes as input lang_dir and generates HLG from

- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from lang_dir/L_disambig.pt

Caution: We use a lexicon that contains disambiguation symbols

- G, the LM, built from data/lm/G.fst.txt

The generated HLG is saved in $lang_dir/HLG.pt
"""
import argparse
import logging
from pathlib import Path

import k2
import torch

from icefall.lexicon import Lexicon


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)

return parser.parse_args()


def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.

Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))

logging.info("Loading G.fst.txt")
with open("data/lm/G.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)

first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to make the following k2 operations run on GPU if there are devices available?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the yesno dataset, the graphs are tiny. It's ok to run them on CPU.

For the librispeech dataset, I think it's worthwhile to have some benchmarks. If GPU is faster, we can switch to it.

L = k2.arc_sort(L)
G = k2.arc_sort(G)

logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")

logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")

logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")

LG = k2.determinize(LG)
logging.info(type(LG.aux_labels))

logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)

logging.info("Removing disambiguation symbols on LG")

LG.labels[LG.labels >= first_token_disambig_id] = 0

assert isinstance(LG.aux_labels, k2.RaggedInt)
LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0

LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")

LG = k2.connect(LG)
LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)

logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")

logging.info("Connecting LG")
HLG = k2.connect(HLG)

logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
logging.info(f"HLG.shape: {HLG.shape}")

return HLG


def main():
args = get_args()
lang_dir = Path(args.lang_dir)

if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return

logging.info(f"Processing {lang_dir}")

HLG = compile_HLG(lang_dir)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")


if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)

logging.basicConfig(format=formatter, level=logging.INFO)

main()
80 changes: 80 additions & 0 deletions egs/yesno/ASR/local/compute_fbank_yesno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3

"""
This file computes fbank features of the yesno dataset.
Its looks for manifests in the directory data/manifests.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its -> It ?


The generated fbank features are saved in data/fbank.
"""

import logging
import os
from pathlib import Path

import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
from lhotse.recipes.utils import read_manifests_if_cached

from icefall.utils import get_executor

# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
# slow things down. Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)


def compute_fbank_yesno():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")

# This dataset is rather small, so we use only one job
num_jobs = min(1, os.cpu_count())
num_mel_bins = 23

dataset_parts = (
"train",
"test",
)
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, output_dir=src_dir
)
assert manifests is not None

extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))

with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.json.gz").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 1, # use one job
executor=ex,
storage_type=LilcomHdf5Writer,
)
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")


if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)

logging.basicConfig(format=formatter, level=logging.INFO)

compute_fbank_yesno()
Loading