Skip to content

Commit 0f65420

Browse files
authored
Implement linear_fsa_with_self_loops. (#940)
* Implement linear_fsa_with_self_loops.
1 parent 846c39c commit 0f65420

File tree

4 files changed

+104
-5
lines changed

4 files changed

+104
-5
lines changed

k2/python/k2/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from .fsa_algo import levenshtein_alignment
4242
from .fsa_algo import levenshtein_graph
4343
from .fsa_algo import linear_fsa
44+
from .fsa_algo import linear_fsa_with_self_loops
4445
from .fsa_algo import linear_fst
4546
from .fsa_algo import prune_on_arc_post
4647
from .fsa_algo import random_paths

k2/python/k2/fsa_algo.py

+39-5
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,37 @@ def linear_fsa(labels: Union[List[int], List[List[int]], k2.RaggedTensor],
6868
return fsa
6969

7070

71+
def linear_fsa_with_self_loops(fsas: k2.Fsa):
72+
'''Create a linear FSA with epsilon self-loops by first removing epsilon
73+
transitions from the input linear FSA.
74+
75+
Args:
76+
fsas:
77+
An FSA or an FsaVec. It MUST be a linear FSA or a vector of linear FSAs.
78+
Returns:
79+
Return an FSA or FsaVec, where each FSA contains epsilon self-loops but
80+
contains no epsilon transitions for arcs that are not self-loops.
81+
'''
82+
if len(fsas.shape) == 2:
83+
# A single FSA
84+
device = fsas.device
85+
shape0 = _k2.RaggedShape.regular_ragged_shape(dim0=1,
86+
dim1=fsas.shape[0])
87+
shape = shape0.to(device).compose(fsas.arcs.shape())
88+
else:
89+
shape = fsas.arcs.shape()
90+
91+
shape = shape.remove_axis(1) # remove the state axis
92+
93+
labels = k2.RaggedTensor(shape, fsas.labels.contiguous())
94+
labels = labels.remove_values_leq(0)
95+
ans = add_epsilon_self_loops(linear_fsa(labels))
96+
97+
if len(fsas.shape) == 2:
98+
ans = ans[0]
99+
return ans
100+
101+
71102
def linear_fst(labels: Union[List[int], List[List[int]]],
72103
aux_labels: Union[List[int], List[List[int]]]) -> Fsa:
73104
'''Construct a linear FST from labels and its corresponding
@@ -1192,16 +1223,18 @@ def levenshtein_alignment(
11921223

11931224
hyps.rename_tensor_attribute_("aux_labels", "hyp_labels")
11941225

1195-
lattice = k2.intersect_device(
1196-
refs, hyps, b_to_a_map=hyp_to_ref_map, sorted_match_a=sorted_match_ref)
1226+
lattice = k2.intersect_device(refs,
1227+
hyps,
1228+
b_to_a_map=hyp_to_ref_map,
1229+
sorted_match_a=sorted_match_ref)
11971230
lattice = k2.remove_epsilon_self_loops(lattice)
11981231

11991232
alignment = k2.shortest_path(lattice, use_double_scores=True).invert_()
12001233
alignment.rename_tensor_attribute_("labels", "ref_labels")
12011234
alignment.rename_tensor_attribute_("aux_labels", "labels")
12021235

1203-
alignment.scores -= getattr(
1204-
alignment, "__ins_del_score_offset_internal_attr_")
1236+
alignment.scores -= getattr(alignment,
1237+
"__ins_del_score_offset_internal_attr_")
12051238

12061239
return alignment
12071240

@@ -1223,5 +1256,6 @@ def union(fsas: Fsa) -> Fsa:
12231256
need_arc_map = True
12241257
ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map)
12251258

1226-
out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc, arc_map)
1259+
out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc,
1260+
arc_map)
12271261
return out_fsa

k2/python/tests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ set(py_test_files
5252
levenshtein_alignment_test.py
5353
levenshtein_graph_test.py
5454
linear_fsa_test.py
55+
linear_fsa_with_self_loops_test.py
5556
linear_fst_test.py
5657
multi_gpu_test.py
5758
mutual_information_test.py
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#!/usr/bin/env python3
2+
#
3+
# Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
4+
#
5+
# See ../../../LICENSE for clarification regarding multiple authors
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
19+
# To run this single test, use
20+
#
21+
# ctest --verbose -R linear_fsa_self_loops_test_py
22+
23+
import torch
24+
import k2
25+
import unittest
26+
27+
28+
class TestLinearFsa(unittest.TestCase):
29+
30+
@classmethod
31+
def setUpClass(cls):
32+
cls.devices = [torch.device('cpu')]
33+
if torch.cuda.is_available() and k2.with_cuda:
34+
cls.devices.append(torch.device('cuda', 0))
35+
if torch.cuda.device_count() > 1:
36+
torch.cuda.set_device(1)
37+
cls.devices.append(torch.device('cuda', 1))
38+
39+
def test_single_fsa(self):
40+
for device in self.devices:
41+
labels = [2, 0, 0, 0, 5, 8]
42+
src = k2.linear_fsa(labels, device)
43+
dst = k2.linear_fsa_with_self_loops(src)
44+
assert src.device == dst.device
45+
expected_labels = [0, 2, 0, 5, 0, 8, 0, -1]
46+
assert dst.labels.tolist() == expected_labels
47+
48+
def test_multiple_fsa(self):
49+
for device in self.devices:
50+
labels = [[2, 0, 0, 0, 5, 0, 0, 0, 8, 0, 0], [1, 2],
51+
[0, 0, 0, 3, 0, 2]]
52+
src = k2.linear_fsa(labels, device)
53+
dst = k2.linear_fsa_with_self_loops(src)
54+
assert src.device == dst.device
55+
expected_labels0 = [0, 2, 0, 5, 0, 8, 0, -1]
56+
expected_labels1 = [0, 1, 0, 2, 0, -1]
57+
expected_labels2 = [0, 3, 0, 2, 0, -1]
58+
expected_labels = expected_labels0 + expected_labels1 + expected_labels2
59+
assert dst.labels.tolist() == expected_labels
60+
61+
62+
if __name__ == '__main__':
63+
unittest.main()

0 commit comments

Comments
 (0)