Skip to content

Commit dbba25b

Browse files
committed
implement ComputeForward/BackwardMaxWeights
1 parent c3ccd74 commit dbba25b

8 files changed

+273
-131
lines changed

k2/csrc/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_library(fsa
55
fsa_renderer.cc
66
fsa_util.cc
77
properties.cc
8+
weights.cc
89
)
910

1011
target_include_directories(fsa PUBLIC ${CMAKE_SOURCE_DIR})
@@ -38,6 +39,7 @@ set(fsa_tests
3839
fsa_renderer_test
3940
fsa_util_test
4041
properties_test
42+
weights_test
4143
)
4244

4345
foreach(name IN LISTS fsa_tests)

k2/csrc/determinize.cc

+69-102
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,18 @@
55

66
// See ../../LICENSE for clarification regarding multiple authors
77

8+
#include <algorithm>
89
#include <utility>
910
#include <vector>
10-
#include <algorithm>
1111

1212
#include "k2/csrc/fsa_algo.h"
1313

1414
namespace k2 {
1515

16+
using std::pair;
17+
using std::priority_queue;
1618
using std::shared_ptr;
1719
using std::vector;
18-
using std::priority_queue;
19-
using std::pair
20-
2120

2221
struct MaxTracebackState {
2322
// Element of a path from the start state to some state in an FSA
@@ -29,114 +28,100 @@ struct MaxTracebackState {
2928
// the dest-state, or -1 if the path is empty (only
3029
// possible if this element belongs to the start-state).
3130

32-
int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA
33-
// (copied here for convenience), or 0 if arc_index == -1.
34-
35-
MaxTracebackState(std::shared_ptr<MaxTracebackState> prev,
36-
int32_t arc_index, int32_t symbol):
37-
prev(prev), arc_index(arc_index), symbol(symbol) { }
31+
int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA
32+
// (copied here for convenience), or 0 if arc_index == -1.
3833

34+
MaxTracebackState(std::shared_ptr<MaxTracebackState> prev, int32_t arc_index,
35+
int32_t symbol)
36+
: prev(prev), arc_index(arc_index), symbol(symbol) {}
3937
};
4038

41-
4239
class LogSumTracebackState;
4340

4441
// This struct is used inside LogSumTracebackState; it represents an
4542
// arc that traces back to a previous LogSumTracebackState.
4643
// A LogSumTracebackState represents a weighted colletion of paths
4744
// terminating in a specific state.
4845
struct LogSumTracebackLink {
49-
5046
int32_t arc_index; // Index of most recent arc in path from start-state to
5147
// the dest-state, or -1 if the path is empty (only
5248
// possible if this element belongs to the start-state).
5349

54-
int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA
55-
// (copied here for convenience), or 0 if arc_index == -1.
50+
int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA
51+
// (copied here for convenience), or 0 if arc_index == -1.
5652

57-
double prob; // The probability mass associated with this incoming
58-
// arc in the LogSumTracebackState to which this belongs.
53+
double prob; // The probability mass associated with this incoming
54+
// arc in the LogSumTracebackState to which this belongs.
5955

6056
std::shared_ptr<LogSumTracebackState> prev_state;
6157
};
6258

6359
struct LogSumTracebackState {
64-
// LogSumTracebackState can be thought of as as a weighted set of paths from the
65-
// start state to a particular state. (It will be limited to the subset of
66-
// paths that have a specific symbol sequence).
60+
// LogSumTracebackState can be thought of as as a weighted set of paths from
61+
// the start state to a particular state. (It will be limited to the subset
62+
// of paths that have a specific symbol sequence).
6763

6864
// `prev_elements` is, conceptually, a list of pairs (incoming arc-index,
6965
// traceback link); we will keep it free of duplicates of the same incoming
7066
// arc.
7167
vector<LogSumTracebackLink> prev_elements;
7268

73-
7469
int32_t arc_index; // Index of most recent arc in path from start-state to
7570
// the dest-state, or -1 if the path is empty (only
7671
// possible if this element belongs to the start-state).
7772

78-
int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA
79-
// (copied here for convenience), or 0 if arc_index == -1.
80-
81-
MaxTracebackState(std::shared_ptr<MaxTracebackState> prev,
82-
int32_t arc_index, int32_t symbol):
83-
prev(prev), arc_index(arc_index), symbol(symbol) { }
73+
int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA
74+
// (copied here for convenience), or 0 if arc_index == -1.
8475

76+
MaxTracebackState(std::shared_ptr<MaxTracebackState> prev, int32_t arc_index,
77+
int32_t symbol)
78+
: prev(prev), arc_index(arc_index), symbol(symbol) {}
8579
};
8680

87-
8881
struct DetStateElement {
89-
90-
double weight; // Weight from reference state to this state, along
91-
// the path taken by following the 'prev' links
92-
// (the path would have `seq_len` arcs in it).
93-
// Note: by "this state" we mean the destination-state of
94-
// the arc at `arc_index`.
95-
// Interpret this with caution, because the
96-
// base state, and the length of the sequence arcs from the
97-
// base state to here, are known only in the DetState
98-
// that owns this DetStateElement.
82+
double weight; // Weight from reference state to this state, along
83+
// the path taken by following the 'prev' links
84+
// (the path would have `seq_len` arcs in it).
85+
// Note: by "this state" we mean the destination-state of
86+
// the arc at `arc_index`.
87+
// Interpret this with caution, because the
88+
// base state, and the length of the sequence arcs from the
89+
// base state to here, are known only in the DetState
90+
// that owns this DetStateElement.
9991

10092
std::shared_ptr<PathElement> path;
101-
// The path from the start state to here (actually we will
102-
// only follow back `seq_len` links. Will be nullptr if
103-
// seq_len == 0 and this belongs to the initial determinized
104-
// state.
105-
106-
DetStateElement &&Advance(float arc_weight, int32_t arc_index, int32_t arc_symbol) {
107-
return DetStateElement(weight + arc_weight,
108-
std::make_shared<PathElement>(path, arc_index, arc_symbol));
93+
// The path from the start state to here (actually we will
94+
// only follow back `seq_len` links. Will be nullptr if
95+
// seq_len == 0 and this belongs to the initial determinized
96+
// state.
97+
98+
DetStateElement &&Advance(float arc_weight, int32_t arc_index,
99+
int32_t arc_symbol) {
100+
return DetStateElement(
101+
weight + arc_weight,
102+
std::make_shared<PathElement>(path, arc_index, arc_symbol));
109103
}
110104

111-
DetStateElement(double weight, std::shared_ptr<PathElement> &&path):
112-
weight(weight), path(path) { }
113-
105+
DetStateElement(double weight, std::shared_ptr<PathElement> &&path)
106+
: weight(weight), path(path) {}
114107
};
115108

116109
class DetState;
117110

118-
119111
struct DetStateCompare {
120112
// Comparator for priority queue. Less-than operator that compares
121113
// forward_backward_weight for best-first processing.
122-
bool operator()(const shared_ptr<DetState> &a,
123-
const shared_ptr<DetState> &b);
114+
bool operator()(const shared_ptr<DetState> &a, const shared_ptr<DetState> &b);
124115
};
125116

126-
127-
128117
class Determinizer {
129118
public:
130119
private:
131-
132-
using DetStatePriorityQueue = priority_queue<shared_ptr<DetState>,
133-
vector<shared_ptr<DetState> >,
134-
DetStateCompare>;
135-
136-
120+
using DetStatePriorityQueue =
121+
priority_queue<shared_ptr<DetState>, vector<shared_ptr<DetState>>,
122+
DetStateCompare>;
137123
};
138124

139-
140125
/*
141126
Conceptually a determinized state in weighted FSA determinization would
142127
normally
@@ -171,7 +156,6 @@ class DetState {
171156
// those paths. When Normalize() is called we may advance
172157
int32_t base_state;
173158

174-
175159
// seq_len is the length of symbol sequence that we follow from state
176160
// `base_state`. The sequence of symbols can be found by tracing back one of
177161
// the DetStateElements in the doubly linked list (it doesn't matter which you
@@ -180,7 +164,6 @@ class DetState {
180164

181165
bool normalized{false};
182166

183-
184167
std::list<DetStateElement> elements;
185168

186169
// This is the weight on the best path that includes this determinized state.
@@ -190,14 +173,12 @@ class DetState {
190173
// the backward-weight of the state associated with that DetStateElement).
191174
double forward_backward_weight;
192175

193-
194176
/*
195-
Process arcs leaving this determinized state, possibly creating new determinized
196-
states in the process.
197-
@param [in] wfsa_in The input FSA that we are determinizing, along
198-
with forward-backward weights.
199-
The input FSA should normally be epsilon-free as
200-
epsilons are treated as a normal symbol; and require
177+
Process arcs leaving this determinized state, possibly creating new
178+
determinized states in the process.
179+
@param [in] wfsa_in The input FSA that we are determinizing,
180+
along with forward-backward weights. The input FSA should normally be
181+
epsilon-free as epsilons are treated as a normal symbol; and require
201182
wfsa_in.weight_tpe == kMaxWeight, for
202183
now (might later create a version of this code
203184
that works
@@ -210,13 +191,10 @@ class DetState {
210191
211192
212193
*/
213-
void ProcessArcs(const WfsaWithFbWeights &wfsa_in,
214-
Fsa *wfsa_out,
215-
float prune_cutoff,
216-
DetStateMap *state_map,
194+
void ProcessArcs(const WfsaWithFbWeights &wfsa_in, Fsa *wfsa_out,
195+
float prune_cutoff, DetStateMap *state_map,
217196
DetStatePriorityQueue *queue);
218197

219-
220198
/*
221199
Normalizes this DetState and sets forward_backward_weight.
222200
@@ -259,18 +237,17 @@ class DetState {
259237
plus the backward weight in the input FSA of the state that corresponds
260238
to it).
261239
*/
262-
void Normalize(const Fsa &input_fsa,
263-
const float *input_fsa_weights,
264-
float *removed_weight,
265-
std::vector<int32_t> *leftover_arcs) {
240+
void Normalize(const Fsa &input_fsa, const float *input_fsa_weights,
241+
float *removed_weight, std::vector<int32_t> *leftover_arcs) {
266242
#ifndef NDEBUG
267243
CheckElementOrder();
268244
#endif
269245
RemoveDuplicatesOfStates(input_fsa);
270-
RemoveCommonPrefix(input_fsa, input_fsa_weights, removed_weight, leftover_arcs);
246+
RemoveCommonPrefix(input_fsa, input_fsa_weights, removed_weight,
247+
leftover_arcs);
271248
}
272-
private:
273249

250+
private:
274251
/*
275252
Called from Normalize(), this function removes duplicates in
276253
`elements`: that is, if two elements represent paths that terminate at
@@ -287,8 +264,7 @@ class DetState {
287264
weights in `elements`, and set `input_arcs` to the sequence of arcs that
288265
were removed from
289266
*/
290-
RemoveCommonPrefix(const Fsa &input_fsa,
291-
const float *input_fsa_weights,
267+
RemoveCommonPrefix(const Fsa &input_fsa, const float *input_fsa_weights,
292268
std::vector<int32_t> *input_arcs);
293269
/*
294270
This function just does some checking on the `elements` list that
@@ -298,29 +274,24 @@ class DetState {
298274
they are all the same.
299275
*/
300276
void CheckElementOrder() const;
301-
302277
};
303278

304279
bool DetStateCompare::operator()(const shared_ptr<DetState> &a,
305280
const shared_ptr<DetState> &b) {
306281
return a->forward_backward_weight < b->forward_backward_weight;
307282
}
308283

309-
310-
311284
void DetState::RemoveDuplicatesOfStates(const Fsa &input_fsa) {
312-
313285
/*
314286
`state_to_elem` maps from int32_t state-id to the DetStateElement
315287
associated with it (there can be only one, we choose the one with
316288
the best weight).
317289
*/
318-
std::unordered_map<int32_t, typename std::list<DetStateElement>::iterator> state_to_elem;
319-
320-
290+
std::unordered_map<int32_t, typename std::list<DetStateElement>::iterator>
291+
state_to_elem;
321292

322293
for (auto iter = elements.begin(); iter != elements.end(); ++iter) {
323-
int32_t state = input_fsa.arcs[elem.arc_index].nextstate;
294+
int32_t state = input_fsa.arcs[elem.arc_index].nextstate;
324295
auto p = state_to_elem.insert({state, elem});
325296
bool inserted = p.second;
326297
if (!inserted) {
@@ -339,11 +310,9 @@ void DetState::RemoveCommonPrefix(const Fsa &input_fsa,
339310
const float *input_fsa_weights,
340311
float *removed_weight_out,
341312
std::vector<int32_t> *input_arcs) {
342-
343313
CHECK_GE(seq_len, 0);
344314
int32_t len;
345-
auto first_path = elements.front().path,
346-
last_path = elements.back().path;
315+
auto first_path = elements.front().path, last_path = elements.back().path;
347316

348317
for (len = 1; len < seq_len; ++len) {
349318
first_path = first_path->prev;
@@ -359,8 +328,7 @@ void DetState::RemoveCommonPrefix(const Fsa &input_fsa,
359328
/* We reach a common state after traversing fewer than `seq_len` arcs,
360329
so we can remove a shared prefix. */
361330
double removed_weight = 0.0;
362-
int32_t new_seq_len = len,
363-
removed_seq_len = seq_len - len;
331+
int32_t new_seq_len = len, removed_seq_len = seq_len - len;
364332
input_arcs->resize(removed_seq_len);
365333
// Advance base_state
366334
int32_t new_base_state = input_fsa.arcs[first_path->arc_index].src_state;
@@ -375,7 +343,7 @@ void DetState::RemoveCommonPrefix(const Fsa &input_fsa,
375343
fsa.arcs[first_path->arc_index].dest_state == this->base_state);
376344
this->base_state = new_base_state;
377345
if (removed_weight != 0) {
378-
for (DetStateElement &det_state_elem: elements) {
346+
for (DetStateElement &det_state_elem : elements) {
379347
det_state_elem.weight -= removed_weight;
380348
}
381349
}
@@ -393,8 +361,8 @@ void DetState::CheckElementOrder(const Fsa &input_fsa) const {
393361
// leaving each state in the FSA are sorted first on label and then on
394362
// dest_state.
395363
if (seq_len == 0) {
396-
CHECK(elements.size() == 1);
397-
CHECK(elements.front().weight == 0.0);
364+
CHECK_EQ(elements.size(), 1);
365+
CHECK_EQ(elements.front().weight, 0.0);
398366
}
399367

400368
std::vector<int32> prev_seq;
@@ -413,14 +381,12 @@ void DetState::CheckElementOrder(const Fsa &input_fsa) const {
413381
}
414382
}
415383

416-
417384
/*
418385
This class maps from determinized states (DetState) to integer state-id
419386
in the determinized output.
420387
*/
421388
class DetStateMap {
422389
public:
423-
424390
/*
425391
Outputs the output state-id corresponding to a specific DetState structure.
426392
This does not store any pointers to the DetState or its contents, so
@@ -454,7 +420,8 @@ class DetStateMap {
454420
private:
455421
int32_t cur_output_state_{0};
456422
std::unordered_map<std::pair<uint64_t, uint64_t>, int32_t,
457-
DetStateVectorHasher> map_;
423+
DetStateVectorHasher>
424+
map_;
458425

459426
/* Turns DetState into a compact form of 128 bits. Technically there
460427
could be collisions, which would be fatal for the algorithm, but this
@@ -493,7 +460,7 @@ class DetStateMap {
493460
};
494461

495462
void DeterminizeMax(const WfsaWithFbWeights &a, float beam, Fsa *b,
496-
std::vector<std::vector<int32_t> > *arc_map) {
463+
std::vector<std::vector<int32_t>> *arc_map) {
497464
// TODO(dpovey): use glog stuff.
498465
assert(IsValid(a) && IsEpsilonFree(a) && IsTopSortedAndAcyclic(a));
499466
if (a.arc_indexes.empty()) {

0 commit comments

Comments
 (0)