Skip to content

Commit f75a615

Browse files
authored
Merge pull request #6 from csukuangfj/fangjun-fsa-util
implement GetEnteringArcs.
2 parents 474e352 + 0b25697 commit f75a615

11 files changed

+276
-88
lines changed

.clang-format

+1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ Language: Cpp
55
Cpp11BracedListStyle: true
66
Standard: Cpp11
77
DerivePointerAlignment: false
8+
PointerAlignment: Right
89
---

k2/csrc/CMakeLists.txt

+20
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ add_library(properties properties.cc)
22
target_include_directories(properties PUBLIC ${CMAKE_SOURCE_DIR})
33
target_compile_features(properties PUBLIC cxx_std_11)
44

5+
add_library(fsa_util fsa_util.cc)
6+
target_include_directories(fsa_util PUBLIC ${CMAKE_SOURCE_DIR})
7+
target_compile_features(fsa_util PUBLIC cxx_std_11)
8+
59
add_executable(properties_test properties_test.cc)
610

711
target_link_libraries(properties_test
@@ -15,3 +19,19 @@ add_test(NAME Test.properties_test
1519
COMMAND
1620
$<TARGET_FILE:properties_test>
1721
)
22+
23+
add_executable(fsa_util_test fsa_util_test.cc)
24+
25+
target_link_libraries(fsa_util_test
26+
PRIVATE
27+
fsa_util
28+
gtest
29+
gtest_main
30+
)
31+
32+
add_test(NAME Test.fsa_util_test
33+
COMMAND
34+
$<TARGET_FILE:fsa_util_test>
35+
)
36+
37+
# TODO(fangjun): write some helper functions to create targets.

k2/csrc/fsa.h

+5-6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define K2_CSRC_FSA_H_
99

1010
#include <cstdint>
11+
#include <utility>
1112
#include <vector>
1213

1314
namespace k2 {
@@ -48,7 +49,7 @@ struct Arc {
4849
};
4950

5051
struct ArcLabelCompare {
51-
bool operator()(const Arc& a, const Arc& b) const {
52+
bool operator()(const Arc &a, const Arc &b) const {
5253
return a.label < b.label;
5354
}
5455
};
@@ -92,11 +93,9 @@ struct Fsa {
9293
more state). For 0 <= t < T, we have an arc with symbol n on it for
9394
each 0 <= n < N, from state t to state t+1, with weight equal to
9495
weights[t,n].
95-
96-
9796
*/
9897
struct DenseFsa {
99-
Weight* weights; // Would typically be a log-prob or unnormalized log-prob
98+
Weight *weights; // Would typically be a log-prob or unnormalized log-prob
10099
int32_t T; // The number of time steps == rows in the matrix `weights`;
101100
// this FSA has T + 2 states, see explanation above.
102101
int32_t num_symbols; // The number of symbols == columns in the matrix
@@ -110,7 +109,7 @@ struct DenseFsa {
110109
CAUTION: we may later enforce that stride == num_symbols, in order to
111110
be able to know the layout of a phantom matrix of arcs. (?)
112111
*/
113-
DenseFsa(Weight* data, int32_t T, int32_t num_symbols, int32_t stride);
112+
DenseFsa(Weight *data, int32_t T, int32_t num_symbols, int32_t stride);
114113
};
115114

116115
/*
@@ -120,7 +119,7 @@ struct DenseFsa {
120119
*/
121120
struct VecOfVec {
122121
std::vector<Range> ranges;
123-
std::vector<int32_t> values;
122+
std::vector<std::pair<Label, StateId>> values;
124123
};
125124

126125
struct Fst {

k2/csrc/fsa_algo.h

+26-26
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace k2 {
2626
so the output will be topologically sorted if the input
2727
was.
2828
*/
29-
void ConnectCore(const Fsa& fsa, std::vector<int32>* state_map);
29+
void ConnectCore(const Fsa &fsa, std::vector<int32_t> *state_map);
3030

3131
/*
3232
Removes states that are not accessible (from the start state) or are not
@@ -44,13 +44,13 @@ void ConnectCore(const Fsa& fsa, std::vector<int32>* state_map);
4444
4545
Notes:
4646
- If `a` admitted a topological sorting, b will be topologically
47-
sorted. TODO: maybe just leave in the same order as a??
47+
sorted. TODO(Dan): maybe just leave in the same order as a??
4848
- If `a` was deterministic, `b` will be deterministic; same for
4949
epsilon free, obviously.
5050
- `b` will be arc-sorted (arcs sorted by label)
5151
- `b` will (obviously) be connected
5252
*/
53-
void Connect(const Fsa& a, Fsa* b, std::vector<int32>* arc_map = nullptr);
53+
void Connect(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map = nullptr);
5454

5555
/**
5656
Output an Fsa that is equivalent to the input but which has no epsilons.
@@ -61,10 +61,10 @@ void Connect(const Fsa& a, Fsa* b, std::vector<int32>* arc_map = nullptr);
6161
@param [out] arc_map If non-NULL: for each arc in `b`, a list of
6262
the arc-indexes in `a` that contributed to that arc
6363
(e.g. its cost would be a sum of their costs).
64-
TODO: make it a VecOfVec, maybe?
64+
TODO(Dan): make it a VecOfVec, maybe?
6565
*/
66-
void RmEpsilons(const Fsa& a, Fsa* b,
67-
std::vector<std::vector>* arc_map = nullptr);
66+
void RmEpsilons(const Fsa &a, Fsa *b,
67+
std::vector<std::vector> *arc_map = nullptr);
6868

6969
/**
7070
Pruned version of RmEpsilons, which also uses a pruning beam.
@@ -77,12 +77,12 @@ void RmEpsilons(const Fsa& a, Fsa* b,
7777
@param [out] arc_map If non-NULL: for each arc in `b`, a list of
7878
the arc-indexes in `a` that contributed to that arc
7979
(e.g. its cost would be a sum of their costs).
80-
TODO: make it a VecOfVec, maybe?
80+
TODO(Dan): make it a VecOfVec, maybe?
8181
*/
82-
void RmEpsilonsPruned(const Fsa& a, const float* a_state_forward_costs,
83-
const float* a_state_backward_costs,
84-
const float* a_arc_costs, float cutoff, Fsa* b,
85-
std::vector<std::vector>* arc_map = nullptr);
82+
void RmEpsilonsPruned(const Fsa &a, const float *a_state_forward_costs,
83+
const float *a_state_backward_costs,
84+
const float *a_arc_costs, float cutoff, Fsa *b,
85+
std::vector<std::vector> *arc_map = nullptr);
8686

8787
/*
8888
Compute the intersection of two FSAs; this is the equivalent of composition
@@ -104,16 +104,16 @@ void RmEpsilonsPruned(const Fsa& a, const float* a_state_forward_costs,
104104
size c->arcs.size(), saying for each arc in
105105
`c` what the source arc in `b` was.
106106
*/
107-
void Intersect(const Fsa& a, const Fsa& b, Fsa* c,
108-
std::vector<int32>* arc_map_a = nullptr,
109-
std::vector<int32>* arc_map_b = nullptr);
107+
void Intersect(const Fsa &a, const Fsa &b, Fsa *c,
108+
std::vector<int32_t> *arc_map_a = nullptr,
109+
std::vector<int32_t> *arc_map_b = nullptr);
110110

111111
/*
112112
Version of Intersect where `a` is dense?
113113
*/
114-
void Intersect(const DenseFsa& a, const Fsa& b, Fsa* c,
115-
std::vector<int32>* arc_map_a = nullptr,
116-
std::vector<int32>* arc_map_b = nullptr);
114+
void Intersect(const DenseFsa &a, const Fsa &b, Fsa *c,
115+
std::vector<int32_t> *arc_map_a = nullptr,
116+
std::vector<int32_t> *arc_map_b = nullptr);
117117

118118
/*
119119
Version of Intersect where `a` is dense, pruned with pruning beam `beam`.
@@ -124,9 +124,9 @@ void Intersect(const DenseFsa& a, const Fsa& b, Fsa* c,
124124
125125
This is the same as time-synchronous Viterbi beam pruning.
126126
*/
127-
void IntersectPruned(const DenseFsa& a, const Fsa& b, float beam, Fsa* c,
128-
std::vector<int32>* arc_map_a = nullptr,
129-
std::vector<int32>* arc_map_b = nullptr);
127+
void IntersectPruned(const DenseFsa &a, const Fsa &b, float beam, Fsa *c,
128+
std::vector<int32_t> *arc_map_a = nullptr,
129+
std::vector<int32_t> *arc_map_b = nullptr);
130130

131131
/**
132132
Intersection of two weighted FSA's: the same as Intersect(), but it prunes
@@ -152,13 +152,13 @@ void IntersectPruned(const DenseFsa& a, const Fsa& b, float beam, Fsa* c,
152152
@param [out] state_map_b Maps from arc-index in c to the corresponding
153153
arc-index in b
154154
*/
155-
void IntersectPruned2(const Fsa& a, const float* a_cost, const Fsa& b,
156-
const float* b_cost, float cutoff, Fsa* c,
157-
std::vector<int32>* state_map_a,
158-
std::vector<int32>* state_map_b);
155+
void IntersectPruned2(const Fsa &a, const float *a_cost, const Fsa &b,
156+
const float *b_cost, float cutoff, Fsa *c,
157+
std::vector<int32_t> *state_map_a,
158+
std::vector<int32_t> *state_map_b);
159159

160-
void RandomPath(const Fsa& a, const float* a_cost, Fsa* b,
161-
std::vector<int32>* state_map = nullptr);
160+
void RandomPath(const Fsa &a, const float *a_cost, Fsa *b,
161+
std::vector<int32_t> *state_map = nullptr);
162162

163163
} // namespace k2
164164

k2/csrc/fsa_util.cc

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// k2/csrc/fsa_util.cc
2+
3+
// Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com)
4+
5+
// See ../../LICENSE for clarification regarding multiple authors
6+
7+
#include "k2/csrc/fsa_util.h"
8+
9+
#include <utility>
10+
#include <vector>
11+
12+
namespace k2 {
13+
14+
void GetEnteringArcs(const Fsa &fsa, VecOfVec *entering_arcs) {
15+
// CHECK(CheckProperties(fsa, KTopSorted));
16+
17+
int num_states = fsa.NumStates();
18+
std::vector<std::vector<std::pair<Label, StateId>>> vec(num_states);
19+
int num_arcs = 0;
20+
for (const auto &arc : fsa.arcs) {
21+
auto src_state = arc.src_state;
22+
auto dest_state = arc.dest_state;
23+
auto label = arc.label;
24+
vec[dest_state].emplace_back(label, src_state);
25+
++num_arcs;
26+
}
27+
28+
auto &ranges = entering_arcs->ranges;
29+
auto &values = entering_arcs->values;
30+
ranges.reserve(num_states);
31+
values.reserve(num_arcs);
32+
33+
int32_t start = 0;
34+
int32_t end = 0;
35+
for (const auto &label_state : vec) {
36+
values.insert(values.end(), label_state.begin(), label_state.end());
37+
start = end;
38+
end += static_cast<int32_t>(label_state.size());
39+
ranges.push_back({start, end});
40+
}
41+
}
42+
43+
} // namespace k2

k2/csrc/fsa_util.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace k2 {
1818
Requires that `fsa` be valid and top-sorted, i.e.
1919
CheckProperties(fsa, KTopSorted) == true.
2020
*/
21-
void GetEnteringArcs(const Fsa& fsa, VecOfVec* entering_arcs);
21+
void GetEnteringArcs(const Fsa &fsa, VecOfVec *entering_arcs);
2222

2323
} // namespace k2
2424

k2/csrc/fsa_util_test.cc

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// k2/csrc/fsa_util_test.cc
2+
3+
// Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com)
4+
5+
// See ../../LICENSE for clarification regarding multiple authors
6+
7+
#include "k2/csrc/fsa_util.h"
8+
9+
#include <utility>
10+
#include <vector>
11+
12+
#include "gtest/gtest.h"
13+
14+
namespace k2 {
15+
16+
TEST(FsaUtil, GetEnteringArcs) {
17+
std::vector<Arc> arcs = {
18+
{0, 1, 2}, {0, 2, 1}, {1, 2, 0}, {1, 3, 5}, {2, 3, 6},
19+
};
20+
std::vector<Range> leaving_arcs = {
21+
{0, 2}, {2, 4}, {4, 5}, {0, 0}, // the last state has no entering arcs
22+
};
23+
24+
Fsa fsa;
25+
fsa.leaving_arcs = std::move(leaving_arcs);
26+
fsa.arcs = std::move(arcs);
27+
28+
VecOfVec entering_arcs;
29+
GetEnteringArcs(fsa, &entering_arcs);
30+
31+
const auto &ranges = entering_arcs.ranges;
32+
const auto &values = entering_arcs.values;
33+
EXPECT_EQ(ranges.size(), 4u); // there are 4 states
34+
EXPECT_EQ(values.size(), 5u); // there are 5 arcs
35+
36+
// state 0, no entering arcs
37+
EXPECT_EQ(ranges[0].begin, ranges[0].end);
38+
39+
// state 1 has one entering arc from state 0 with label 2
40+
EXPECT_EQ(ranges[1].begin, 0);
41+
EXPECT_EQ(ranges[1].end, 1);
42+
EXPECT_EQ(values[0].first, 2); // label is 2
43+
EXPECT_EQ(values[0].second, 0); // state is 0
44+
45+
// state 2 has two entering arcs
46+
// the first one: from state 0 with label 1
47+
// the second one: from state 1 with label 0
48+
EXPECT_EQ(ranges[2].begin, 1);
49+
EXPECT_EQ(ranges[2].end, 3);
50+
EXPECT_EQ(values[1].first, 1); // label is 1
51+
EXPECT_EQ(values[1].second, 0); // state is 0
52+
53+
EXPECT_EQ(values[2].first, 0); // label is 0
54+
EXPECT_EQ(values[2].second, 1); // state is 1
55+
56+
// state 3 has two entering arcs
57+
// the first one: from state 1 with label 5
58+
// the second one: from state 2 with label 6
59+
EXPECT_EQ(ranges[3].begin, 3);
60+
EXPECT_EQ(ranges[3].end, 5);
61+
EXPECT_EQ(values[3].first, 5); // label is 5
62+
EXPECT_EQ(values[3].second, 1); // state is 1
63+
64+
EXPECT_EQ(values[4].first, 6); // label is 6
65+
EXPECT_EQ(values[4].second, 2); // state is 2
66+
}
67+
68+
} // namespace k2

0 commit comments

Comments
 (0)