Skip to content

Commit 37aafd8

Browse files
committed
Merge branch 'master'
2 parents 18823fe + 5565bd4 commit 37aafd8

6 files changed

+367
-2
lines changed

k2/csrc/aux_labels.cc

+169
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,94 @@
66

77
#include "k2/csrc/aux_labels.h"
88

9+
#include <algorithm>
910
#include <numeric>
11+
#include <utility>
1012
#include <vector>
1113

1214
#include "glog/logging.h"
1315
#include "k2/csrc/fsa.h"
16+
#include "k2/csrc/fsa_util.h"
17+
#include "k2/csrc/properties.h"
18+
19+
namespace {
20+
21+
/*
22+
This function counts how many extra states we need to create for each state in
23+
the input FSA when we invert an FSA. Generally, if an entering arc of state
24+
`i` in the input FSA has n olabels, then we need to create n-1 extra states
25+
for state `i`.
26+
27+
@param [in] fsa_in Input FSA
28+
@param [in] labels_in Aux-label sequences for the input FSA
29+
@param [out] num_extra_states For state `i` in `fsa_in`, we need to create
30+
extra `num_extra_states[i]` states in the output
31+
inverted FSA.
32+
*/
33+
static void CountExtraStates(const k2::Fsa &fsa_in,
34+
const k2::AuxLabels &labels_in,
35+
std::vector<int32_t> *num_extra_states) {
36+
CHECK_EQ(num_extra_states->size(), fsa_in.NumStates());
37+
auto &states = *num_extra_states;
38+
for (int32_t i = 0; i != fsa_in.arcs.size(); ++i) {
39+
const auto &arc = fsa_in.arcs[i];
40+
int32_t pos_start = labels_in.start_pos[i];
41+
int32_t pos_end = labels_in.start_pos[i + 1];
42+
states[arc.dest_state] += std::max(0, pos_end - pos_start - 1);
43+
}
44+
}
45+
46+
/*
47+
Map the state in the input FSA to state in the output inverted FSA.
48+
49+
@param [in] num_extra_states Output of function `CountExtraStates`
50+
which gives how many extra states we need
51+
to create for each state in the input FSA.
52+
@param [out] state_map Map state `i` in the input FSA to state
53+
`state_map[i]` in the output FSA.
54+
At exit, it will be
55+
state_map[0] = 0,
56+
state_map[i] = state_map[i-1]
57+
+ num_extra_states[i]
58+
+ 1, for any i >=1
59+
@param [out] state_ids At exit, it will be
60+
state_ids[0] = 0,
61+
state_ids[i] = state_map[i-1], for any i >= 1.
62+
*/
63+
static void MapStates(const std::vector<int32_t> &num_extra_states,
64+
std::vector<int32_t> *state_map,
65+
std::vector<int32_t> *state_ids) {
66+
CHECK_EQ(state_map->size(), num_extra_states.size());
67+
CHECK_EQ(state_ids->size(), num_extra_states.size());
68+
auto &s_map = *state_map;
69+
auto &s_ids = *state_ids;
70+
// we suppose there's no arcs entering the start state (i.e. state id of the
71+
// start state in output FSA will be 0), otherwise we may need to create a new
72+
// state as the real start state.
73+
CHECK_EQ(num_extra_states[0], 0);
74+
auto num_states_in = num_extra_states.size();
75+
// process from the second state
76+
s_map[0] = 0;
77+
s_ids[0] = 0;
78+
int32_t num_states_out = 0;
79+
for (auto i = 1; i != num_states_in; ++i) {
80+
s_ids[i] = num_states_out;
81+
// `+1` as we did not count state `i` itself in `num_extra_states`
82+
num_states_out += num_extra_states[i] + 1;
83+
s_map[i] = num_states_out;
84+
}
85+
}
86+
} // namespace
1487

