Skip to content

Commit 7094857

Browse files
authored
Implement RandEquivalent and ShorestDistance for weighted FSA (#44)
1 parent 75a83c9 commit 7094857

7 files changed

+669
-103
lines changed

k2/csrc/fsa_equivalent.cc

+287-49
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,119 @@
77
#include "k2/csrc/fsa_equivalent.h"
88

99
#include <algorithm>
10+
#include <queue>
1011
#include <random>
1112
#include <unordered_map>
1213
#include <unordered_set>
14+
#include <utility>
1315
#include <vector>
1416

1517
#include "k2/csrc/fsa.h"
1618
#include "k2/csrc/fsa_algo.h"
1719
#include "k2/csrc/properties.h"
20+
#include "k2/csrc/util.h"
21+
#include "k2/csrc/weights.h"
22+
23+
namespace {
24+
// out_weights[i] = weights[arc_map1[arc_map2[i]]]
25+
static void GetArcWeights(const float *weights,
26+
const std::vector<int32_t> &arc_map1,
27+
const std::vector<int32_t> &arc_map2,
28+
std::vector<float> *out_weights) {
29+
CHECK_NOTNULL(out_weights);
30+
auto &arc_weights = *out_weights;
31+
for (auto i = 0; i != arc_weights.size(); ++i) {
32+
arc_weights[i] = weights[arc_map1[arc_map2[i]]];
33+
}
34+
}
35+
36+
// c = (a - b) + (b-a)
37+
static void SetDifference(const std::unordered_set<int32_t> &a,
38+
const std::unordered_set<int32_t> &b,
39+
std::unordered_set<int32_t> *c) {
40+
CHECK_NOTNULL(c);
41+
c->clear();
42+
for (const auto &v : a) {
43+
if (b.find(v) == b.end()) c->insert(v);
44+
}
45+
for (const auto &v : b) {
46+
if (a.find(v) == a.end()) c->insert(v);
47+
}
48+
}
49+
50+
static bool RandomPathHelper(const k2::Fsa &a, k2::Fsa *b, bool no_epsilon_arc,
51+
std::vector<int32_t> *state_map = nullptr) {
52+
using k2::Arc;
53+
using k2::ArcHash;
54+
using k2::kEpsilon;
55+
if (IsEmpty(a) || b == nullptr) return false;
56+
// we cannot do `connect` on `a` here to get a connected fsa
57+
// as `state_map` will map to states in the connected fsa
58+
// instead of in `a` if we do that.
59+
if (!IsConnected(a)) return false;
60+
61+
int32_t num_states = a.NumStates();
62+
std::vector<int32_t> state_map_b2a;
63+
std::vector<int32_t> state_map_a2b(num_states, -1);
64+
// `visited_arcs[i]` stores `arcs` leaving from state `i` in `b`
65+
std::vector<std::unordered_set<Arc, ArcHash>> visited_arcs;
66+
67+
std::random_device rd;
68+
std::mt19937 generator(rd());
69+
std::uniform_int_distribution<int32_t> distribution(0);
70+
71+
int32_t num_visited_arcs = 0;
72+
int32_t num_visited_state = 0;
73+
int32_t state = 0;
74+
int32_t final_state = num_states - 1;
75+
while (true) {
76+
if (state_map_a2b[state] == -1) {
77+
state_map_a2b[state] = num_visited_state;
78+
state_map_b2a.push_back(state);
79+
visited_arcs.emplace_back(std::unordered_set<Arc, ArcHash>());
80+
++num_visited_state;
81+
}
82+
if (state == final_state) break;
83+
const Arc *curr_arc = nullptr;
84+
int32_t curr_state = state;
85+
do {
86+
int32_t begin = a.arc_indexes[curr_state];
87+
int32_t end = a.arc_indexes[curr_state + 1];
88+
// since `a` is valid, so every states contains at least one arc.
89+
int32_t arc_index = begin + (distribution(generator) % (end - begin));
90+
curr_arc = &a.arcs[arc_index];
91+
curr_state = curr_arc->dest_state;
92+
} while (no_epsilon_arc && curr_arc->label == kEpsilon);
93+
int32_t state_id_in_b = state_map_a2b[state];
94+
if (visited_arcs[state_id_in_b]
95+
.insert({state, curr_arc->dest_state, curr_arc->label})
96+
.second)
97+
++num_visited_arcs;
98+
state = curr_arc->dest_state;
99+
}
100+
101+
// create `b`
102+
b->arc_indexes.resize(num_visited_state);
103+
b->arcs.resize(num_visited_arcs);
104+
int32_t n = 0;
105+
for (int32_t i = 0; i < num_visited_state; ++i) {
106+
b->arc_indexes[i] = n;
107+
for (const auto &arc : visited_arcs[i]) {
108+
auto &b_arc = b->arcs[n];
109+
b_arc.src_state = i;
110+
b_arc.dest_state = state_map_a2b[arc.dest_state];
111+
b_arc.label = arc.label;
112+
++n;
113+
}
114+
}
115+
if (state_map != nullptr) {
116+
state_map->swap(state_map_b2a);
117+
}
118+
b->arc_indexes.emplace_back(b->arc_indexes.back());
119+
return true;
120+
}
121+
122+
} // namespace
18123

19124
namespace k2 {
20125

@@ -62,65 +167,198 @@ bool IsRandEquivalent(const Fsa &a, const Fsa &b, std::size_t npath /*=100*/) {
62167
return true;
63168
}
64169

170+
template <FbWeightType Type>
171+
bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
172+
const float *b_weights, bool top_sorted /*=true*/,
173+
std::size_t npath /*= 100*/) {
174+
Fsa connected_a, connected_b, valid_a, valid_b;
175+
std::vector<int32_t> connected_a_arc_map, connected_b_arc_map,
176+
valid_a_arc_map, valid_b_arc_map;
177+
Connect(a, &connected_a, &connected_a_arc_map);
178+
Connect(b, &connected_b, &connected_b_arc_map);
179+
ArcSort(connected_a, &valid_a, &valid_a_arc_map); // required by `intersect`
180+
ArcSort(connected_b, &valid_b, &valid_b_arc_map);
181+
if (IsEmpty(valid_a) && IsEmpty(valid_b)) return true;
182+
if (IsEmpty(valid_a) || IsEmpty(valid_b)) return false;
183+
184+
// Get arc weights
185+
std::vector<float> valid_a_weights(valid_a.arcs.size());
186+
std::vector<float> valid_b_weights(valid_b.arcs.size());
187+
GetArcWeights(a_weights, connected_a_arc_map, valid_a_arc_map,
188+
&valid_a_weights);
189+
GetArcWeights(b_weights, connected_b_arc_map, valid_b_arc_map,
190+
&valid_b_weights);
191+
192+
// Check that arc labels are compatible.
193+
std::unordered_set<int32_t> labels_a, labels_b, labels_difference;
194+
for (const auto &arc : valid_a.arcs) labels_a.insert(arc.label);
195+
for (const auto &arc : valid_b.arcs) labels_b.insert(arc.label);
196+
SetDifference(labels_a, labels_b, &labels_difference);
197+
if (labels_difference.size() >= 2 ||
198+
(labels_difference.size() == 1 &&
199+
(*(labels_difference.begin())) != kEpsilon))
200+
return false;
201+
202+
std::random_device rd;
203+
std::mt19937 gen(rd());
204+
std::bernoulli_distribution coin(0.5);
205+
for (auto i = 0; i != npath; ++i) {
206+
const auto &fsa = coin(gen) ? valid_a : valid_b;
207+
Fsa path, valid_path;
208+
RandomPathWithoutEpsilonArc(fsa, &path); // path is already connected
209+
ArcSort(path, &valid_path);
210+
211+
Fsa a_compose_path, b_compose_path;
212+
std::vector<float> a_compose_weights, b_compose_weights;
213+
Intersect(valid_a, valid_a_weights.data(), path, &a_compose_path,
214+
&a_compose_weights);
215+
Intersect(valid_b, valid_b_weights.data(), path, &b_compose_path,
216+
&b_compose_weights);
217+
// TODO(haowen): we may need to implement a version of `ShortestDistance`
218+
// for non-top-sorted FSAs, but we prefer to decide this later as there's no
219+
// such scenarios (input FSAs are not top-sorted) currently. If we finally
220+
// find out that we don't need that version, we will remove flag
221+
// `top_sorted` and add requirements as comments in the header file.
222+
CHECK(top_sorted);
223+
double sum_a =
224+
ShortestDistance<Type>(a_compose_path, a_compose_weights.data());
225+
double sum_b =
226+
ShortestDistance<Type>(b_compose_path, b_compose_weights.data());
227+
if (!DoubleApproxEqual(sum_a, sum_b)) return false;
228+
}
229+
return true;
230+
}
231+
232+
// explicit instantiation here
233+
template bool IsRandEquivalent<kMaxWeight>(const Fsa &a, const float *a_weights,
234+
const Fsa &b, const float *b_weights,
235+
bool top_sorted, std::size_t npath);
236+
template bool IsRandEquivalent<kLogSumWeight>(
237+
const Fsa &a, const float *a_weights, const Fsa &b, const float *b_weights,
238+
bool top_sorted, std::size_t npath);
239+
65240
bool RandomPath(const Fsa &a, Fsa *b,
66241
std::vector<int32_t> *state_map /*=nullptr*/) {
67-
if (IsEmpty(a) || b == nullptr) return false;
68-
// we cannot do `connect` on `a` here to get a connected fsa
69-
// as `state_map` will map to states in the connected fsa
70-
// instead of in `a` if we do that.
71-
if (!IsConnected(a)) return false;
242+
return RandomPathHelper(a, b, false, state_map);
243+
}
72244

73-
int32_t num_states = a.NumStates();
74-
std::vector<int32_t> state_map_b2a;
75-
std::vector<int32_t> state_map_a2b(num_states, -1);
76-
// `visited_arcs[i]` stores `arcs` leaving from state `i` in `b`
77-
std::vector<std::unordered_set<Arc, ArcHash>> visited_arcs;
245+
bool RandomPathWithoutEpsilonArc(
246+
const Fsa &a, Fsa *b, std::vector<int32_t> *state_map /*= nullptr*/) {
247+
return RandomPathHelper(a, b, true, state_map);
248+
}
78249

79-
std::random_device rd;
80-
std::mt19937 generator(rd());
81-
std::uniform_int_distribution<int32_t> distribution(0);
250+
void Intersect(const Fsa &a, const float *a_weights, const Fsa &b, Fsa *c,
251+
std::vector<float> *c_weights,
252+
std::vector<int32_t> *arc_map_a /*= nullptr*/,
253+
std::vector<int32_t> *arc_map_b /*= nullptr*/) {
254+
CHECK_NOTNULL(c);
255+
CHECK_NOTNULL(c_weights);
256+
c->arc_indexes.clear();
257+
c->arcs.clear();
258+
c_weights->clear();
259+
if (arc_map_a != nullptr) arc_map_a->clear();
260+
if (arc_map_b != nullptr) arc_map_b->clear();
82261

83-
int32_t num_visited_arcs = 0;
84-
int32_t num_visited_state = 0;
85-
int32_t state = 0;
86-
int32_t final_state = num_states - 1;
87-
while (true) {
88-
if (state_map_a2b[state] == -1) {
89-
state_map_a2b[state] = num_visited_state;
90-
state_map_b2a.push_back(state);
91-
visited_arcs.emplace_back(std::unordered_set<Arc, ArcHash>());
92-
++num_visited_state;
262+
if (IsEmpty(a) || IsEmpty(b)) return;
263+
CHECK(IsArcSorted(a));
264+
CHECK(IsArcSorted(b));
265+
CHECK(IsEpsilonFree(b));
266+
267+
int32_t final_state_a = a.NumStates() - 1;
268+
int32_t final_state_b = b.NumStates() - 1;
269+
const auto arc_a_begin = a.arcs.begin();
270+
const auto arc_b_begin = b.arcs.begin();
271+
using ArcIterator = std::vector<Arc>::const_iterator;
272+
273+
const int32_t kFinalStateC = -1; // just as a placeholder
274+
// no corresponding arc mapping from `c` to `a` or `c` to `b`
275+
const int32_t kArcMapNone = -1;
276+
auto &arc_indexes_c = c->arc_indexes;
277+
auto &arcs_c = c->arcs;
278+
279+
using StatePair = std::pair<int32_t, int32_t>;
280+
// map state pair to unique id
281+
std::unordered_map<StatePair, int32_t, PairHash> state_pair_map;
282+
std::queue<StatePair> qstates;
283+
qstates.push({0, 0});
284+
state_pair_map.insert({{0, 0}, 0});
285+
state_pair_map.insert({{final_state_a, final_state_b}, kFinalStateC});
286+
int32_t state_index_c = 0;
287+
while (!qstates.empty()) {
288+
arc_indexes_c.push_back(static_cast<int32_t>(arcs_c.size()));
289+
290+
auto curr_state_pair = qstates.front();
291+
qstates.pop();
292+
// as we have inserted `curr_state_pair` before.
293+
int32_t curr_state_index = state_pair_map[curr_state_pair];
294+
295+
auto state_a = curr_state_pair.first;
296+
ArcIterator a_arc_iter_begin = arc_a_begin + a.arc_indexes[state_a];
297+
ArcIterator a_arc_iter_end = arc_a_begin + a.arc_indexes[state_a + 1];
298+
auto state_b = curr_state_pair.second;
299+
ArcIterator b_arc_iter_begin = arc_b_begin + b.arc_indexes[state_b];
300+
ArcIterator b_arc_iter_end = arc_b_begin + b.arc_indexes[state_b + 1];
301+
302+
// As both `a` and `b` are arc-sorted, we first process epsilon arcs in `a`.
303+
for (; a_arc_iter_begin != a_arc_iter_end; ++a_arc_iter_begin) {
304+
if (kEpsilon != a_arc_iter_begin->label) break;
305+
306+
StatePair new_state{a_arc_iter_begin->dest_state, state_b};
307+
auto result = state_pair_map.insert({new_state, state_index_c + 1});
308+
if (result.second) {
309+
// we have not visited `new_state` before.
310+
qstates.push(new_state);
311+
++state_index_c;
312+
}
313+
int32_t new_state_index = result.first->second;
314+
arcs_c.push_back({curr_state_index, new_state_index, kEpsilon});
315+
c_weights->push_back(a_weights[a_arc_iter_begin - arc_a_begin]);
316+
if (arc_map_a != nullptr)
317+
arc_map_a->push_back(
318+
static_cast<int32_t>(a_arc_iter_begin - arc_a_begin));
319+
if (arc_map_b != nullptr) arc_map_b->push_back(kArcMapNone);
93320
}
94-
if (state == final_state) break;
95-
int32_t begin = a.arc_indexes[state];
96-
int32_t end = a.arc_indexes[state + 1];
97-
// since `a` is valid, so every states contains at least one arc.
98-
int32_t arc_index = begin + (distribution(generator) % (end - begin));
99-
int32_t state_id_in_b = state_map_a2b[state];
100-
const auto &curr_arc = a.arcs[arc_index];
101-
if (visited_arcs[state_id_in_b].insert(curr_arc).second) ++num_visited_arcs;
102-
state = curr_arc.dest_state;
103-
}
104321

105-
// create `b`
106-
b->arc_indexes.resize(num_visited_state);
107-
b->arcs.resize(num_visited_arcs);
108-
int32_t n = 0;
109-
for (int32_t i = 0; i < num_visited_state; ++i) {
110-
b->arc_indexes[i] = n;
111-
for (const auto &arc : visited_arcs[i]) {
112-
auto &b_arc = b->arcs[n];
113-
b_arc.src_state = i;
114-
b_arc.dest_state = state_map_a2b[arc.dest_state];
115-
b_arc.label = arc.label;
116-
++n;
322+
// `b` is usually a path generated from `RandNonEpsilonPath`, it may hold
323+
// less number of arcs in each state, so we iterate over `b` here to save
324+
// time.
325+
for (; b_arc_iter_begin != b_arc_iter_end; ++b_arc_iter_begin) {
326+
const Arc &curr_b_arc = *b_arc_iter_begin;
327+
auto a_arc_range =
328+
std::equal_range(a_arc_iter_begin, a_arc_iter_end, curr_b_arc,
329+
[](const Arc &left, const Arc &right) {
330+
return left.label < right.label;
331+
});
332+
for (auto it_a = a_arc_range.first; it_a != a_arc_range.second; ++it_a) {
333+
const Arc &curr_a_arc = *it_a;
334+
StatePair new_state{curr_a_arc.dest_state, curr_b_arc.dest_state};
335+
auto result = state_pair_map.insert({new_state, state_index_c + 1});
336+
if (result.second) {
337+
qstates.push(new_state);
338+
++state_index_c;
339+
}
340+
int32_t new_state_index = result.first->second;
341+
arcs_c.push_back({curr_state_index, new_state_index, curr_a_arc.label});
342+
c_weights->push_back(a_weights[it_a - arc_a_begin]);
343+
if (arc_map_a != nullptr)
344+
arc_map_a->push_back(static_cast<int32_t>(it_a - arc_a_begin));
345+
if (arc_map_b != nullptr)
346+
arc_map_b->push_back(
347+
static_cast<int32_t>(b_arc_iter_begin - arc_b_begin));
348+
}
117349
}
118350
}
119-
if (state_map != nullptr) {
120-
state_map->swap(state_map_b2a);
351+
352+
// push final state
353+
arc_indexes_c.push_back(static_cast<int32_t>(arcs_c.size()));
354+
++state_index_c;
355+
// then replace `kFinalStateC` with the real index of final state of `c`
356+
for (auto &arc : arcs_c) {
357+
if (arc.dest_state == kFinalStateC) arc.dest_state = state_index_c;
121358
}
122-
b->arc_indexes.emplace_back(b->arc_indexes.back());
123-
return true;
359+
// push a duplicate of final state, see the constructor of `Fsa` in
360+
// `k2/csrc/fsa.h`
361+
arc_indexes_c.emplace_back(arc_indexes_c.back());
124362
}
125363

126364
} // namespace k2

0 commit comments

Comments
 (0)