Skip to content

Commit 5ca755c

Browse files
authored
Merge pull request #8 from danpovey/rm_epsilons_fix
Rework RmEpsilons
2 parents b1f7702 + edc8d42 commit 5ca755c

File tree

4 files changed

+104
-40
lines changed

4 files changed

+104
-40
lines changed

k2/csrc/fsa.h

-10
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,6 @@ struct DenseFsa {
112112
DenseFsa(Weight *data, int32_t T, int32_t num_symbols, int32_t stride);
113113
};
114114

115-
/*
116-
this general-purpose structure conceptually the same as
117-
std::vector<std::vector>; elements of `ranges` are (begin, end) indexes into
118-
`values`.
119-
*/
120-
struct VecOfVec {
121-
std::vector<Range> ranges;
122-
std::vector<std::pair<Label, StateId>> values;
123-
};
124-
125115
struct Fst {
126116
Fsa core;
127117
std::vector<int32_t> aux_label;

k2/csrc/fsa_algo.h

+81-25
Original file line numberDiff line numberDiff line change
@@ -53,36 +53,82 @@ void ConnectCore(const Fsa &fsa, std::vector<int32_t> *state_map);
5353
void Connect(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map = nullptr);
5454

5555
/**
56-
Output an Fsa that is equivalent to the input but which has no epsilons.
57-
58-
@param [in] a The input FSA
56+
Output an Fsa that is equivalent to the input (in the tropical semiring,
57+
which here means taking the max of the weights along paths) but which has no
58+
epsilons. The input needs to have associated weights, because they will be
59+
used to choose the best among alternative epsilon paths between states.
60+
61+
@param [in] a The input, with weights and forward-backward weights
62+
as required by this computation. For now we assume
63+
that `a` is topologically sorted, as required by
64+
the current constructor of WfsaWithFbWeights.
65+
@param [in] beam beam > 0 that affects pruning; this algorithm will
66+
keep paths that are within `beam` of the best path.
67+
Just make this very large if you don't want pruning.
5968
@param [out] b The output FSA; will be epsilon-free, and the states
6069
will be in the same order that they were in in `a`.
6170
@param [out] arc_map If non-NULL: for each arc in `b`, a list of
62-
the arc-indexes in `a` that contributed to that arc
63-
(e.g. its cost would be a sum of their costs).
64-
TODO(Dan): make it a VecOfVec, maybe?
71+
the arc-indexes in `a`, in order, that contributed
72+
to that arc (e.g. its cost would be a sum of their costs).
73+
74+
Notes on algorithm (please rework all this when it's complete, i.e. just
75+
make sure the code is clear and remove this).
76+
77+
The states in the output FSA will correspond to the subset of states in the
78+
input FSA which are within `beam` of the best path and which have at least
79+
one non-epsilon arc entering them, plus the start state. (Note: this
80+
automatically includes the final state, assuming `a` has at least one
81+
successful path; if it does not, the output will be empty).
82+
83+
If we ever need the associated state map from calling code, we'll add an
84+
extra output argument to this function.
85+
86+
The basic algorithm is to (1) identify the kept states, (2) from each kept
87+
input-state ki, we'll iterate over all states that are reachable via zero or more
88+
epsilons from this state and process the non-epsilon outgoing arcs from
89+
those states, which will become the arcs in the output. We'll also store a
90+
back-pointer array that will allow us to figure out the best path back to ki,
91+
in order to produce the output `arc_map`. Assume we have arrays
92+
93+
local_forward_weights (float) and local_backpointers (int) indexed by
94+
state-id, and that the local_forward_weights are initialized with
95+
-infinity's each time we process a new ki. (we have to figure out how to do this
96+
efficiently).
97+
98+
99+
Processing input-state ki:
100+
local_forward_state_weights[ki] = forward_state_weights[ki] // from WfsaWithFbWeights.
101+
// Caution: we should probably use
102+
// double here; these kinds of algorithms
103+
// are extremely sensitive to roundoff for
104+
// very long FSAs.
105+
local_backpointers[ki] = -1 // will terminate a sequence..
106+
queue.push_back(ki)
107+
while (!queue.empty()) {
108+
ji = queue.front() // we have to be a bit careful about order here, to make sure
109+
// we always process states when they already have the
110+
// best cost they are going to get. If
111+
// FSA was top-sorted at the start, which we assume, we could perhaps
112+
// process them in numerical order, e.g. using a heap.
113+
queue.pop_front()
114+
for each arc leaving state ji:
115+
next_weight = local_forward_state_weights[ji] + arc_weights[this_arc_index]
116+
if next_weight + backward_state_weights[arc_dest_state] < best_path_weight - beam:
117+
if arc label is epsilon:
118+
if next_weight < local_forward_state_weight[next_state]:
119+
local_forward_state_weight[next_state] = next_weight
120+
local_backpointers[next_state] = ji
121+
else:
122+
add an arc to the output FSA, and create the appropriate
123+
arc_map entry by following backpointers (hopefully you can figure out the
124+
details). Note: the output FSA's weights can be computed later on,
125+
by calling code, using the info in arc_map.
65126
*/
66-
void RmEpsilons(const Fsa &a, Fsa *b,
67-
std::vector<std::vector> *arc_map = nullptr);
68-
69-
/**
70-
Pruned version of RmEpsilons, which also uses a pruning beam.
71-
72-
Output an Fsa that is equivalent to the input but which has no epsilons.
127+
void RmEpsilonsPruned(const WfsaWithFbWeights &a,
128+
float beam,
129+
Fsa *b,
130+
std::vector<std::vector> *arc_map);
73131

74-
@param [in] a The input FSA
75-
@param [out] b The output FSA; will be epsilon-free, and the states
76-
will be in the same order that they were in in `a`.
77-
@param [out] arc_map If non-NULL: for each arc in `b`, a list of
78-
the arc-indexes in `a` that contributed to that arc
79-
(e.g. its cost would be a sum of their costs).
80-
TODO(Dan): make it a VecOfVec, maybe?
81-
*/
82-
void RmEpsilonsPruned(const Fsa &a, const float *a_state_forward_costs,
83-
const float *a_state_backward_costs,
84-
const float *a_arc_costs, float cutoff, Fsa *b,
85-
std::vector<std::vector> *arc_map = nullptr);
86132

87133
/*
88134
Compute the intersection of two FSAs; this is the equivalent of composition
@@ -160,6 +206,16 @@ void IntersectPruned2(const Fsa &a, const float *a_cost, const Fsa &b,
160206
void RandomPath(const Fsa &a, const float *a_cost, Fsa *b,
161207
std::vector<int32_t> *state_map = nullptr);
162208

209+
210+
211+
/**
212+
213+
*/
214+
void Determinize(const Fsa &a, Fsa *b,
215+
std::vector<std::vector<StateId> > *state_map);
216+
217+
218+
163219
} // namespace k2
164220

