@@ -169,8 +169,12 @@ bool IsRandEquivalent(const Fsa &a, const Fsa &b, std::size_t npath /*=100*/) {
169
169
170
170
template <FbWeightType Type>
171
171
bool IsRandEquivalent (const Fsa &a, const float *a_weights, const Fsa &b,
172
- const float *b_weights, bool top_sorted /* =true*/ ,
172
+ const float *b_weights, float beam /* =kFloatInfinity*/ ,
173
+ float delta /* =1e-6*/ , bool top_sorted /* =true*/ ,
173
174
std::size_t npath /* = 100*/ ) {
175
+ CHECK_GT (beam, 0 );
176
+ CHECK_NOTNULL (a_weights);
177
+ CHECK_NOTNULL (b_weights);
174
178
Fsa connected_a, connected_b, valid_a, valid_b;
175
179
std::vector<int32_t > connected_a_arc_map, connected_b_arc_map,
176
180
valid_a_arc_map, valid_b_arc_map;
@@ -199,10 +203,25 @@ bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
199
203
(*(labels_difference.begin ())) != kEpsilon ))
200
204
return false ;
201
205
206
+ double loglike_cutoff_a, loglike_cutoff_b;
207
+ if (beam != kFloatInfinity ) {
208
+ loglike_cutoff_a =
209
+ ShortestDistance<Type>(valid_a, valid_a_weights.data ()) - beam;
210
+ loglike_cutoff_b =
211
+ ShortestDistance<Type>(valid_b, valid_b_weights.data ()) - beam;
212
+ if (Type == kMaxWeight &&
213
+ !DoubleApproxEqual (loglike_cutoff_a, loglike_cutoff_b))
214
+ return false ;
215
+ } else {
216
+ loglike_cutoff_a = kDoubleNegativeInfinity ;
217
+ loglike_cutoff_b = kDoubleNegativeInfinity ;
218
+ }
219
+
202
220
std::random_device rd;
203
221
std::mt19937 gen (rd ());
204
222
std::bernoulli_distribution coin (0.5 );
205
- for (auto i = 0 ; i != npath; ++i) {
223
+ std::size_t n = 0 ;
224
+ while (n < npath) {
206
225
const auto &fsa = coin (gen) ? valid_a : valid_b;
207
226
Fsa path, valid_path;
208
227
RandomPathWithoutEpsilonArc (fsa, &path); // path is already connected
@@ -220,22 +239,28 @@ bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
220
239
// find out that we don't need that version, we will remove flag
221
240
// `top_sorted` and add requirements as comments in the header file.
222
241
CHECK (top_sorted);
223
- double sum_a =
242
+ double cost_a =
224
243
ShortestDistance<Type>(a_compose_path, a_compose_weights.data ());
225
- double sum_b =
244
+ double cost_b =
226
245
ShortestDistance<Type>(b_compose_path, b_compose_weights.data ());
227
- if (!DoubleApproxEqual (sum_a, sum_b)) return false ;
246
+ if (cost_a < loglike_cutoff_a && cost_b < loglike_cutoff_b) {
247
+ continue ;
248
+ } else {
249
+ if (!DoubleApproxEqual (cost_a, cost_b, delta)) return false ;
250
+ ++n;
251
+ }
228
252
}
229
253
return true ;
230
254
}
231
255
232
256
// explicit instantiation here
233
257
template bool IsRandEquivalent<kMaxWeight >(const Fsa &a, const float *a_weights,
234
258
const Fsa &b, const float *b_weights,
259
+ float beam, float delta,
235
260
bool top_sorted, std::size_t npath);
236
261
template bool IsRandEquivalent<kLogSumWeight >(
237
262
const Fsa &a, const float *a_weights, const Fsa &b, const float *b_weights,
238
- bool top_sorted, std::size_t npath);
263
+ float beam, float delta, bool top_sorted, std::size_t npath);
239
264
240
265
bool RandomPath (const Fsa &a, Fsa *b,
241
266
std::vector<int32_t > *state_map /* =nullptr*/ ) {
0 commit comments