@@ -52,78 +52,6 @@ inline int32_t InsertIntersectionState(
52
52
return result.first ->second ;
53
53
}
54
54
55
- /* *
56
- A TraceBack() function used in RmEpsilonsPrunedLogSum. It finds derivative
57
- information for all arcs in a sub-graph. Generally, in
58
- RmEpsilonsPrunedLogSum, we actually get a sub-graph when we find a
59
- non-epsilon arc starting from a particular state `s` (from which we are
60
- trying to remove epsilon arcs). All leaving arcs of all states in this
61
- sub-graph are epsilon arcs except the last one. Then, from the last state, we
62
- need to trace back to state `s` to find the derivative information for all
63
- epsilon arcs in this graph.
64
- @param [in] curr_states (This is consumed destructively, i.e. don't
65
- expect it to contain the same set on exit).
66
- A set of states, stored as a std::map that mapping
67
- state_id in input FSA to the corresponding
68
- LogSumTracebackState we created for this state;
69
- we'll iteratively trace back this set one element
70
- (processing all entering arcs) at a time. At entry
71
- it must have size() == 1 which contains the last
72
- state mentioned above; it will also have size() == 1
73
- at exit which contains the state `s` above.
74
- @param [in] arc_weights_in Weights on the arcs of the input FSA
75
- @param [out] deriv_out Some derivative information at the output
76
- will be written to here, which tells us how the weight
77
- of the non-epsilon arc we created from the above
78
- sub-graph varies as a function of the weights on the
79
- arcs of the input FSA; it's a list
80
- (input_arc_id, deriv) where, mathematically,
81
- 0 < deriv <= 1 (but we might still get exact zeros
82
- due to limitations of floating point representation).
83
- */
84
- static void TraceBackRmEpsilonsLogSum (
85
- std::map<int32_t , k2::LogSumTracebackState *> *curr_states,
86
- const float *arc_weights_in,
87
- std::vector<std::pair<int32_t , float >> *deriv_out) {
88
- CHECK_EQ (curr_states->size (), 1 );
89
- deriv_out->clear ();
90
-
91
- // as the input fsa is top-sorted, we traverse states in a reverse order so we
92
- // can process them when they already have correct backward_prob (all leaving
93
- // arcs have been processed).
94
- k2::LogSumTracebackState *state_ptr = curr_states->rbegin ()->second ;
95
- // In the standard forward-backward algorithm for HMMs this backward_prob
96
- // would, mathematically, be 0.0, but if we set it to the negative of the
97
- // forward prob we can avoid having to subtract the total log-prob
98
- // when we compute posterior/occupation probabilities for arcs.
99
- state_ptr->backward_prob = -state_ptr->forward_prob ;
100
- while (!state_ptr->prev_elements .empty ()) {
101
- double backward_prob = state_ptr->backward_prob ;
102
- for (const auto &link : state_ptr->prev_elements ) {
103
- auto arc_log_posterior =
104
- static_cast <float >(link .forward_prob + backward_prob);
105
- deriv_out->emplace_back (link .arc_index , expf (arc_log_posterior));
106
- k2::LogSumTracebackState *prev_state = link .prev_state .get ();
107
- double new_backward_prob = backward_prob + arc_weights_in[link .arc_index ];
108
- auto result = curr_states->emplace (prev_state->state_id , prev_state);
109
- if (result.second ) {
110
- prev_state->backward_prob = new_backward_prob;
111
- } else {
112
- prev_state->backward_prob =
113
- k2::LogAdd (new_backward_prob, prev_state->backward_prob );
114
- }
115
- }
116
- // we have processed all entering arcs of state curr_states->rbegin(),
117
- // we'll remove it now. As std::map.erase() does not support passing a
118
- // reverse iterator, we here pass --end();
119
- curr_states->erase (--curr_states->end ());
120
- CHECK (!curr_states->empty ());
121
- state_ptr = curr_states->rbegin ()->second ;
122
- }
123
- // we have reached the state from which we are trying to remove epsilon arcs.
124
- CHECK_EQ (curr_states->size (), 1 );
125
- }
126
-
127
55
} // namespace
128
56
129
57
namespace k2 {
@@ -350,215 +278,6 @@ bool Connect(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map /*=nullptr*/) {
350
278
return is_acyclic;
351
279
}
352
280
353
- void RmEpsilonsPrunedMax (const WfsaWithFbWeights &a, float beam, Fsa *b,
354
- std::vector<std::vector<int32_t >> *arc_derivs) {
355
- CHECK_EQ (a.weight_type , kMaxWeight );
356
- CHECK_GT (beam, 0 );
357
- CHECK_NOTNULL (b);
358
- CHECK_NOTNULL (arc_derivs);
359
- b->arc_indexes .clear ();
360
- b->arcs .clear ();
361
- arc_derivs->clear ();
362
-
363
- const auto &fsa = a.fsa ;
364
- if (IsEmpty (fsa)) return ;
365
- int32_t num_states_a = fsa.NumStates ();
366
- int32_t final_state = fsa.FinalState ();
367
- const auto &arcs_a = fsa.data ;
368
- const float *arc_weights_a = a.arc_weights ;
369
-
370
- // identify all states that should be kept
371
- std::vector<char > non_eps_in (num_states_a, 0 );
372
- non_eps_in[0 ] = 1 ;
373
- for (const auto &arc : fsa) {
374
- // We suppose the input fsa `a` is top-sorted, but only check this in DEBUG
375
- // time.
376
- DCHECK_GE (arc.dest_state , arc.src_state );
377
- if (arc.label != kEpsilon ) non_eps_in[arc.dest_state ] = 1 ;
378
- }
379
-
380
- // remap state id
381
- std::vector<int32_t > state_map_a2b (num_states_a, -1 );
382
- int32_t num_states_b = 0 ;
383
- for (int32_t i = 0 ; i != num_states_a; ++i) {
384
- if (non_eps_in[i] == 1 ) state_map_a2b[i] = num_states_b++;
385
- }
386
- b->arc_indexes .reserve (num_states_b + 1 );
387
- int32_t arc_num_b = 0 ;
388
-
389
- const double *forward_state_weights = a.ForwardStateWeights ();
390
- const double *backward_state_weights = a.BackwardStateWeights ();
391
- const double best_weight = forward_state_weights[final_state] - beam;
392
- for (int32_t i = 0 ; i != num_states_a; ++i) {
393
- if (non_eps_in[i] != 1 ) continue ;
394
- b->arc_indexes .push_back (arc_num_b);
395
- int32_t curr_state_b = state_map_a2b[i];
396
- // as the input FSA is top-sorted, we use a map here so we can process
397
- // states when they already have the best cost they are going to get
398
- std::map<int32_t , double >
399
- local_forward_weights; // state -> local_forward_state_weights of this
400
- // state
401
- // state -> (src_state, arc_index) entering this state which contributes to
402
- // `local_forward_weights` of this state.
403
- std::unordered_map<int32_t , std::pair<int32_t , int32_t >>
404
- local_backward_arcs;
405
- local_forward_weights.emplace (i, forward_state_weights[i]);
406
- // `-1` means we have traced back to current state `i`
407
- local_backward_arcs.emplace (i, std::make_pair (i, -1 ));
408
- while (!local_forward_weights.empty ()) {
409
- std::pair<int32_t , double > curr_local_forward_weights =
410
- *(local_forward_weights.begin ());
411
- local_forward_weights.erase (local_forward_weights.begin ());
412
- int32_t state = curr_local_forward_weights.first ;
413
-
414
- int32_t arc_end = fsa.indexes [state + 1 ];
415
- for (int32_t arc_index = fsa.indexes [state]; arc_index != arc_end;
416
- ++arc_index) {
417
- int32_t next_state = arcs_a[arc_index].dest_state ;
418
- int32_t label = arcs_a[arc_index].label ;
419
- double next_weight =
420
- curr_local_forward_weights.second + arc_weights_a[arc_index];
421
- if (next_weight + backward_state_weights[next_state] >= best_weight) {
422
- if (label == kEpsilon ) {
423
- auto result =
424
- local_forward_weights.emplace (next_state, next_weight);
425
- if (result.second ) {
426
- local_backward_arcs[next_state] =
427
- std::make_pair (state, arc_index);
428
- } else {
429
- if (next_weight > result.first ->second ) {
430
- result.first ->second = next_weight;
431
- local_backward_arcs[next_state] =
432
- std::make_pair (state, arc_index);
433
- }
434
- }
435
- } else {
436
- b->arcs .emplace_back (curr_state_b, state_map_a2b[next_state],
437
- label);
438
- std::vector<int32_t > curr_arc_deriv;
439
- std::pair<int32_t , int32_t > curr_backward_arc{state, arc_index};
440
- auto *backward_arc = &curr_backward_arc;
441
- while (backward_arc->second != -1 ) {
442
- curr_arc_deriv.push_back (backward_arc->second );
443
- backward_arc = &(local_backward_arcs[backward_arc->first ]);
444
- }
445
- std::reverse (curr_arc_deriv.begin (), curr_arc_deriv.end ());
446
- arc_derivs->emplace_back (std::move (curr_arc_deriv));
447
- ++arc_num_b;
448
- }
449
- }
450
- }
451
- }
452
- }
453
- // duplicate of final state
454
- b->arc_indexes .push_back (b->arc_indexes .back ());
455
- }
456
-
457
- void RmEpsilonsPrunedLogSum (
458
- const WfsaWithFbWeights &a, float beam, Fsa *b,
459
- std::vector<float > *b_arc_weights,
460
- std::vector<std::vector<std::pair<int32_t , float >>> *arc_derivs) {
461
- CHECK_GT (beam, 0 );
462
- CHECK_NOTNULL (b);
463
- CHECK_NOTNULL (b_arc_weights);
464
- CHECK_NOTNULL (arc_derivs);
465
- b->arc_indexes .clear ();
466
- b->arcs .clear ();
467
- b_arc_weights->clear ();
468
- arc_derivs->clear ();
469
-
470
- const auto &fsa = a.fsa ;
471
- if (IsEmpty (fsa)) return ;
472
- int32_t num_states_a = fsa.NumStates ();
473
- int32_t final_state = fsa.FinalState ();
474
- const auto &arcs_a = fsa.data ;
475
- const float *arc_weights_a = a.arc_weights ;
476
-
477
- // identify all states that should be kept
478
- std::vector<char > non_eps_in (num_states_a, 0 );
479
- non_eps_in[0 ] = 1 ;
480
- for (const auto &arc : fsa) {
481
- // We suppose the input fsa `a` is top-sorted, but only check this in DEBUG
482
- // time.
483
- DCHECK_GE (arc.dest_state , arc.src_state );
484
- if (arc.label != kEpsilon ) non_eps_in[arc.dest_state ] = 1 ;
485
- }
486
-
487
- // remap state id
488
- std::vector<int32_t > state_map_a2b (num_states_a, -1 );
489
- int32_t num_states_b = 0 ;
490
- for (int32_t i = 0 ; i != num_states_a; ++i) {
491
- if (non_eps_in[i] == 1 ) state_map_a2b[i] = num_states_b++;
492
- }
493
- b->arc_indexes .reserve (num_states_b + 1 );
494
- int32_t arc_num_b = 0 ;
495
-
496
- const double *forward_state_weights = a.ForwardStateWeights ();
497
- const double *backward_state_weights = a.BackwardStateWeights ();
498
- const double best_weight = forward_state_weights[final_state] - beam;
499
- for (int32_t i = 0 ; i != num_states_a; ++i) {
500
- if (non_eps_in[i] != 1 ) continue ;
501
- b->arc_indexes .push_back (arc_num_b);
502
- int32_t curr_state_b = state_map_a2b[i];
503
- // as the input FSA is top-sorted, we use a set here so we can process
504
- // states when they already have costs over all paths they are going to get
505
- std::set<int32_t > qstates;
506
- std::unordered_map<int32_t , std::shared_ptr<LogSumTracebackState>>
507
- traceback_states; // state -> LogSumTracebackState of this state
508
- std::shared_ptr<LogSumTracebackState> start_state (
509
- new LogSumTracebackState (i, forward_state_weights[i]));
510
- double start_forward_weights = start_state->forward_prob ;
511
- traceback_states.emplace (i, start_state);
512
- qstates.insert (i);
513
- while (!qstates.empty ()) {
514
- int32_t state = *(qstates.begin ());
515
- qstates.erase (qstates.begin ());
516
-
517
- const auto &curr_traceback_state = traceback_states[state];
518
- double curr_forward_weights = curr_traceback_state->forward_prob ;
519
- int32_t arc_end = fsa.indexes [state + 1 ];
520
- for (int32_t arc_index = fsa.indexes [state]; arc_index != arc_end;
521
- ++arc_index) {
522
- int32_t next_state = arcs_a[arc_index].dest_state ;
523
- int32_t label = arcs_a[arc_index].label ;
524
- float curr_arc_weight = arc_weights_a[arc_index];
525
- double next_weight = curr_forward_weights + curr_arc_weight;
526
- if (next_weight + backward_state_weights[next_state] >= best_weight) {
527
- if (label == kEpsilon ) {
528
- auto result = traceback_states.emplace (next_state, nullptr );
529
- if (result.second ) {
530
- result.first ->second = std::make_shared<LogSumTracebackState>(
531
- next_state, curr_traceback_state, arc_index, curr_arc_weight);
532
- qstates.insert (next_state);
533
- } else {
534
- result.first ->second ->Accept (curr_traceback_state, arc_index,
535
- curr_arc_weight);
536
- }
537
- } else {
538
- b->arcs .emplace_back (curr_state_b, state_map_a2b[next_state],
539
- label);
540
- b_arc_weights->push_back (curr_forward_weights + curr_arc_weight -
541
- start_forward_weights);
542
-
543
- std::vector<std::pair<int32_t , float >> curr_arc_deriv;
544
- std::map<int32_t , LogSumTracebackState *> curr_states;
545
- curr_states.emplace (state, curr_traceback_state.get ());
546
- TraceBackRmEpsilonsLogSum (&curr_states, arc_weights_a,
547
- &curr_arc_deriv);
548
- std::reverse (curr_arc_deriv.begin (), curr_arc_deriv.end ());
549
- // push derivs info of current arc
550
- curr_arc_deriv.emplace_back (arc_index, 1 );
551
- arc_derivs->emplace_back (std::move (curr_arc_deriv));
552
- ++arc_num_b;
553
- }
554
- }
555
- }
556
- }
557
- }
558
- // duplicate of final state
559
- b->arc_indexes .push_back (b->arc_indexes .back ());
560
- }
561
-
562
281
bool Intersect (const Fsa &a, const Fsa &b, Fsa *c,
563
282
std::vector<int32_t > *arc_map_a /* = nullptr*/ ,
564
283
std::vector<int32_t > *arc_map_b /* = nullptr*/ ) {
0 commit comments