165221
#endif // K2_CSRC_FSA_ALGO_H_

k2/csrc/fsa_util.h

+16-3
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,23 @@ namespace k2 {
1515
Computes lists of arcs entering each state (needed for algorithms that
1616
traverse the Fsa in reverse order).
1717
18-
Requires that `fsa` be valid and top-sorted, i.e.
19-
CheckProperties(fsa, KTopSorted) == true.
18+
Requires that `fsa` be valid and top-sorted, i.e. CheckProperties(fsa,
19+
KTopSorted) == true.
20+
21+
@param [out] arc_index A list of arc indexes.
22+
For states 0 < s < fsa.NumStates(),
23+
the elements arc_index[i] for end_index[s-1] <= i < end_index[s]
24+
contain the arc-indexes in fsa.arcs for arcs that
25+
enter state s.
26+
@param [out] end_index For each state, the `end` index in `arc_index`
27+
where we can find arcs entering this state, i.e.
28+
one past the index of the last element in `arc_index`
29+
that points to an arc entering this state.
2030
*/
21-
void GetEnteringArcs(const Fsa &fsa, VecOfVec *entering_arcs);
31+
void GetEnteringArcs(const Fsa &fsa,
32+
std::vector<int32_t> *arc_index,
33+
std::vector<int32_t> *end_index);
34+
2235

2336
} // namespace k2
2437

k2/csrc/weights.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,13 @@ enum { kMaxWeight, kLogSumWeight } FbWeightType;
6161
struct WfsaWithFbWeights {
6262
const Fsa *fsa;
6363
const float *arc_weights;
64-
const float *forward_state_weights;
65-
const float *backward_state_weights;
64+
// forward_state_weights are the sum of weights along the best path from the
65+
// start-state to each state. We use double because for long FSAs roundoff
66+
// effects can cause nasty errors in pruning.
67+
const double *forward_state_weights;
68+
// backward_state_weights are the sum of weights along the best path
69+
// from each state to the final state.
70+
const double *backward_state_weights;
6671

6772
/*
6873
Constructor.

0 commit comments

Comments
 (0)