Skip to content

Commit d860c7e

Browse files
authored
add beam in RandEquivalent (#46)
* add beam in RandEquivalent * changed to check the total weights
1 parent 7094857 commit d860c7e

8 files changed

+184
-75
lines changed

k2/csrc/fsa_algo_test.cc

+8-6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "gmock/gmock.h"
1515
#include "gtest/gtest.h"
16+
#include "k2/csrc/fsa_equivalent.h"
1617
#include "k2/csrc/fsa_renderer.h"
1718
#include "k2/csrc/fsa_util.h"
1819
#include "k2/csrc/properties.h"
@@ -593,22 +594,23 @@ TEST_F(DeterminizeTest, DeterminizePrunedMax) {
593594
Fsa b;
594595
std::vector<float> b_arc_weights;
595596
std::vector<std::vector<int32_t>> arc_derivs;
596-
DeterminizePrunedMax(*max_wfsa_, 10, 100, &b, &b_arc_weights, &arc_derivs);
597+
DeterminizePrunedMax(*max_wfsa_, 10.0, 100, &b, &b_arc_weights, &arc_derivs);
597598

598599
EXPECT_TRUE(IsDeterministic(b));
599-
600-
// TODO(haowen) as the type of `label_to_state` is `unordered_map` (instead of
601-
// `map`), the output `state_id` and `arc_id` may differ under different STL
602-
// implementations, we need to check the equivalence automatically
600+
EXPECT_TRUE(IsRandEquivalent<kMaxWeight>(
601+
max_wfsa_->fsa, max_wfsa_->arc_weights, b, b_arc_weights.data(), 10.0));
603602
}
604603

605604
TEST_F(DeterminizeTest, DeterminizePrunedLogSum) {
606605
Fsa b;
607606
std::vector<float> b_arc_weights;
608607
std::vector<std::vector<std::pair<int32_t, float>>> arc_derivs;
609-
DeterminizePrunedLogSum(*log_wfsa_, 10, 100, &b, &b_arc_weights, &arc_derivs);
608+
DeterminizePrunedLogSum(*log_wfsa_, 10.0, 100, &b, &b_arc_weights,
609+
&arc_derivs);
610610

611611
EXPECT_TRUE(IsDeterministic(b));
612+
EXPECT_TRUE(IsRandEquivalent<kLogSumWeight>(
613+
log_wfsa_->fsa, log_wfsa_->arc_weights, b, b_arc_weights.data(), 10.0));
612614

613615
// TODO(haowen): how to check `arc_derivs_out` here, may return `num_steps` to
614616
// check the sum of `derivs_out` for each output arc?

k2/csrc/fsa_equivalent.cc

+31-6
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,12 @@ bool IsRandEquivalent(const Fsa &a, const Fsa &b, std::size_t npath /*=100*/) {
169169

170170
template <FbWeightType Type>
171171
bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
172-
const float *b_weights, bool top_sorted /*=true*/,
172+
const float *b_weights, float beam /*=kFloatInfinity*/,
173+
float delta /*=1e-6*/, bool top_sorted /*=true*/,
173174
std::size_t npath /*= 100*/) {
175+
CHECK_GT(beam, 0);
176+
CHECK_NOTNULL(a_weights);
177+
CHECK_NOTNULL(b_weights);
174178
Fsa connected_a, connected_b, valid_a, valid_b;
175179
std::vector<int32_t> connected_a_arc_map, connected_b_arc_map,
176180
valid_a_arc_map, valid_b_arc_map;
@@ -199,10 +203,25 @@ bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
199203
(*(labels_difference.begin())) != kEpsilon))
200204
return false;
201205

206+
double loglike_cutoff_a, loglike_cutoff_b;
207+
if (beam != kFloatInfinity) {
208+
loglike_cutoff_a =
209+
ShortestDistance<Type>(valid_a, valid_a_weights.data()) - beam;
210+
loglike_cutoff_b =
211+
ShortestDistance<Type>(valid_b, valid_b_weights.data()) - beam;
212+
if (Type == kMaxWeight &&
213+
!DoubleApproxEqual(loglike_cutoff_a, loglike_cutoff_b))
214+
return false;
215+
} else {
216+
loglike_cutoff_a = kDoubleNegativeInfinity;
217+
loglike_cutoff_b = kDoubleNegativeInfinity;
218+
}
219+
202220
std::random_device rd;
203221
std::mt19937 gen(rd());
204222
std::bernoulli_distribution coin(0.5);
205-
for (auto i = 0; i != npath; ++i) {
223+
std::size_t n = 0;
224+
while (n < npath) {
206225
const auto &fsa = coin(gen) ? valid_a : valid_b;
207226
Fsa path, valid_path;
208227
RandomPathWithoutEpsilonArc(fsa, &path); // path is already connected
@@ -220,22 +239,28 @@ bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
220239
// find out that we don't need that version, we will remove flag
221240
// `top_sorted` and add requirements as comments in the header file.
222241
CHECK(top_sorted);
223-
double sum_a =
242+
double cost_a =
224243
ShortestDistance<Type>(a_compose_path, a_compose_weights.data());
225-
double sum_b =
244+
double cost_b =
226245
ShortestDistance<Type>(b_compose_path, b_compose_weights.data());
227-
if (!DoubleApproxEqual(sum_a, sum_b)) return false;
246+
if (cost_a < loglike_cutoff_a && cost_b < loglike_cutoff_b) {
247+
continue;
248+
} else {
249+
if (!DoubleApproxEqual(cost_a, cost_b, delta)) return false;
250+
++n;
251+
}
228252
}
229253
return true;
230254
}
231255

232256
// explicit instantiation here
233257
template bool IsRandEquivalent<kMaxWeight>(const Fsa &a, const float *a_weights,
234258
const Fsa &b, const float *b_weights,
259+
float beam, float delta,
235260
bool top_sorted, std::size_t npath);
236261
template bool IsRandEquivalent<kLogSumWeight>(
237262
const Fsa &a, const float *a_weights, const Fsa &b, const float *b_weights,
238-
bool top_sorted, std::size_t npath);
263+
float beam, float delta, bool top_sorted, std::size_t npath);
239264

