@@ -169,8 +169,11 @@ bool IsRandEquivalent(const Fsa &a, const Fsa &b, std::size_t npath /*=100*/) {
169
169
170
170
template <FbWeightType Type>
171
171
bool IsRandEquivalent (const Fsa &a, const float *a_weights, const Fsa &b,
172
- const float *b_weights, bool top_sorted /* =true*/ ,
173
- std::size_t npath /* = 100*/ ) {
172
+ const float *b_weights, float beam /* =kFloatInfinity*/ ,
173
+ bool top_sorted /* =true*/ , std::size_t npath /* = 100*/ ) {
174
+ CHECK_GT (beam, 0 );
175
+ CHECK_NOTNULL (a_weights);
176
+ CHECK_NOTNULL (b_weights);
174
177
Fsa connected_a, connected_b, valid_a, valid_b;
175
178
std::vector<int32_t > connected_a_arc_map, connected_b_arc_map,
176
179
valid_a_arc_map, valid_b_arc_map;
@@ -199,6 +202,13 @@ bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
199
202
(*(labels_difference.begin ())) != kEpsilon ))
200
203
return false ;
201
204
205
+ double loglike_cutoff;
206
+ if (beam != kFloatInfinity )
207
+ loglike_cutoff =
208
+ ShortestDistance<Type>(valid_a, valid_a_weights.data ()) - beam;
209
+ else
210
+ loglike_cutoff = kDoubleNegativeInfinity ;
211
+
202
212
std::random_device rd;
203
213
std::mt19937 gen (rd ());
204
214
std::bernoulli_distribution coin (0.5 );
@@ -222,6 +232,7 @@ bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
222
232
CHECK (top_sorted);
223
233
double sum_a =
224
234
ShortestDistance<Type>(a_compose_path, a_compose_weights.data ());
235
+ if (sum_a < loglike_cutoff) sum_a = kDoubleNegativeInfinity ;
225
236
double sum_b =
226
237
ShortestDistance<Type>(b_compose_path, b_compose_weights.data ());
227
238
if (!DoubleApproxEqual (sum_a, sum_b)) return false ;
@@ -232,10 +243,11 @@ bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
232
243
// explicit instantiation here
233
244
template bool IsRandEquivalent<kMaxWeight >(const Fsa &a, const float *a_weights,
234
245
const Fsa &b, const float *b_weights,
235
- bool top_sorted, std::size_t npath);
246
+ float beam, bool top_sorted,
247
+ std::size_t npath);
236
248
template bool IsRandEquivalent<kLogSumWeight >(
237
249
const Fsa &a, const float *a_weights, const Fsa &b, const float *b_weights,
238
- bool top_sorted, std::size_t npath);
250
+ float beam, bool top_sorted, std::size_t npath);
239
251
240
252
bool RandomPath (const Fsa &a, Fsa *b,
241
253
std::vector<int32_t > *state_map /* =nullptr*/ ) {
0 commit comments