Skip to content

Commit 7178d67

Browse files
authored
Fix test cases for k2.union() (#853)
1 parent d061bc6 commit 7178d67

File tree

4 files changed

+53
-68
lines changed

4 files changed

+53
-68
lines changed

k2/python/k2/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#
2222
from .autograd import intersect_dense
2323
from .autograd import intersect_dense_pruned
24-
from .autograd import union
2524
from .ctc_loss import CtcLoss
2625
from .ctc_loss import ctc_loss
2726
from .dense_fsa_vec import DenseFsaVec
@@ -51,6 +50,7 @@
5150
from .fsa_algo import replace_fsa
5251
from .fsa_algo import shortest_path
5352
from .fsa_algo import top_sort
53+
from .fsa_algo import union
5454
from .fsa_properties import to_str as properties_to_str
5555
from .nbest import Nbest
5656
from .ops import cat

k2/python/k2/autograd.py

-67
Original file line numberDiff line numberDiff line change
@@ -645,53 +645,6 @@ def backward(ctx, out_fsa_grad: torch.Tensor) \
645645
)
646646

647647

648-
class _UnionFunction(torch.autograd.Function):
649-
650-
@staticmethod
651-
def forward(ctx, fsas: Fsa, out_fsa: List[Fsa],
652-
unused_fsas_scores: torch.Tensor) -> torch.Tensor:
653-
'''Compute the union of all fsas in a FsaVec.
654-
655-
Args:
656-
fsas:
657-
The input FsaVec. Caution: We require that each fsa in the FsaVec
658-
is non-empty (i.e., with at least two states).
659-
out_fsa:
660-
A list containing one entry. Since this function can only return
661-
values of type `torch.Tensor`, we return the union result in the
662-
list.
663-
unused_fsas_scores:
664-
It is the same as `fsas.scores`, whose sole purpose is for autograd.
665-
It is not used in this function.
666-
'''
667-
need_arc_map = True
668-
ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map)
669-
out_fsa[0] = Fsa(ragged_arc)
670-
671-
for name, value in fsas.named_tensor_attr(include_scores=False):
672-
value = k2.index(value, arc_map)
673-
setattr(out_fsa[0], name, value)
674-
675-
for name, value in fsas.named_non_tensor_attr():
676-
setattr(out_fsa[0], name, value)
677-
ctx.arc_map = arc_map
678-
ctx.save_for_backward(unused_fsas_scores)
679-
680-
return out_fsa[0].scores # the return value will be discarded
681-
682-
@staticmethod
683-
def backward(ctx, out_fsa_grad: torch.Tensor
684-
) -> Tuple[None, None, torch.Tensor]: # noqa
685-
arc_map = ctx.arc_map
686-
fsas_scores, = ctx.saved_tensors
687-
ans = torch.zeros(fsas_scores.size(0),
688-
dtype=torch.float32,
689-
device=fsas_scores.device,
690-
requires_grad=False)
691-
_k2.index_add(arc_map, out_fsa_grad, ans)
692-
return None, None, ans
693-
694-
695648
def intersect_dense_pruned(a_fsas: Fsa,
696649
b_fsas: DenseFsaVec,
697650
search_beam: float,
@@ -843,23 +796,3 @@ def intersect_dense(a_fsas: Fsa,
843796
a_fsas.scores, b_fsas.scores, a_to_b_map,
844797
seqframe_idx_name, frame_idx_name)
845798
return out_fsa[0]
846-
847-
848-
def union(fsas: Fsa) -> Fsa:
849-
'''Compute the union of a FsaVec.
850-
851-
Caution:
852-
We require that every fsa in fsas is non-empty, i.e.,
853-
contains at least two states
854-
855-
Args:
856-
fsas:
857-
A FsaVec. That is, len(fsas.shape) == 3.
858-
859-
Returns:
860-
A single Fsa that is the union of the input fsas.
861-
'''
862-
863-
out_fsa = [0] # as a placeholder
864-
_UnionFunction.apply(fsas, out_fsa, fsas.scores)
865-
return out_fsa[0]

k2/python/k2/fsa_algo.py

+21
Original file line numberDiff line numberDiff line change
@@ -1179,3 +1179,24 @@ def levenshtein_alignment(
11791179
alignment, "__ins_del_score_offset_internal_attr_")
11801180

11811181
return alignment
1182+
1183+
1184+
def union(fsas: Fsa) -> Fsa:
1185+
'''Compute the union of a FsaVec.
1186+
1187+
Caution:
1188+
We require that every fsa in fsas is non-empty, i.e.,
1189+
contains at least two states
1190+
1191+
Args:
1192+
fsas:
1193+
A FsaVec. That is, len(fsas.shape) == 3.
1194+
1195+
Returns:
1196+
A single Fsa that is the union of the input fsas.
1197+
'''
1198+
need_arc_map = True
1199+
ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map)
1200+
1201+
out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc, arc_map)
1202+
return out_fsa

k2/python/tests/union_test.py

+31
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,40 @@ def test(self):
6363
fsa1 = k2.Fsa.from_str(s1)
6464
fsa2 = k2.Fsa.from_str(s2)
6565

66+
fsa0.tensor_attr = torch.tensor([1, 2, 3, 4, 5, 6],
67+
dtype=torch.int32,
68+
device=device)
69+
fsa0.ragged_tensor_attr = k2.RaggedTensor(
70+
fsa0.tensor_attr.unsqueeze(-1))
71+
72+
fsa1.tensor_attr = torch.tensor([7],
73+
dtype=torch.int32,
74+
device=device)
75+
76+
fsa1.ragged_tensor_attr = k2.RaggedTensor(
77+
fsa1.tensor_attr.unsqueeze(-1))
78+
79+
fsa2.tensor_attr = torch.tensor([8, 9, 10, 11],
80+
dtype=torch.int32,
81+
device=device)
82+
83+
fsa2.ragged_tensor_attr = k2.RaggedTensor(
84+
fsa2.tensor_attr.unsqueeze(-1))
85+
6686
fsa_vec = k2.create_fsa_vec([fsa0, fsa1, fsa2]).to(device)
6787

6888
fsa = k2.union(fsa_vec)
89+
90+
expected_tensor_attr = torch.tensor(
91+
[0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
92+
11]).to(fsa.tensor_attr)
93+
assert torch.all(torch.eq(fsa.tensor_attr, expected_tensor_attr))
94+
95+
expected_ragged_tensor_attr = k2.RaggedTensor(
96+
expected_tensor_attr.unsqueeze(-1)).remove_values_eq(0)
97+
assert str(expected_ragged_tensor_attr) == str(
98+
fsa.ragged_tensor_attr)
99+
69100
assert torch.allclose(
70101
fsa.arcs.values()[:, :3],
71102
torch.tensor([

0 commit comments

Comments
 (0)