240265
bool RandomPath(const Fsa &a, Fsa *b,
241266
std::vector<int32_t> *state_map /*=nullptr*/) {

k2/csrc/fsa_equivalent.h

+24-9
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,32 @@ bool IsRandEquivalent(const Fsa &a, const Fsa &b, std::size_t npath = 100);
3232
@param [in] a_weights Arc weights of `a`
3333
@param [in] b The other FSA to be checked the equivalence
3434
@param [in] b_weights Arc weights of `b`
35-
@param [in] top_sorted If both `a` and `b` are topological sorted or not.
36-
We may remove this flag if we finally find out that
37-
input FSAs in all scenarios are top-sorted.
35+
@param [in] beam beam > 0 that affects pruning; the algorithm
36+
will only check paths within `beam` of the
37+
best path(for tropical semiring, it's max
38+
weight over all paths from start state to
39+
final state; for log semiring, it's log-sum probs
40+
over all paths) in `a` or `b`. That is,
41+
any symbol sequence, whose total weights
42+
over all paths are within `beam` of the best
43+
path (either in `a` or `b`), must have
44+
the same weights in `a` and `b`.
45+
There is no any requirement on symbol sequences
46+
whose total weights over paths are outside `beam`.
47+
Just keep `kFloatInfinity` if you don't want pruning.
48+
@param [in] delta Tolerance for path weights to check the equivalence.
49+
If abs(weights_a, weights_b) <= delta, we say the two
50+
paths are equivalent.
51+
@param [in] top_sorted The user may set this to true if both `a` and `b` are
52+
topologically sorted; this makes this function faster.
53+
Otherwise it must be set to false.
3854
@param [in] npath The number of paths will be generated to check the
3955
equivalence of `a` and `b`
4056
*/
4157
template <FbWeightType Type>
4258
bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
43-
const float *b_weights, bool top_sorted = true,
59+
const float *b_weights, float beam = kFloatInfinity,
60+
float delta = 1e-6, bool top_sorted = true,
4461
std::size_t npath = 100);
4562

