5
5
6
6
// See ../../LICENSE for clarification regarding multiple authors
7
7
8
+ #include < algorithm>
8
9
#include < utility>
9
10
#include < vector>
10
- #include < algorithm>
11
11
12
12
#include " k2/csrc/fsa_algo.h"
13
13
14
14
namespace k2 {
15
15
16
+ using std::pair;
17
+ using std::priority_queue;
16
18
using std::shared_ptr;
17
19
using std::vector;
18
- using std::priority_queue;
19
- using std::pair
20
-
21
20
22
21
struct MaxTracebackState {
23
22
// Element of a path from the start state to some state in an FSA
@@ -29,114 +28,100 @@ struct MaxTracebackState {
29
28
// the dest-state, or -1 if the path is empty (only
30
29
// possible if this element belongs to the start-state).
31
30
32
- int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA
33
- // (copied here for convenience), or 0 if arc_index == -1.
34
-
35
- MaxTracebackState (std::shared_ptr<MaxTracebackState> prev,
36
- int32_t arc_index, int32_t symbol):
37
- prev (prev), arc_index(arc_index), symbol(symbol) { }
31
+ int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA
32
+ // (copied here for convenience), or 0 if arc_index == -1.
38
33
34
+ MaxTracebackState (std::shared_ptr<MaxTracebackState> prev, int32_t arc_index,
35
+ int32_t symbol)
36
+ : prev(prev), arc_index(arc_index), symbol(symbol) {}
39
37
};
40
38
41
-
42
39
class LogSumTracebackState ;
43
40
44
41
// This struct is used inside LogSumTracebackState; it represents an
45
42
// arc that traces back to a previous LogSumTracebackState.
46
43
// A LogSumTracebackState represents a weighted colletion of paths
47
44
// terminating in a specific state.
48
45
struct LogSumTracebackLink {
49
-
50
46
int32_t arc_index; // Index of most recent arc in path from start-state to
51
47
// the dest-state, or -1 if the path is empty (only
52
48
// possible if this element belongs to the start-state).
53
49
54
- int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA
55
- // (copied here for convenience), or 0 if arc_index == -1.
50
+ int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA
51
+ // (copied here for convenience), or 0 if arc_index == -1.
56
52
57
- double prob; // The probability mass associated with this incoming
58
- // arc in the LogSumTracebackState to which this belongs.
53
+ double prob; // The probability mass associated with this incoming
54
+ // arc in the LogSumTracebackState to which this belongs.
59
55
60
56
std::shared_ptr<LogSumTracebackState> prev_state;
61
57
};
62
58
63
59
struct LogSumTracebackState {
64
- // LogSumTracebackState can be thought of as as a weighted set of paths from the
65
- // start state to a particular state. (It will be limited to the subset of
66
- // paths that have a specific symbol sequence).
60
+ // LogSumTracebackState can be thought of as as a weighted set of paths from
61
+ // the start state to a particular state. (It will be limited to the subset
62
+ // of paths that have a specific symbol sequence).
67
63
68
64
// `prev_elements` is, conceptually, a list of pairs (incoming arc-index,
69
65
// traceback link); we will keep it free of duplicates of the same incoming
70
66
// arc.
71
67
vector<LogSumTracebackLink> prev_elements;
72
68
73
-
74
69
int32_t arc_index; // Index of most recent arc in path from start-state to
75
70
// the dest-state, or -1 if the path is empty (only
76
71
// possible if this element belongs to the start-state).
77
72
78
- int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA
79
- // (copied here for convenience), or 0 if arc_index == -1.
80
-
81
- MaxTracebackState (std::shared_ptr<MaxTracebackState> prev,
82
- int32_t arc_index, int32_t symbol):
83
- prev (prev), arc_index(arc_index), symbol(symbol) { }
73
+ int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA
74
+ // (copied here for convenience), or 0 if arc_index == -1.
84
75
76
+ MaxTracebackState (std::shared_ptr<MaxTracebackState> prev, int32_t arc_index,
77
+ int32_t symbol)
78
+ : prev(prev), arc_index(arc_index), symbol(symbol) {}
85
79
};
86
80
87
-
88
81
struct DetStateElement {
89
-
90
- double weight; // Weight from reference state to this state, along
91
- // the path taken by following the 'prev' links
92
- // (the path would have `seq_len` arcs in it).
93
- // Note: by "this state" we mean the destination-state of
94
- // the arc at `arc_index`.
95
- // Interpret this with caution, because the
96
- // base state, and the length of the sequence arcs from the
97
- // base state to here, are known only in the DetState
98
- // that owns this DetStateElement.
82
+ double weight; // Weight from reference state to this state, along
83
+ // the path taken by following the 'prev' links
84
+ // (the path would have `seq_len` arcs in it).
85
+ // Note: by "this state" we mean the destination-state of
86
+ // the arc at `arc_index`.
87
+ // Interpret this with caution, because the
88
+ // base state, and the length of the sequence arcs from the
89
+ // base state to here, are known only in the DetState
90
+ // that owns this DetStateElement.
99
91
100
92
std::shared_ptr<PathElement> path;
101
- // The path from the start state to here (actually we will
102
- // only follow back `seq_len` links. Will be nullptr if
103
- // seq_len == 0 and this belongs to the initial determinized
104
- // state.
105
-
106
- DetStateElement &&Advance(float arc_weight, int32_t arc_index, int32_t arc_symbol) {
107
- return DetStateElement (weight + arc_weight,
108
- std::make_shared<PathElement>(path, arc_index, arc_symbol));
93
+ // The path from the start state to here (actually we will
94
+ // only follow back `seq_len` links. Will be nullptr if
95
+ // seq_len == 0 and this belongs to the initial determinized
96
+ // state.
97
+
98
+ DetStateElement &&Advance(float arc_weight, int32_t arc_index,
99
+ int32_t arc_symbol) {
100
+ return DetStateElement (
101
+ weight + arc_weight,
102
+ std::make_shared<PathElement>(path, arc_index, arc_symbol));
109
103
}
110
104
111
- DetStateElement (double weight, std::shared_ptr<PathElement> &&path):
112
- weight (weight), path(path) { }
113
-
105
+ DetStateElement (double weight, std::shared_ptr<PathElement> &&path)
106
+ : weight(weight), path(path) {}
114
107
};
115
108
116
109
class DetState ;
117
110
118
-
119
111
struct DetStateCompare {
120
112
// Comparator for priority queue. Less-than operator that compares
121
113
// forward_backward_weight for best-first processing.
122
- bool operator ()(const shared_ptr<DetState> &a,
123
- const shared_ptr<DetState> &b);
114
+ bool operator ()(const shared_ptr<DetState> &a, const shared_ptr<DetState> &b);
124
115
};
125
116
126
-
127
-
128
117
class Determinizer {
129
118
public:
130
119
private:
131
-
132
- using DetStatePriorityQueue = priority_queue<shared_ptr<DetState>,
133
- vector<shared_ptr<DetState> >,
134
- DetStateCompare>;
135
-
136
-
120
+ using DetStatePriorityQueue =
121
+ priority_queue<shared_ptr<DetState>, vector<shared_ptr<DetState>>,
122
+ DetStateCompare>;
137
123
};
138
124
139
-
140
125
/*
141
126
Conceptually a determinized state in weighted FSA determinization would
142
127
normally
@@ -171,7 +156,6 @@ class DetState {
171
156
// those paths. When Normalize() is called we may advance
172
157
int32_t base_state;
173
158
174
-
175
159
// seq_len is the length of symbol sequence that we follow from state
176
160
// `base_state`. The sequence of symbols can be found by tracing back one of
177
161
// the DetStateElements in the doubly linked list (it doesn't matter which you
@@ -180,7 +164,6 @@ class DetState {
180
164
181
165
bool normalized{false };
182
166
183
-
184
167
std::list<DetStateElement> elements;
185
168
186
169
// This is the weight on the best path that includes this determinized state.
@@ -190,14 +173,12 @@ class DetState {
190
173
// the backward-weight of the state associated with that DetStateElement).
191
174
double forward_backward_weight;
192
175
193
-
194
176
/*
195
- Process arcs leaving this determinized state, possibly creating new determinized
196
- states in the process.
197
- @param [in] wfsa_in The input FSA that we are determinizing, along
198
- with forward-backward weights.
199
- The input FSA should normally be epsilon-free as
200
- epsilons are treated as a normal symbol; and require
177
+ Process arcs leaving this determinized state, possibly creating new
178
+ determinized states in the process.
179
+ @param [in] wfsa_in The input FSA that we are determinizing,
180
+ along with forward-backward weights. The input FSA should normally be
181
+ epsilon-free as epsilons are treated as a normal symbol; and require
201
182
wfsa_in.weight_tpe == kMaxWeight, for
202
183
now (might later create a version of this code
203
184
that works
@@ -210,13 +191,10 @@ class DetState {
210
191
211
192
212
193
*/
213
- void ProcessArcs (const WfsaWithFbWeights &wfsa_in,
214
- Fsa *wfsa_out,
215
- float prune_cutoff,
216
- DetStateMap *state_map,
194
+ void ProcessArcs (const WfsaWithFbWeights &wfsa_in, Fsa *wfsa_out,
195
+ float prune_cutoff, DetStateMap *state_map,
217
196
DetStatePriorityQueue *queue);
218
197
219
-
220
198
/*
221
199
Normalizes this DetState and sets forward_backward_weight.
222
200
@@ -259,18 +237,17 @@ class DetState {
259
237
plus the backward weight in the input FSA of the state that corresponds
260
238
to it).
261
239
*/
262
- void Normalize (const Fsa &input_fsa,
263
- const float *input_fsa_weights,
264
- float *removed_weight,
265
- std::vector<int32_t > *leftover_arcs) {
240
+ void Normalize (const Fsa &input_fsa, const float *input_fsa_weights,
241
+ float *removed_weight, std::vector<int32_t > *leftover_arcs) {
266
242
#ifndef NDEBUG
267
243
CheckElementOrder ();
268
244
#endif
269
245
RemoveDuplicatesOfStates (input_fsa);
270
- RemoveCommonPrefix (input_fsa, input_fsa_weights, removed_weight, leftover_arcs);
246
+ RemoveCommonPrefix (input_fsa, input_fsa_weights, removed_weight,
247
+ leftover_arcs);
271
248
}
272
- private:
273
249
250
+ private:
274
251
/*
275
252
Called from Normalize(), this function removes duplicates in
276
253
`elements`: that is, if two elements represent paths that terminate at
@@ -287,8 +264,7 @@ class DetState {
287
264
weights in `elements`, and set `input_arcs` to the sequence of arcs that
288
265
were removed from
289
266
*/
290
- RemoveCommonPrefix (const Fsa &input_fsa,
291
- const float *input_fsa_weights,
267
+ RemoveCommonPrefix (const Fsa &input_fsa, const float *input_fsa_weights,
292
268
std::vector<int32_t > *input_arcs);
293
269
/*
294
270
This function just does some checking on the `elements` list that
@@ -298,29 +274,24 @@ class DetState {
298
274
they are all the same.
299
275
*/
300
276
void CheckElementOrder () const ;
301
-
302
277
};
303
278
304
279
bool DetStateCompare::operator ()(const shared_ptr<DetState> &a,
305
280
const shared_ptr<DetState> &b) {
306
281
return a->forward_backward_weight < b->forward_backward_weight ;
307
282
}
308
283
309
-
310
-
311
284
void DetState::RemoveDuplicatesOfStates (const Fsa &input_fsa) {
312
-
313
285
/*
314
286
`state_to_elem` maps from int32_t state-id to the DetStateElement
315
287
associated with it (there can be only one, we choose the one with
316
288
the best weight).
317
289
*/
318
- std::unordered_map<int32_t , typename std::list<DetStateElement>::iterator> state_to_elem;
319
-
320
-
290
+ std::unordered_map<int32_t , typename std::list<DetStateElement>::iterator>
291
+ state_to_elem;
321
292
322
293
for (auto iter = elements.begin (); iter != elements.end (); ++iter) {
323
- int32_t state = input_fsa.arcs [elem.arc_index ].nextstate ;
294
+ int32_t state = input_fsa.arcs [elem.arc_index ].nextstate ;
324
295
auto p = state_to_elem.insert ({state, elem});
325
296
bool inserted = p.second ;
326
297
if (!inserted) {
@@ -339,11 +310,9 @@ void DetState::RemoveCommonPrefix(const Fsa &input_fsa,
339
310
const float *input_fsa_weights,
340
311
float *removed_weight_out,
341
312
std::vector<int32_t > *input_arcs) {
342
-
343
313
CHECK_GE (seq_len, 0 );
344
314
int32_t len;
345
- auto first_path = elements.front ().path ,
346
- last_path = elements.back ().path ;
315
+ auto first_path = elements.front ().path , last_path = elements.back ().path ;
347
316
348
317
for (len = 1 ; len < seq_len; ++len) {
349
318
first_path = first_path->prev ;
@@ -359,8 +328,7 @@ void DetState::RemoveCommonPrefix(const Fsa &input_fsa,
359
328
/* We reach a common state after traversing fewer than `seq_len` arcs,
360
329
so we can remove a shared prefix. */
361
330
double removed_weight = 0.0 ;
362
- int32_t new_seq_len = len,
363
- removed_seq_len = seq_len - len;
331
+ int32_t new_seq_len = len, removed_seq_len = seq_len - len;
364
332
input_arcs->resize (removed_seq_len);
365
333
// Advance base_state
366
334
int32_t new_base_state = input_fsa.arcs [first_path->arc_index ].src_state ;
@@ -375,7 +343,7 @@ void DetState::RemoveCommonPrefix(const Fsa &input_fsa,
375
343
fsa.arcs [first_path->arc_index ].dest_state == this ->base_state );
376
344
this ->base_state = new_base_state;
377
345
if (removed_weight != 0 ) {
378
- for (DetStateElement &det_state_elem: elements) {
346
+ for (DetStateElement &det_state_elem : elements) {
379
347
det_state_elem.weight -= removed_weight;
380
348
}
381
349
}
@@ -393,8 +361,8 @@ void DetState::CheckElementOrder(const Fsa &input_fsa) const {
393
361
// leaving each state in the FSA are sorted first on label and then on
394
362
// dest_state.
395
363
if (seq_len == 0 ) {
396
- CHECK (elements.size () == 1 );
397
- CHECK (elements.front ().weight == 0.0 );
364
+ CHECK_EQ (elements.size (), 1 );
365
+ CHECK_EQ (elements.front ().weight , 0.0 );
398
366
}
399
367
400
368
std::vector<int32> prev_seq;
@@ -413,14 +381,12 @@ void DetState::CheckElementOrder(const Fsa &input_fsa) const {
413
381
}
414
382
}
415
383
416
-
417
384
/*
418
385
This class maps from determinized states (DetState) to integer state-id
419
386
in the determinized output.
420
387
*/
421
388
class DetStateMap {
422
389
public:
423
-
424
390
/*
425
391
Outputs the output state-id corresponding to a specific DetState structure.
426
392
This does not store any pointers to the DetState or its contents, so
@@ -454,7 +420,8 @@ class DetStateMap {
454
420
private:
455
421
int32_t cur_output_state_{0 };
456
422
std::unordered_map<std::pair<uint64_t , uint64_t >, int32_t ,
457
- DetStateVectorHasher> map_;
423
+ DetStateVectorHasher>
424
+ map_;
458
425
459
426
/* Turns DetState into a compact form of 128 bits. Technically there
460
427
could be collisions, which would be fatal for the algorithm, but this
@@ -493,7 +460,7 @@ class DetStateMap {
493
460
};
494
461
495
462
void DeterminizeMax (const WfsaWithFbWeights &a, float beam, Fsa *b,
496
- std::vector<std::vector<int32_t > > *arc_map) {
463
+ std::vector<std::vector<int32_t >> *arc_map) {
497
464
// TODO(dpovey): use glog stuff.
498
465
assert (IsValid (a) && IsEpsilonFree (a) && IsTopSortedAndAcyclic (a));
499
466
if (a.arc_indexes .empty ()) {
0 commit comments