Skip to content

Commit

Permalink
fix type hints for decode.py (#623)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 17, 2022
1 parent a66e74b commit d1f16a0
Show file tree
Hide file tree
Showing 43 changed files with 93 additions and 93 deletions.
4 changes: 2 additions & 2 deletions egs/aishell/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def decode_dataset(
lexicon: Lexicon,
sos_id: int,
eos_id: int,
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -410,7 +410,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
Expand Down
4 changes: 2 additions & 2 deletions egs/aishell/ASR/conformer_mmi/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def decode_dataset(
lexicon: Lexicon,
sos_id: int,
eos_id: int,
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -422,7 +422,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
Expand Down
4 changes: 2 additions & 2 deletions egs/aishell/ASR/pruned_transducer_stateless2/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def decode_dataset(
model: nn.Module,
token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -396,7 +396,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
Expand Down
4 changes: 2 additions & 2 deletions egs/aishell/ASR/pruned_transducer_stateless3/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def decode_dataset(
model: nn.Module,
token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -410,7 +410,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
Expand Down
4 changes: 2 additions & 2 deletions egs/aishell/ASR/tdnn_lstm_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def decode_dataset(
model: nn.Module,
HLG: k2.Fsa,
lexicon: Lexicon,
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -274,7 +274,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
Expand Down
4 changes: 2 additions & 2 deletions egs/aishell/ASR/transducer_stateless/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -328,7 +328,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
Expand Down
4 changes: 2 additions & 2 deletions egs/aishell/ASR/transducer_stateless_modified-2/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def decode_dataset(
model: nn.Module,
token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -374,7 +374,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
Expand Down
4 changes: 2 additions & 2 deletions egs/aishell/ASR/transducer_stateless_modified/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def decode_dataset(
model: nn.Module,
token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -378,7 +378,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
Expand Down
4 changes: 2 additions & 2 deletions egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def decode_dataset(
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -547,7 +547,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
Expand Down
4 changes: 2 additions & 2 deletions egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def decode_dataset(
model: nn.Module,
lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -410,7 +410,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
Expand Down
4 changes: 2 additions & 2 deletions egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def decode_dataset(
model: nn.Module,
lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -399,7 +399,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
Expand Down
12 changes: 6 additions & 6 deletions egs/gigaspeech/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,13 @@ def get_params() -> AttributeDict:


def post_processing(
results: List[Tuple[List[str], List[str]]],
) -> List[Tuple[List[str], List[str]]]:
results: List[Tuple[str, List[str], List[str]]],
) -> List[Tuple[str, List[str], List[str]]]:
new_results = []
for ref, hyp in results:
for key, ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref)).split()
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
new_results.append((new_ref, new_hyp))
new_results.append((key, new_ref, new_hyp))
return new_results


Expand Down Expand Up @@ -408,7 +408,7 @@ def decode_dataset(
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -502,7 +502,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
Expand Down
12 changes: 6 additions & 6 deletions egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,13 @@ def get_parser():


def post_processing(
results: List[Tuple[List[str], List[str]]],
) -> List[Tuple[List[str], List[str]]]:
results: List[Tuple[str, List[str], List[str]]],
) -> List[Tuple[str, List[str], List[str]]]:
new_results = []
for ref, hyp in results:
for key, ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref)).split()
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
new_results.append((new_ref, new_hyp))
new_results.append((key, new_ref, new_hyp))
return new_results


Expand Down Expand Up @@ -340,7 +340,7 @@ def decode_dataset(
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -407,7 +407,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
Expand Down
4 changes: 2 additions & 2 deletions egs/librispeech/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def decode_dataset(
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -577,7 +577,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]],
):
if params.method in ("attention-decoder", "rnn-lm"):
# Set it to False since there are too many logs.
Expand Down
4 changes: 2 additions & 2 deletions egs/librispeech/ASR/conformer_ctc2/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def decode_dataset(
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -684,7 +684,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
if params.method in ("attention-decoder", "rnn-lm"):
# Set it to False since there are too many logs.
Expand Down
4 changes: 2 additions & 2 deletions egs/librispeech/ASR/conformer_mmi/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def decode_dataset(
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -487,7 +487,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def decode_dataset(
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -436,7 +436,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def decode_dataset(
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -436,7 +436,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
Expand Down
4 changes: 2 additions & 2 deletions egs/librispeech/ASR/lstm_transducer_stateless/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -570,7 +570,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
Expand Down
4 changes: 2 additions & 2 deletions egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -571,7 +571,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
Expand Down
Loading

0 comments on commit d1f16a0

Please sign in to comment.