4663
/*
@@ -61,18 +78,16 @@ bool RandomPath(const Fsa &a, Fsa *b,
6178
bool RandomPathWithoutEpsilonArc(const Fsa &a, Fsa *b,
6279
std::vector<int32_t> *state_map = nullptr);
6380
/*
64-
Computes the intersection of two FSAs where one FSA has weights on arc. This
65-
function will be called in the version of `IsRandEquivalent` for Wfsa.
81+
Computes the intersection of two FSAs where one FSA has weights on arc.
6682
6783
@param [in] a One of the FSAs to be intersected. Must satisfy
6884
ArcSorted(a)
6985
@param [in] a_weights Arc weights of `a`
7086
@param [in] b The other FSA to be intersected Must satisfy
7187
ArcSorted(b) and IsEpsilonFree(b). It is usually a path
72-
generated from `RandomNonEpsilonPath`
88+
generated from `RandomNonEpsilonPath`
7389
@param [out] c The composed FSA will be output to here.
74-
@param [out] c_weights Arc weights of output FSA `c` which are corresponding
75-
arc weights in `a`
90+
@param [out] c_weights Arc weights of output FSA `c`.
7691
@param [out] arc_map_a If non-NULL, at exit will be a vector of
7792
size c->arcs.size(), saying for each arc in
7893
`c` what the source arc in `a` was, `-1` represents

k2/csrc/fsa_equivalent_test.cc

+13
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,19 @@ TEST(FsaEquivalent, IsWfsaRandEquivalent) {
149149
c_weights.data());
150150
EXPECT_FALSE(status);
151151
}
152+
153+
// check equivalence with beam
154+
{
155+
bool status = IsRandEquivalent<kMaxWeight>(a, a_weights.data(), b,
156+
b_weights.data(), 4.0);
157+
EXPECT_TRUE(status);
158+
}
159+
// check equivalence with beam
160+
{
161+
bool status = IsRandEquivalent<kMaxWeight>(a, a_weights.data(), c,
162+
c_weights.data(), 6.0);
163+
EXPECT_FALSE(status);
164+
}
152165
}
153166

154167
TEST(FsaEquivalent, RandomPathFail) {

k2/csrc/properties.cc

+2-6
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ bool IsValid(const Fsa &fsa) {
4545
if (arc.src_state == state) {
4646
++num_arcs;
4747
} else {
48-
// every state contains at least one arc.
49-
if (arc.src_state != state + 1) return false;
5048
// `arc_indexes` and `arcs` in this state are not consistent.
5149
if ((fsa.arc_indexes[state + 1] - fsa.arc_indexes[state]) != num_arcs)
5250
return false;
@@ -55,7 +53,6 @@ bool IsValid(const Fsa &fsa) {
5553
}
5654
}
5755
// check the last state
58-
if (final_state != state + 1) return false;
5956
if ((fsa.arc_indexes[final_state] - fsa.arc_indexes[state]) != num_arcs)
6057
return false;
6158
return true;
@@ -106,8 +103,7 @@ bool IsAcyclic(const Fsa &fsa, std::vector<int32_t> *order /*= nullptr*/) {
106103
if (current_state.arc_begin == current_state.arc_end) {
107104
// we have finished visiting this state
108105
state_status[current_state.state] = kVisited;
109-
if (order != nullptr)
110-
order->push_back(current_state.state);
106+
if (order != nullptr) order->push_back(current_state.state);
111107
stack.pop();
112108
continue;
113109
}
@@ -119,7 +115,7 @@ bool IsAcyclic(const Fsa &fsa, std::vector<int32_t> *order /*= nullptr*/) {
119115
// a new discovered node
120116
state_status[next_state] = kVisiting;
121117
stack.push({next_state, fsa.arc_indexes[next_state],
122-
fsa.arc_indexes[next_state + 1]});
118+
fsa.arc_indexes[next_state + 1]});
123119
++current_state.arc_begin;
124120
break;
125121
}

k2/csrc/properties.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ enum Properties {
3232
`fsa` is valid if:
3333
1. it is empty, if not, it contains at least two states.
3434
2. only kFinalSymbol arcs enter the final state.
35-
3. every state contains at least one arc except the final state.
36-
4. `arcs_indexes` and `arcs` in this state are not consistent.
35+
3. `arcs_indexes` and `arcs` in this state are not consistent.
3736
TODO(haowen): add more rules?
3837
*/
3938
bool IsValid(const Fsa &fsa);
@@ -59,7 +58,7 @@ bool HasSelfLoops(const Fsa &fsa);
5958
accessible (i.e. from the start state) are not considered.
6059
The optional argument order, assigns the order in which visiting states is
6160
finished in DFS traversal. State 0 has the largest order (num_states - 1) and
62-
the final state has the smallest order (0).
61+
the final state has the smallest order (0).
6362
*/
6463
bool IsAcyclic(const Fsa &fsa, std::vector<int32_t> *order = nullptr);
6564

0 commit comments

Comments
 (0)