diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index 1091f4348..930affb18 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -41,6 +41,7 @@ from .fsa_algo import levenshtein_alignment from .fsa_algo import levenshtein_graph from .fsa_algo import linear_fsa +from .fsa_algo import linear_fsa_with_self_loops from .fsa_algo import linear_fst from .fsa_algo import prune_on_arc_post from .fsa_algo import random_paths diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 8e44f2e93..34ff9ca06 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -68,6 +68,37 @@ def linear_fsa(labels: Union[List[int], List[List[int]], k2.RaggedTensor], return fsa +def linear_fsa_with_self_loops(fsas: k2.Fsa): + '''Create a linear FSA with epsilon self-loops by first removing epsilon + transitions from the input linear FSA. + + Args: + fsas: + An FSA or an FsaVec. It MUST be a linear FSA or a vector of linear FSAs. + Returns: + Return an FSA or FsaVec, where each FSA contains epsilon self-loops but + contains no epsilon transitions for arcs that are not self-loops. + ''' + if len(fsas.shape) == 2: + # A single FSA + device = fsas.device + shape0 = _k2.RaggedShape.regular_ragged_shape(dim0=1, + dim1=fsas.shape[0]) + shape = shape0.to(device).compose(fsas.arcs.shape()) + else: + shape = fsas.arcs.shape() + + shape = shape.remove_axis(1) # remove the state axis + + labels = k2.RaggedTensor(shape, fsas.labels.contiguous()) + labels = labels.remove_values_leq(0) + ans = add_epsilon_self_loops(linear_fsa(labels)) + + if len(fsas.shape) == 2: + ans = ans[0] + return ans + + def linear_fst(labels: Union[List[int], List[List[int]]], aux_labels: Union[List[int], List[List[int]]]) -> Fsa: '''Construct a linear FST from labels and its corresponding @@ -1191,16 +1222,18 @@ def levenshtein_alignment( hyps.rename_tensor_attribute_("aux_labels", "hyp_labels") - lattice = k2.intersect_device( - refs, hyps, b_to_a_map=hyp_to_ref_map, sorted_match_a=sorted_match_ref) + lattice = k2.intersect_device(refs, + hyps, + b_to_a_map=hyp_to_ref_map, + sorted_match_a=sorted_match_ref) lattice = k2.remove_epsilon_self_loops(lattice) alignment = k2.shortest_path(lattice, use_double_scores=True).invert_() alignment.rename_tensor_attribute_("labels", "ref_labels") alignment.rename_tensor_attribute_("aux_labels", "labels") - alignment.scores -= getattr( - alignment, "__ins_del_score_offset_internal_attr_") + alignment.scores -= getattr(alignment, + "__ins_del_score_offset_internal_attr_") return alignment @@ -1222,5 +1255,6 @@ def union(fsas: Fsa) -> Fsa: need_arc_map = True ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map) - out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc, arc_map) + out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc, + arc_map) return out_fsa diff --git a/k2/python/tests/CMakeLists.txt b/k2/python/tests/CMakeLists.txt index 57525979e..cde9a5382 100644 --- a/k2/python/tests/CMakeLists.txt +++ b/k2/python/tests/CMakeLists.txt @@ -52,6 +52,7 @@ set(py_test_files levenshtein_alignment_test.py levenshtein_graph_test.py linear_fsa_test.py + linear_fsa_with_self_loops_test.py linear_fst_test.py multi_gpu_test.py mutual_information_test.py diff --git a/k2/python/tests/linear_fsa_with_self_loops_test.py b/k2/python/tests/linear_fsa_with_self_loops_test.py new file mode 100644 index 000000000..1e331bbbc --- /dev/null +++ b/k2/python/tests/linear_fsa_with_self_loops_test.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R linear_fsa_self_loops_test_py + +import torch +import k2 +import unittest + + +class TestLinearFsa(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.devices = [torch.device('cpu')] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device('cuda', 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device('cuda', 1)) + + def test_single_fsa(self): + for device in self.devices: + labels = [2, 0, 0, 0, 5, 8] + src = k2.linear_fsa(labels, device) + dst = k2.linear_fsa_with_self_loops(src) + assert src.device == dst.device + expected_labels = [0, 2, 0, 5, 0, 8, 0, -1] + assert dst.labels.tolist() == expected_labels + + def test_multiple_fsa(self): + for device in self.devices: + labels = [[2, 0, 0, 0, 5, 0, 0, 0, 8, 0, 0], [1, 2], + [0, 0, 0, 3, 0, 2]] + src = k2.linear_fsa(labels, device) + dst = k2.linear_fsa_with_self_loops(src) + assert src.device == dst.device + expected_labels0 = [0, 2, 0, 5, 0, 8, 0, -1] + expected_labels1 = [0, 1, 0, 2, 0, -1] + expected_labels2 = [0, 3, 0, 2, 0, -1] + expected_labels = expected_labels0 + expected_labels1 + expected_labels2 + assert dst.labels.tolist() == expected_labels + + +if __name__ == '__main__': + unittest.main()