Skip to content

Commit 39291d4

Browse files
committed
add beam in RandEquivalent
1 parent 7094857 commit 39291d4

File tree

4 files changed

+36
-11
lines changed

4 files changed

+36
-11
lines changed

k2/csrc/fsa_equivalent.cc

+16-4
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,11 @@ bool IsRandEquivalent(const Fsa &a, const Fsa &b, std::size_t npath /*=100*/) {
169169

170170
template <FbWeightType Type>
171171
bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
172-
const float *b_weights, bool top_sorted /*=true*/,
173-
std::size_t npath /*= 100*/) {
172+
const float *b_weights, float beam /*=kFloatInfinity*/,
173+
bool top_sorted /*=true*/, std::size_t npath /*= 100*/) {
174+
CHECK_GT(beam, 0);
175+
CHECK_NOTNULL(a_weights);
176+
CHECK_NOTNULL(b_weights);
174177
Fsa connected_a, connected_b, valid_a, valid_b;
175178
std::vector<int32_t> connected_a_arc_map, connected_b_arc_map,
176179
valid_a_arc_map, valid_b_arc_map;
@@ -199,6 +202,13 @@ bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
199202
(*(labels_difference.begin())) != kEpsilon))
200203
return false;
201204

205+
double loglike_cutoff;
206+
if (beam != kFloatInfinity)
207+
loglike_cutoff =
208+
ShortestDistance<Type>(valid_a, valid_a_weights.data()) - beam;
209+
else
210+
loglike_cutoff = kDoubleNegativeInfinity;
211+
202212
std::random_device rd;
203213
std::mt19937 gen(rd());
204214
std::bernoulli_distribution coin(0.5);
@@ -222,6 +232,7 @@ bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
222232
CHECK(top_sorted);
223233
double sum_a =
224234
ShortestDistance<Type>(a_compose_path, a_compose_weights.data());
235+
if (sum_a < loglike_cutoff) sum_a = kDoubleNegativeInfinity;
225236
double sum_b =
226237
ShortestDistance<Type>(b_compose_path, b_compose_weights.data());
227238
if (!DoubleApproxEqual(sum_a, sum_b)) return false;
@@ -232,10 +243,11 @@ bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
232243
// explicit instantiation here
233244
template bool IsRandEquivalent<kMaxWeight>(const Fsa &a, const float *a_weights,
234245
const Fsa &b, const float *b_weights,
235-
bool top_sorted, std::size_t npath);
246+
float beam, bool top_sorted,
247+
std::size_t npath);
236248
template bool IsRandEquivalent<kLogSumWeight>(
237249
const Fsa &a, const float *a_weights, const Fsa &b, const float *b_weights,
238-
bool top_sorted, std::size_t npath);
250+
float beam, bool top_sorted, std::size_t npath);
239251

240252
bool RandomPath(const Fsa &a, Fsa *b,
241253
std::vector<int32_t> *state_map /*=nullptr*/) {

k2/csrc/fsa_equivalent.h

+11-7
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,20 @@ bool IsRandEquivalent(const Fsa &a, const Fsa &b, std::size_t npath = 100);
3232
@param [in] a_weights Arc weights of `a`
3333
@param [in] b The other FSA to be checked the equivalence
3434
@param [in] b_weights Arc weights of `b`
35-
@param [in] top_sorted If both `a` and `b` are topological sorted or not.
36-
We may remove this flag if we finally find out that
37-
input FSAs in all scenarios are top-sorted.
35+
@param [in] beam beam > 0 that affects pruning in `a`; the algorithm
36+
will keep paths that are within `beam` of the best
37+
path in `a` to check the equivalence. Just keep
38+
`kFloatInfinity` if you don't want pruning.
39+
@param [in] top_sorted The user may set this to true if both `a` and `b` are
40+
topologically sorted; this makes this function faster.
41+
Otherwise it must be set to false.
3842
@param [in] npath The number of paths will be generated to check the
3943
equivalence of `a` and `b`
4044
*/
4145
template <FbWeightType Type>
4246
bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
43-
const float *b_weights, bool top_sorted = true,
44-
std::size_t npath = 100);
47+
const float *b_weights, float beam = kFloatInfinity,
48+
bool top_sorted = true, std::size_t npath = 100);
4549

4650
/*
4751
Gets a random path from an Fsa `a`, returns true if we get one path
@@ -69,10 +73,10 @@ bool RandomPathWithoutEpsilonArc(const Fsa &a, Fsa *b,
6973
@param [in] a_weights Arc weights of `a`
7074
@param [in] b The other FSA to be intersected Must satisfy
7175
ArcSorted(b) and IsEpsilonFree(b). It is usually a path
72-
generated from `RandomNonEpsilonPath`
76+
generated from `RandomNonEpsilonPath`
7377
@param [out] c The composed FSA will be output to here.
7478
@param [out] c_weights Arc weights of output FSA `c` which are corresponding
75-
arc weights in `a`
79+
arc weights in `a`
7680
@param [out] arc_map_a If non-NULL, at exit will be a vector of
7781
size c->arcs.size(), saying for each arc in
7882
`c` what the source arc in `a` was, `-1` represents

k2/csrc/fsa_equivalent_test.cc

+7
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,13 @@ TEST(FsaEquivalent, IsWfsaRandEquivalent) {
149149
c_weights.data());
150150
EXPECT_FALSE(status);
151151
}
152+
153+
// check equivalence with beam
154+
{
155+
bool status = IsRandEquivalent<kMaxWeight>(a, a_weights.data(), b,
156+
b_weights.data(), 3);
157+
EXPECT_FALSE(status);
158+
}
152159
}
153160

154161
TEST(FsaEquivalent, RandomPathFail) {

k2/csrc/weights.h

+2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919

2020
namespace k2 {
2121

22+
constexpr float kFloatInfinity = std::numeric_limits<float>::infinity();
2223
constexpr float kFloatNegativeInfinity =
2324
-std::numeric_limits<float>::infinity();
25+
constexpr double kDoubleInfinity = std::numeric_limits<double>::infinity();
2426
constexpr double kDoubleNegativeInfinity =
2527
-std::numeric_limits<double>::infinity();
2628

0 commit comments

Comments
 (0)