Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement linear_fsa_with_self_loops. #940

Merged
merged 2 commits into from
Mar 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions k2/python/k2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .fsa_algo import levenshtein_alignment
from .fsa_algo import levenshtein_graph
from .fsa_algo import linear_fsa
from .fsa_algo import linear_fsa_with_self_loops
from .fsa_algo import linear_fst
from .fsa_algo import prune_on_arc_post
from .fsa_algo import random_paths
Expand Down
44 changes: 39 additions & 5 deletions k2/python/k2/fsa_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,37 @@ def linear_fsa(labels: Union[List[int], List[List[int]], k2.RaggedTensor],
return fsa


def linear_fsa_with_self_loops(fsas: k2.Fsa):
'''Create a linear FSA with epsilon self-loops by first removing epsilon
transitions from the input linear FSA.

Args:
fsas:
An FSA or an FsaVec. It MUST be a linear FSA or a vector of linear FSAs.
Returns:
Return an FSA or FsaVec, where each FSA contains epsilon self-loops but
contains no epsilon transitions for arcs that are not self-loops.
'''
if len(fsas.shape) == 2:
# A single FSA
device = fsas.device
shape0 = _k2.RaggedShape.regular_ragged_shape(dim0=1,
dim1=fsas.shape[0])
shape = shape0.to(device).compose(fsas.arcs.shape())
else:
shape = fsas.arcs.shape()

shape = shape.remove_axis(1) # remove the state axis

labels = k2.RaggedTensor(shape, fsas.labels.contiguous())
labels = labels.remove_values_leq(0)
ans = add_epsilon_self_loops(linear_fsa(labels))

if len(fsas.shape) == 2:
ans = ans[0]
return ans


def linear_fst(labels: Union[List[int], List[List[int]]],
aux_labels: Union[List[int], List[List[int]]]) -> Fsa:
'''Construct a linear FST from labels and its corresponding
Expand Down Expand Up @@ -1191,16 +1222,18 @@ def levenshtein_alignment(

hyps.rename_tensor_attribute_("aux_labels", "hyp_labels")

lattice = k2.intersect_device(
refs, hyps, b_to_a_map=hyp_to_ref_map, sorted_match_a=sorted_match_ref)
lattice = k2.intersect_device(refs,
hyps,
b_to_a_map=hyp_to_ref_map,
sorted_match_a=sorted_match_ref)
lattice = k2.remove_epsilon_self_loops(lattice)

alignment = k2.shortest_path(lattice, use_double_scores=True).invert_()
alignment.rename_tensor_attribute_("labels", "ref_labels")
alignment.rename_tensor_attribute_("aux_labels", "labels")

alignment.scores -= getattr(
alignment, "__ins_del_score_offset_internal_attr_")
alignment.scores -= getattr(alignment,
"__ins_del_score_offset_internal_attr_")

return alignment

Expand All @@ -1222,5 +1255,6 @@ def union(fsas: Fsa) -> Fsa:
need_arc_map = True
ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map)

out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc, arc_map)
out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc,
arc_map)
return out_fsa
1 change: 1 addition & 0 deletions k2/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ set(py_test_files
levenshtein_alignment_test.py
levenshtein_graph_test.py
linear_fsa_test.py
linear_fsa_with_self_loops_test.py
linear_fst_test.py
multi_gpu_test.py
mutual_information_test.py
Expand Down
63 changes: 63 additions & 0 deletions k2/python/tests/linear_fsa_with_self_loops_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# To run this single test, use
#
# ctest --verbose -R linear_fsa_self_loops_test_py

import torch
import k2
import unittest


class TestLinearFsa(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.devices = [torch.device('cpu')]
if torch.cuda.is_available() and k2.with_cuda:
cls.devices.append(torch.device('cuda', 0))
if torch.cuda.device_count() > 1:
torch.cuda.set_device(1)
cls.devices.append(torch.device('cuda', 1))

def test_single_fsa(self):
for device in self.devices:
labels = [2, 0, 0, 0, 5, 8]
src = k2.linear_fsa(labels, device)
dst = k2.linear_fsa_with_self_loops(src)
assert src.device == dst.device
expected_labels = [0, 2, 0, 5, 0, 8, 0, -1]
assert dst.labels.tolist() == expected_labels

def test_multiple_fsa(self):
for device in self.devices:
labels = [[2, 0, 0, 0, 5, 0, 0, 0, 8, 0, 0], [1, 2],
[0, 0, 0, 3, 0, 2]]
src = k2.linear_fsa(labels, device)
dst = k2.linear_fsa_with_self_loops(src)
assert src.device == dst.device
expected_labels0 = [0, 2, 0, 5, 0, 8, 0, -1]
expected_labels1 = [0, 1, 0, 2, 0, -1]
expected_labels2 = [0, 3, 0, 2, 0, -1]
expected_labels = expected_labels0 + expected_labels1 + expected_labels2
assert dst.labels.tolist() == expected_labels


if __name__ == '__main__':
unittest.main()