Skip to content

Commit

Permalink
support to compute mean symbol delay with word-level alignments
Browse files Browse the repository at this point in the history
  • Loading branch information
yaozengwei committed Oct 6, 2022
1 parent c0379c6 commit 6494e0f
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 25 deletions.
3 changes: 2 additions & 1 deletion egs/librispeech/ASR/local/add_alignment_librispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

"""
This file adds alignments from https://github.com/CorentinJ/librispeech-alignments # noqa
to the existing fbank features dir data/fbank.
to the existing fbank features dir (e.g., data/fbank)
and save cuts to a new dir (e.g., data/fbank_ali).
"""

import argparse
Expand Down
56 changes: 48 additions & 8 deletions egs/librispeech/ASR/lstm_transducer_stateless3/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,9 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float]]]]:
) -> Dict[
str, List[Tuple[str, List[str], List[str], List[float], List[float]]]
]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -546,6 +548,18 @@ def decode_dataset(
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]

timestamps_ref = []
for cut in batch["supervisions"]["cut"]:
for s in cut.supervisions:
time = []
if s.alignment is not None and "word" in s.alignment:
time = [
aliword.start
for aliword in s.alignment["word"]
if aliword.symbol != ""
]
timestamps_ref.append(time)

hyps_dict = decode_one_batch(
params=params,
model=model,
Expand All @@ -555,14 +569,18 @@ def decode_dataset(
batch=batch,
)

for name, (hyps, timestamps) in hyps_dict.items():
for name, (hyps, timestamps_hyp) in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts) and len(timestamps) == len(texts)
for cut_id, hyp_words, ref_text, time in zip(
cut_ids, hyps, texts, timestamps
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
timestamps_ref
)
for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words, time))
this_batch.append(
(cut_id, ref_words, hyp_words, time_ref, time_hyp)
)

results[name].extend(this_batch)

Expand All @@ -581,10 +599,12 @@ def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[
str, List[Tuple[List[str], List[str], List[str], List[float]]]
str,
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
],
):
test_set_wers = dict()
test_set_delays = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
Expand All @@ -599,10 +619,11 @@ def save_results(
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats_with_timestamps(
wer, delay = write_error_stats_with_timestamps(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
test_set_delays[key] = delay

logging.info("Wrote detailed error stats to {}".format(errs_filename))

Expand All @@ -616,13 +637,32 @@ def save_results(
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)

test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1])
delays_info = (
params.res_dir
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(delays_info, "w") as f:
print("settings\tsymbol-delay", file=f)
for key, val in test_set_delays:
print("{}\t{}".format(key, val), file=f)

s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)

s = "\nFor {}, symbol-delay of different settings are:\n".format(
test_set_name
)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_delays:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)


@torch.no_grad()
def main():
Expand Down
71 changes: 55 additions & 16 deletions icefall/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,9 @@ def store_transcripts(

def store_transcripts_and_timestamps(
filename: Pathlike,
texts: Iterable[Tuple[str, List[str], List[str], List[float]]],
texts: Iterable[Tuple[str, List[str], List[str], List[float], List[float]]],
) -> None:
"""Save predicted results with timestamps and reference transcripts
"""Save predicted results and reference transcripts as well as their timestamps
to a file.
Args:
Expand All @@ -450,10 +450,14 @@ def store_transcripts_and_timestamps(
Return None.
"""
with open(filename, "w") as f:
for cut_id, ref, hyp, timestamp in texts:
for cut_id, ref, hyp, time_ref, time_hyp in texts:
print(f"{cut_id}:\tref={ref}", file=f)
print(f"{cut_id}:\thyp={hyp}", file=f)
print(f"{cut_id}:\ttimestamp={timestamp}", file=f)
if len(time_ref) > 0:
s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]"
print(f"{cut_id}:\ttimestamp_ref={s}", file=f)
s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)


def write_error_stats(
Expand Down Expand Up @@ -626,11 +630,11 @@ def write_error_stats(
def write_error_stats_with_timestamps(
f: TextIO,
test_set_name: str,
results: List[Tuple[str, List[str], List[str], List[float]]],
results: List[Tuple[str, List[str], List[str], List[float], List[float]]],
enable_log: bool = True,
) -> float:
"""Write statistics based on predicted results with timestamps
and reference transcripts.
) -> Tuple[float, float]:
"""Write statistics based on predicted results and reference transcripts
as well as their timestamps.
It will write the following to the given file:
Expand Down Expand Up @@ -661,8 +665,9 @@ def write_error_stats_with_timestamps(
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns:
Return None.
Return total word error rate and mean delay.
"""
subs: Dict[Tuple[str, str], int] = defaultdict(int)
ins: Dict[str, int] = defaultdict(int)
Expand All @@ -673,35 +678,68 @@ def write_error_stats_with_timestamps(
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
for cut_id, ref, hyp, timestamp in results:
# Compute mean alignment delay on the correct words
all_delay = []
for cut_id, ref, hyp, time_ref, time_hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
has_time_ref = len(time_ref) > 0
if has_time_ref:
# pointer to timestamp_hyp
p_hyp = 0
# pointer to timestamp_ref
p_ref = 0
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
if has_time_ref:
p_hyp += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
if has_time_ref:
p_ref += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
if has_time_ref:
p_hyp += 1
p_ref += 1
else:
words[ref_word][0] += 1
num_corr += 1
ref_len = sum([len(r) for _, r, _, _ in results])
if has_time_ref:
all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
p_hyp += 1
p_ref += 1
if has_time_ref:
assert p_hyp == len(hyp), (p_hyp, len(hyp))
assert p_ref == len(ref), (p_ref, len(ref))

ref_len = sum([len(r) for _, r, _, _, _ in results])
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)

mean_delay = "inf"
sum_delay = sum(all_delay)
num_delay = len(all_delay)
if num_delay > 0:
mean_delay = "%.3f" % (sum_delay / num_delay)

if enable_log:
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
logging.info(
f"[{test_set_name}] %symbol-delay {mean_delay} "
f"computed on {num_delay} words"
)

print(f"%WER = {tot_err_rate}", file=f)
print(
Expand All @@ -718,7 +756,7 @@ def write_error_stats_with_timestamps(

print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
for cut_id, ref, hyp, timestamp in results:
for cut_id, ref, hyp, _, _ in results:
ali = kaldialign.align(ref, hyp, ERR)
combine_successive_errors = True
if combine_successive_errors:
Expand Down Expand Up @@ -788,7 +826,7 @@ def write_error_stats_with_timestamps(
hyp_count = corr + hyp_sub + ins

print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)
return float(tot_err_rate), float(mean_delay)


class MetricsTracker(collections.defaultdict):
Expand Down Expand Up @@ -1292,9 +1330,9 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
start_token = b"\xe2\x96\x81".decode() # '_'
assert len(tokens) == len(timestamp)
ans = []
for token, start_time in zip(tokens, timestamp):
if token.startswith(start_token):
ans.append(start_time)
for i in range(len(tokens)):
if i == 0 or tokens[i].startswith(start_token):
ans.append(timestamp[i])
return ans


Expand Down Expand Up @@ -1362,6 +1400,7 @@ def parse_hyp_and_timestamp(
res.timestamps[i], subsampling_factor, frame_shift_ms
)
time = parse_timestamp(tokens, time)
assert len(time) == len(words), (tokens, words)

hyps.append(words)
timestamps.append(time)
Expand Down

0 comments on commit 6494e0f

Please sign in to comment.