1588
namespace k2 {
1689

90+
void Swap(AuxLabels *labels1, AuxLabels *labels2) {
91+
CHECK_NOTNULL(labels1);
92+
CHECK_NOTNULL(labels2);
93+
std::swap(labels1->start_pos, labels2->start_pos);
94+
std::swap(labels1->labels, labels2->labels);
95+
}
96+
1797
void MapAuxLabels1(const AuxLabels &labels_in,
1898
const std::vector<int32_t> &arc_map, AuxLabels *labels_out) {
1999
CHECK_NOTNULL(labels_out);
@@ -61,4 +141,93 @@ void MapAuxLabels2(const AuxLabels &labels_in,
61141
start_pos.push_back(num_labels);
62142
}
63143

144+
void InvertFst(const Fsa &fsa_in, const AuxLabels &labels_in, Fsa *fsa_out,
145+
AuxLabels *aux_labels_out) {
146+
CHECK_NOTNULL(fsa_out);
147+
CHECK_NOTNULL(aux_labels_out);
148+
fsa_out->arc_indexes.clear();
149+
fsa_out->arcs.clear();
150+
aux_labels_out->start_pos.clear();
151+
aux_labels_out->labels.clear();
152+
153+
if (IsEmpty(fsa_in)) {
154+
aux_labels_out->start_pos.push_back(0);
155+
return;
156+
}
157+
158+
auto num_states_in = fsa_in.NumStates();
159+
// get the number of extra states we need to create for each state
160+
// in fsa_in when inverting
161+
std::vector<int32_t> num_extra_states(num_states_in, 0);
162+
CountExtraStates(fsa_in, labels_in, &num_extra_states);
163+
164+
// map state in fsa_in to state in fsa_out
165+
std::vector<int32_t> state_map(num_states_in, 0);
166+
std::vector<int32_t> state_ids(num_states_in, 0);
167+
MapStates(num_extra_states, &state_map, &state_ids);
168+
169+
// a maximal approximation
170+
int32_t num_arcs_out = labels_in.labels.size() + fsa_in.arcs.size();
171+
std::vector<Arc> arcs;
172+
arcs.reserve(num_arcs_out);
173+
// `+1` for the end position of the last arc's olabel sequence
174+
std::vector<int32_t> start_pos;
175+
start_pos.reserve(num_arcs_out + 1);
176+
std::vector<int32_t> labels;
177+
labels.reserve(fsa_in.arcs.size());
178+
int32_t final_state_in = fsa_in.FinalState();
179+
180+
int32_t num_non_eps_ilabel_processed = 0;
181+
start_pos.push_back(0);
182+
for (auto i = 0; i != fsa_in.arcs.size(); ++i) {
183+
const auto &arc = fsa_in.arcs[i];
184+
int32_t pos_start = labels_in.start_pos[i];
185+
int32_t pos_end = labels_in.start_pos[i + 1];
186+
int32_t src_state = arc.src_state;
187+
int32_t dest_state = arc.dest_state;
188+
if (dest_state == final_state_in) {
189+
// every arc entering the final state must have exactly
190+
// one olabel == kFinalSymbol
191+
CHECK_EQ(pos_start + 1, pos_end);
192+
CHECK_EQ(labels_in.labels[pos_start], kFinalSymbol);
193+
}
194+
if (pos_end - pos_start <= 1) {
195+
int32_t curr_label =
196+
(pos_end - pos_start == 0) ? kEpsilon : labels_in.labels[pos_start];
197+
arcs.emplace_back(state_map[src_state], state_map[dest_state],
198+
curr_label);
199+
} else {
200+
// expand arcs with olabels
201+
arcs.emplace_back(state_map[src_state], state_ids[dest_state] + 1,
202+
labels_in.labels[pos_start]);
203+
start_pos.push_back(num_non_eps_ilabel_processed);
204+
for (int32_t pos = pos_start + 1; pos < pos_end - 1; ++pos) {
205+
++state_ids[dest_state];
206+
arcs.emplace_back(state_ids[dest_state], state_ids[dest_state] + 1,
207+
labels_in.labels[pos]);
208+
start_pos.push_back(num_non_eps_ilabel_processed);
209+
}
210+
++state_ids[dest_state];
211+
arcs.emplace_back(state_ids[dest_state], state_map[arc.dest_state],
212+
labels_in.labels[pos_end - 1]);
213+
}
214+
// push non-epsilon ilabel in fsa_in as olabel of fsa_out
215+
if (arc.label != kEpsilon) {
216+
labels.push_back(arc.label);
217+
++num_non_eps_ilabel_processed;
218+
}
219+
start_pos.push_back(num_non_eps_ilabel_processed);
220+
}
221+
222+
labels.resize(labels.size());
223+
arcs.resize(arcs.size());
224+
start_pos.resize(start_pos.size());
225+
226+
std::vector<int32_t> arc_map;
227+
ReorderArcs(arcs, fsa_out, &arc_map);
228+
AuxLabels labels_tmp;
229+
labels_tmp.start_pos = std::move(start_pos);
230+
labels_tmp.labels = std::move(labels);
231+
MapAuxLabels1(labels_tmp, arc_map, aux_labels_out);
232+
}
64233
} // namespace k2

k2/csrc/aux_labels.h

+3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ struct AuxLabels {
4646
std::vector<int32_t> labels;
4747
};
4848

49+
// Swap AuxLabels; it's cheap to to this as we are actually doing shallow swap.
50+
void Swap(AuxLabels *labels1, AuxLabels *labels2);
51+
4952
/*
5053
Maps auxiliary labels after an FSA operation where each arc in the output
5154
FSA corresponds to exactly one arc in the input FSA.

k2/csrc/aux_labels_test.cc

+103
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "gmock/gmock.h"
1313
#include "gtest/gtest.h"
1414
#include "k2/csrc/fsa.h"
15+
#include "k2/csrc/properties.h"
1516

1617
namespace k2 {
1718

@@ -97,4 +98,106 @@ TEST_F(AuxLablesTest, MapAuxLabels2) {
9798
}
9899
}
99100

101+
TEST(AuxLabels, InvertFst) {
102+
{
103+
// empty input FSA
104+
Fsa fsa_in;
105+
AuxLabels labels_in;
106+
std::vector<int32_t> start_pos = {0, 1, 3, 6, 7};
107+
std::vector<int32_t> labels = {1, 2, 3, 4, 5, 6, 7};
108+
labels_in.start_pos = std::move(start_pos);
109+
labels_in.labels = std::move(labels);
110+
111+
std::vector<Arc> arcs = {{0, 1, 1}, {1, 2, -1}};
112+
Fsa fsa_out(std::move(arcs), 2);
113+
AuxLabels labels_out;
114+
// some dirty data
115+
labels_out.start_pos = {1, 2, 3};
116+
labels_out.labels = {4, 5};
117+
InvertFst(fsa_in, labels_in, &fsa_out, &labels_out);
118+
119+
EXPECT_TRUE(IsEmpty(fsa_out));
120+
EXPECT_TRUE(labels_out.labels.empty());
121+
ASSERT_EQ(labels_out.start_pos.size(), 1);
122+
EXPECT_EQ(labels_out.start_pos[0], 0);
123+
}
124+
125+
{
126+
// top-sorted input FSA
127+
std::vector<Arc> arcs = {{0, 1, 1}, {0, 1, 0}, {0, 3, 2},
128+
{1, 2, 3}, {1, 3, 4}, {1, 5, -1},
129+
{2, 3, 0}, {2, 5, -1}, {4, 5, -1}};
130+
Fsa fsa_in(std::move(arcs), 5);
131+
EXPECT_TRUE(IsTopSorted(fsa_in));
132+
AuxLabels labels_in;
133+
std::vector<int32_t> start_pos = {0, 2, 3, 3, 6, 6, 7, 7, 8, 9};
134+
EXPECT_EQ(start_pos.size(), fsa_in.arcs.size() + 1);
135+
std::vector<int32_t> labels = {1, 2, 3, 5, 6, 7, -1, -1, -1};
136+
labels_in.start_pos = std::move(start_pos);
137+
labels_in.labels = std::move(labels);
138+
139+
Fsa fsa_out;
140+
AuxLabels labels_out;
141+
InvertFst(fsa_in, labels_in, &fsa_out, &labels_out);
142+
143+
EXPECT_TRUE(IsTopSorted(fsa_out));
144+
std::vector<Arc> arcs_out = {
145+
{0, 1, 1}, {0, 2, 3}, {0, 6, 0}, {1, 2, 2}, {2, 3, 5}, {2, 6, 0},
146+
{2, 8, -1}, {3, 4, 6}, {4, 5, 7}, {5, 6, 0}, {5, 8, -1}, {7, 8, -1},
147+
};
148+
ASSERT_EQ(fsa_out.arcs.size(), arcs_out.size());
149+
for (auto i = 0; i != arcs_out.size(); ++i) {
150+
EXPECT_EQ(fsa_out.arcs[i], arcs_out[i]);
151+
}
152+
ASSERT_EQ(fsa_out.arc_indexes.size(), 10);
153+
EXPECT_THAT(fsa_out.arc_indexes,
154+
::testing::ElementsAre(0, 3, 4, 7, 8, 9, 11, 11, 12, 12));
155+
ASSERT_EQ(labels_out.labels.size(), 7);
156+
EXPECT_THAT(labels_out.labels,
157+
::testing::ElementsAre(2, 1, 4, -1, 3, -1, -1));
158+
ASSERT_EQ(labels_out.start_pos.size(), 13);
159+
EXPECT_THAT(labels_out.start_pos,
160+
::testing::ElementsAre(0, 0, 0, 1, 2, 2, 3, 4, 4, 5, 5, 6, 7));
161+
}
162+
163+
{
164+
// non-top-sorted input FSA
165+
std::vector<Arc> arcs = {{0, 1, 1}, {0, 1, 0}, {0, 3, 2},
166+
{1, 2, 3}, {1, 3, 4}, {2, 1, 5},
167+
{2, 5, -1}, {3, 1, 6}, {4, 5, -1}};
168+
Fsa fsa_in(std::move(arcs), 5);
169+
EXPECT_FALSE(IsTopSorted(fsa_in));
170+
AuxLabels labels_in;
171+
std::vector<int32_t> start_pos = {0, 2, 3, 3, 6, 6, 7, 8, 10, 11};
172+
EXPECT_EQ(start_pos.size(), fsa_in.arcs.size() + 1);
173+
std::vector<int32_t> labels = {1, 2, 3, 5, 6, 7, 8, -1, 9, 10, -1};
174+
labels_in.start_pos = std::move(start_pos);
175+
labels_in.labels = std::move(labels);
176+
177+
Fsa fsa_out;
178+
AuxLabels labels_out;
179+
InvertFst(fsa_in, labels_in, &fsa_out, &labels_out);
180+
181+
EXPECT_FALSE(IsTopSorted(fsa_out));
182+
std::vector<Arc> arcs_out = {{0, 1, 1}, {0, 3, 3}, {0, 7, 0}, {1, 3, 2},
183+
{2, 3, 10}, {3, 4, 5}, {3, 7, 0}, {4, 5, 6},
184+
{5, 6, 7}, {6, 3, 8}, {6, 9, -1}, {7, 2, 9},
185+
{8, 9, -1}};
186+
ASSERT_EQ(fsa_out.arcs.size(), arcs_out.size());
187+
for (auto i = 0; i != arcs_out.size(); ++i) {
188+
EXPECT_EQ(fsa_out.arcs[i], arcs_out[i]);
189+
}
190+
ASSERT_EQ(fsa_out.arc_indexes.size(), 11);
191+
EXPECT_THAT(fsa_out.arc_indexes,
192+
::testing::ElementsAre(0, 3, 4, 5, 7, 8, 9, 11, 12, 13, 13));
193+
ASSERT_EQ(labels_out.labels.size(), 8);
194+
EXPECT_THAT(labels_out.labels,
195+
::testing::ElementsAre(2, 1, 6, 4, 3, 5, -1, -1));
196+
ASSERT_EQ(labels_out.start_pos.size(), 14);
197+
EXPECT_THAT(
198+
labels_out.start_pos,
199+
::testing::ElementsAre(0, 0, 0, 1, 2, 3, 3, 4, 4, 5, 6, 7, 7, 8));
200+
}
201+
}
202+
100203
} // namespace k2

k2/csrc/fsa_util.cc

+37
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,43 @@ void GetArcWeights(const float *arc_weights_in,
177177
}
178178
}
179179

180+
void ReorderArcs(const std::vector<Arc> &arcs, Fsa *fsa,
181+
std::vector<int32_t> *arc_map /*= nullptr*/) {
182+
CHECK_NOTNULL(fsa);
183+
fsa->arc_indexes.clear();
184+
fsa->arcs.clear();
185+
if (arc_map != nullptr) arc_map->clear();
186+
187+
if (arcs.empty()) return;
188+
189+
using ArcWithIndex = std::pair<Arc, int32_t>;
190+
int arc_id = 0;
191+
std::vector<std::vector<ArcWithIndex>> vec;
192+
for (const auto &arc : arcs) {
193+
auto src_state = arc.src_state;
194+
auto dest_state = arc.dest_state;
195+
auto new_size = std::max(src_state, dest_state);
196+
if (new_size >= vec.size()) vec.resize(new_size + 1);
197+
vec[src_state].push_back({arc, arc_id++});
198+
}
199+
200+
std::size_t num_states = vec.size();
201+
fsa->arc_indexes.resize(num_states + 1);
202+
fsa->arcs.reserve(arcs.size());
203+
std::vector<int32_t> arc_map_out;
204+
arc_map_out.reserve(arcs.size());
205+
206+
for (auto i = 0; i != num_states; ++i) {
207+
fsa->arc_indexes[i] = static_cast<int32_t>(fsa->arcs.size());
208+
for (auto arc_with_index : vec[i]) {
209+
fsa->arcs.emplace_back(arc_with_index.first);
210+
arc_map_out.push_back(arc_with_index.second);
211+
}
212+
}
213+
fsa->arc_indexes.back() = static_cast<int32_t>(fsa->arcs.size());
214+
if (arc_map != nullptr) arc_map->swap(arc_map_out);
215+
}
216+
180217
void Swap(Fsa *a, Fsa *b) {
181218
CHECK_NOTNULL(a);
182219
CHECK_NOTNULL(b);

k2/csrc/fsa_util.h

+19-2
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,27 @@ void GetArcWeights(const float *arc_weights_in,
5959
const std::vector<std::vector<int32_t>> &arc_map,
6060
float *arc_weights_out);
6161

62-
// Version of GetArcWeights where arc_map maps each arc in the output FSA to one
63-
// arc (instead of a sequence of arcs) in the input FSA; see its documentation.
62+
// Version of GetArcWeights where arc_map maps each arc in the output FSA to
63+
// one arc (instead of a sequence of arcs) in the input FSA; see its
64+
// documentation.
6465
void GetArcWeights(const float *arc_weights_in,
6566
const std::vector<int32_t> &arc_map, float *arc_weights_out);
67+
68+
/* Reorder a list of arcs to get a valid FSA. This function will be used in a
69+
situation that the input list of arcs is not sorted by src_state, we'll
70+
reorder the arcs and generate the corresponding valid FSA. Note that we don't
71+
remap any state index here, it is supposed that the start state is 0 and the
72+
final state is the largest state number in the input arcs.
73+
74+
@param [in] arcs A list of arcs.
75+
@param [out] fsa Output fsa.
76+
@param [out] arc_map If non-NULL, this function will
77+
output a map from the arc-index in `fsa` to
78+
the corresponding arc-index in input `arcs`.
79+
*/
80+
void ReorderArcs(const std::vector<Arc> &arcs, Fsa *fsa,
81+
std::vector<int32_t> *arc_map = nullptr);
82+
6683
/*
6784
Convert indexes (typically arc-mapping indexes, e.g. as output by Compose())
6885
from int32 to int64; this will be needed for conversion to LongTensor.

0 commit comments

Comments
 (0)