Skip to content

Commit 1476834

Browse files
committed
replace FSA with Array2 in RmEpsilons
1 parent 096fbbe commit 1476834

14 files changed

+642
-480
lines changed

k2/csrc/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_library(fsa
88
fsa_renderer.cc
99
fsa_util.cc
1010
properties.cc
11+
rmepsilon.cc
1112
util.cc
1213
weights.cc
1314
)
@@ -46,6 +47,7 @@ set(fsa_tests
4647
fsa_test
4748
fsa_util_test
4849
properties_test
50+
rmepsilon_test
4951
weights_test
5052
)
5153

k2/csrc/array.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,13 @@ struct Array2 {
9494
using PtrT = Ptr;
9595
using ValueType = typename std::iterator_traits<Ptr>::value_type;
9696

97-
Array2() : size1(0), indexes(&size1), size2(0), data(nullptr) {}
98-
Array2(IndexT size1, IndexT *indexes, IndexT size2, PtrT data)
99-
: size1(size1), indexes(indexes), size2(size2), data(data) {}
100-
void Init(IndexT size1, IndexT *indexes, IndexT size2, PtrT data) {
97+
Array2() : size1(0), size2(0), indexes(&size1), data(nullptr) {}
98+
Array2(IndexT size1, IndexT size2, IndexT *indexes, PtrT data)
99+
: size1(size1), size2(size2), indexes(indexes), data(data) {}
100+
void Init(IndexT size1, IndexT size2, IndexT *indexes, PtrT data) {
101101
this->size1 = size1;
102-
this->indexes = indexes;
103102
this->size2 = size2;
103+
this->indexes = indexes;
104104
this->data = data;
105105
}
106106

k2/csrc/aux_labels.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ void FstInverter::GetOutput(Fsa *fsa_out, AuxLabels *labels_out) {
245245

246246
std::vector<int32_t> arc_map;
247247
ReorderArcs(arcs, fsa_out, &arc_map);
248-
AuxLabels labels_tmp(labels_out->size1, start_pos.data(), labels_out->size2,
248+
AuxLabels labels_tmp(labels_out->size1, labels_out->size2, start_pos.data(),
249249
labels.data());
250250
AuxLabels1Mapper aux_mapper(labels_tmp, arc_map);
251251
// don't need to call `GetSizes` here as `labels_out` has been initialized

k2/csrc/aux_labels_test.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ TEST(AuxLabels, InvertFst) {
134134
std::vector<int32_t> start_pos = {0, 1, 3, 6, 7};
135135
std::vector<int32_t> labels = {1, 2, 3, 4, 5, 6, 7};
136136
AuxLabels labels_in(static_cast<int32_t>(start_pos.size()) - 1,
137-
start_pos.data(), static_cast<int32_t>(labels.size()),
137+
static_cast<int32_t>(labels.size()), start_pos.data(),
138138
labels.data());
139139

140140
FstInverter fst_inverter(fsa_in, labels_in);
@@ -162,7 +162,7 @@ TEST(AuxLabels, InvertFst) {
162162
EXPECT_EQ(start_pos.size(), fsa_in.size2 + 1);
163163
std::vector<int32_t> labels = {1, 2, 3, 5, 6, 7, -1, -1, -1};
164164
AuxLabels labels_in(static_cast<int32_t>(start_pos.size()) - 1,
165-
start_pos.data(), static_cast<int32_t>(labels.size()),
165+
static_cast<int32_t>(labels.size()), start_pos.data(),
166166
labels.data());
167167

168168
FstInverter fst_inverter(fsa_in, labels_in);
@@ -212,7 +212,7 @@ TEST(AuxLabels, InvertFst) {
212212
EXPECT_EQ(start_pos.size(), fsa_in.size2 + 1);
213213
std::vector<int32_t> labels = {1, 2, 3, 5, 6, 7, 8, -1, 9, 10, -1};
214214
AuxLabels labels_in(static_cast<int32_t>(start_pos.size()) - 1,
215-
start_pos.data(), static_cast<int32_t>(labels.size()),
215+
static_cast<int32_t>(labels.size()), start_pos.data(),
216216
labels.data());
217217

218218
FstInverter fst_inverter(fsa_in, labels_in);

k2/csrc/determinize.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,11 @@ struct MaxTracebackState {
172172
// sequence of symbols we took to get here)
173173

174174
// This constructor is for the start-state (state zero) of the input FSA.
175-
MaxTracebackState()
176-
: state_id(0), arc_id(-1), prev_state(nullptr), forward_prob(0.0) {}
175+
explicit MaxTracebackState(int32_t state_id = 0, double forward_prob = 0.0)
176+
: state_id(state_id),
177+
arc_id(-1),
178+
prev_state(nullptr),
179+
forward_prob(forward_prob) {}
177180

178181
/**
179182
@param [in] state_id State in input FSA that this corresponds to

k2/csrc/fsa_algo.cc

-281
Original file line numberDiff line numberDiff line change
@@ -52,78 +52,6 @@ inline int32_t InsertIntersectionState(
5252
return result.first->second;
5353
}
5454

55-
/**
56-
A TraceBack() function used in RmEpsilonsPrunedLogSum. It finds derivative
57-
information for all arcs in a sub-graph. Generally, in
58-
RmEpsilonsPrunedLogSum, we actually get a sub-graph when we find a
59-
non-epsilon arc starting from a particular state `s` (from which we are
60-
trying to remove epsilon arcs). All leaving arcs of all states in this
61-
sub-graph are epsilon arcs except the last one. Then, from the last state, we
62-
need to trace back to state `s` to find the derivative information for all
63-
epsilon arcs in this graph.
64-
@param [in] curr_states (This is consumed destructively, i.e. don't
65-
expect it to contain the same set on exit).
66-
A set of states, stored as a std::map that mapping
67-
state_id in input FSA to the corresponding
68-
LogSumTracebackState we created for this state;
69-
we'll iteratively trace back this set one element
70-
(processing all entering arcs) at a time. At entry
71-
it must have size() == 1 which contains the last
72-
state mentioned above; it will also have size() == 1
73-
at exit which contains the state `s` above.
74-
@param [in] arc_weights_in Weights on the arcs of the input FSA
75-
@param [out] deriv_out Some derivative information at the output
76-
will be written to here, which tells us how the weight
77-
of the non-epsilon arc we created from the above
78-
sub-graph varies as a function of the weights on the
79-
arcs of the input FSA; it's a list
80-
(input_arc_id, deriv) where, mathematically,
81-
0 < deriv <= 1 (but we might still get exact zeros
82-
due to limitations of floating point representation).
83-
*/
84-
static void TraceBackRmEpsilonsLogSum(
85-
std::map<int32_t, k2::LogSumTracebackState *> *curr_states,
86-
const float *arc_weights_in,
87-
std::vector<std::pair<int32_t, float>> *deriv_out) {
88-
CHECK_EQ(curr_states->size(), 1);
89-
deriv_out->clear();
90-
91-
// as the input fsa is top-sorted, we traverse states in a reverse order so we
92-
// can process them when they already have correct backward_prob (all leaving
93-
// arcs have been processed).
94-
k2::LogSumTracebackState *state_ptr = curr_states->rbegin()->second;
95-
// In the standard forward-backward algorithm for HMMs this backward_prob
96-
// would, mathematically, be 0.0, but if we set it to the negative of the
97-
// forward prob we can avoid having to subtract the total log-prob
98-
// when we compute posterior/occupation probabilities for arcs.
99-
state_ptr->backward_prob = -state_ptr->forward_prob;
100-
while (!state_ptr->prev_elements.empty()) {
101-
double backward_prob = state_ptr->backward_prob;
102-
for (const auto &link : state_ptr->prev_elements) {
103-
auto arc_log_posterior =
104-
static_cast<float>(link.forward_prob + backward_prob);
105-
deriv_out->emplace_back(link.arc_index, expf(arc_log_posterior));
106-
k2::LogSumTracebackState *prev_state = link.prev_state.get();
107-
double new_backward_prob = backward_prob + arc_weights_in[link.arc_index];
108-
auto result = curr_states->emplace(prev_state->state_id, prev_state);
109-
if (result.second) {
110-
prev_state->backward_prob = new_backward_prob;
111-
} else {
112-
prev_state->backward_prob =
113-
k2::LogAdd(new_backward_prob, prev_state->backward_prob);
114-
}
115-
}
116-
// we have processed all entering arcs of state curr_states->rbegin(),
117-
// we'll remove it now. As std::map.erase() does not support passing a
118-
// reverse iterator, we here pass --end();
119-
curr_states->erase(--curr_states->end());
120-
CHECK(!curr_states->empty());
121-
state_ptr = curr_states->rbegin()->second;
122-
}
123-
// we have reached the state from which we are trying to remove epsilon arcs.
124-
CHECK_EQ(curr_states->size(), 1);
125-
}
126-
12755
} // namespace
12856

12957
namespace k2 {
@@ -350,215 +278,6 @@ bool Connect(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map /*=nullptr*/) {
350278
return is_acyclic;
351279
}
352280

353-
void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b,
354-
std::vector<std::vector<int32_t>> *arc_derivs) {
355-
CHECK_EQ(a.weight_type, kMaxWeight);
356-
CHECK_GT(beam, 0);
357-
CHECK_NOTNULL(b);
358-
CHECK_NOTNULL(arc_derivs);
359-
b->arc_indexes.clear();
360-
b->arcs.clear();
361-
arc_derivs->clear();
362-
363-
const auto &fsa = a.fsa;
364-
if (IsEmpty(fsa)) return;
365-
int32_t num_states_a = fsa.NumStates();
366-
int32_t final_state = fsa.FinalState();
367-
const auto &arcs_a = fsa.data;
368-
const float *arc_weights_a = a.arc_weights;
369-
370-
// identify all states that should be kept
371-
std::vector<char> non_eps_in(num_states_a, 0);
372-
non_eps_in[0] = 1;
373-
for (const auto &arc : fsa) {
374-
// We suppose the input fsa `a` is top-sorted, but only check this in DEBUG
375-
// time.
376-
DCHECK_GE(arc.dest_state, arc.src_state);
377-
if (arc.label != kEpsilon) non_eps_in[arc.dest_state] = 1;
378-
}
379-
380-
// remap state id
381-
std::vector<int32_t> state_map_a2b(num_states_a, -1);
382-
int32_t num_states_b = 0;
383-
for (int32_t i = 0; i != num_states_a; ++i) {
384-
if (non_eps_in[i] == 1) state_map_a2b[i] = num_states_b++;
385-
}
386-
b->arc_indexes.reserve(num_states_b + 1);
387-
int32_t arc_num_b = 0;
388-
389-
const double *forward_state_weights = a.ForwardStateWeights();
390-
const double *backward_state_weights = a.BackwardStateWeights();
391-
const double best_weight = forward_state_weights[final_state] - beam;
392-
for (int32_t i = 0; i != num_states_a; ++i) {
393-
if (non_eps_in[i] != 1) continue;
394-
b->arc_indexes.push_back(arc_num_b);
395-
int32_t curr_state_b = state_map_a2b[i];
396-
// as the input FSA is top-sorted, we use a map here so we can process
397-
// states when they already have the best cost they are going to get
398-
std::map<int32_t, double>
399-
local_forward_weights; // state -> local_forward_state_weights of this
400-
// state
401-
// state -> (src_state, arc_index) entering this state which contributes to
402-
// `local_forward_weights` of this state.
403-
std::unordered_map<int32_t, std::pair<int32_t, int32_t>>
404-
local_backward_arcs;
405-
local_forward_weights.emplace(i, forward_state_weights[i]);
406-
// `-1` means we have traced back to current state `i`
407-
local_backward_arcs.emplace(i, std::make_pair(i, -1));
408-
while (!local_forward_weights.empty()) {
409-
std::pair<int32_t, double> curr_local_forward_weights =
410-
*(local_forward_weights.begin());
411-
local_forward_weights.erase(local_forward_weights.begin());
412-
int32_t state = curr_local_forward_weights.first;
413-
414-
int32_t arc_end = fsa.indexes[state + 1];
415-
for (int32_t arc_index = fsa.indexes[state]; arc_index != arc_end;
416-
++arc_index) {
417-
int32_t next_state = arcs_a[arc_index].dest_state;
418-
int32_t label = arcs_a[arc_index].label;
419-
double next_weight =
420-
curr_local_forward_weights.second + arc_weights_a[arc_index];
421-
if (next_weight + backward_state_weights[next_state] >= best_weight) {
422-
if (label == kEpsilon) {
423-
auto result =
424-
local_forward_weights.emplace(next_state, next_weight);
425-
if (result.second) {
426-
local_backward_arcs[next_state] =
427-
std::make_pair(state, arc_index);
428-
} else {
429-
if (next_weight > result.first->second) {
430-
result.first->second = next_weight;
431-
local_backward_arcs[next_state] =
432-
std::make_pair(state, arc_index);
433-
}
434-
}
435-
} else {
436-
b->arcs.emplace_back(curr_state_b, state_map_a2b[next_state],
437-
label);
438-
std::vector<int32_t> curr_arc_deriv;
439-
std::pair<int32_t, int32_t> curr_backward_arc{state, arc_index};
440-
auto *backward_arc = &curr_backward_arc;
441-
while (backward_arc->second != -1) {
442-
curr_arc_deriv.push_back(backward_arc->second);
443-
backward_arc = &(local_backward_arcs[backward_arc->first]);
444-
}
445-
std::reverse(curr_arc_deriv.begin(), curr_arc_deriv.end());
446-
arc_derivs->emplace_back(std::move(curr_arc_deriv));
447-
++arc_num_b;
448-
}
449-
}
450-
}
451-
}
452-
}
453-
// duplicate of final state
454-
b->arc_indexes.push_back(b->arc_indexes.back());
455-
}
456-
457-
void RmEpsilonsPrunedLogSum(
458-
const WfsaWithFbWeights &a, float beam, Fsa *b,
459-
std::vector<float> *b_arc_weights,
460-
std::vector<std::vector<std::pair<int32_t, float>>> *arc_derivs) {
461-
CHECK_GT(beam, 0);
462-
CHECK_NOTNULL(b);
463-
CHECK_NOTNULL(b_arc_weights);
464-
CHECK_NOTNULL(arc_derivs);
465-
b->arc_indexes.clear();
466-
b->arcs.clear();
467-
b_arc_weights->clear();
468-
arc_derivs->clear();
469-
470-
const auto &fsa = a.fsa;
471-
if (IsEmpty(fsa)) return;
472-
int32_t num_states_a = fsa.NumStates();
473-
int32_t final_state = fsa.FinalState();
474-
const auto &arcs_a = fsa.data;
475-
const float *arc_weights_a = a.arc_weights;
476-
477-
// identify all states that should be kept
478-
std::vector<char> non_eps_in(num_states_a, 0);
479-
non_eps_in[0] = 1;
480-
for (const auto &arc : fsa) {
481-
// We suppose the input fsa `a` is top-sorted, but only check this in DEBUG
482-
// time.
483-
DCHECK_GE(arc.dest_state, arc.src_state);
484-
if (arc.label != kEpsilon) non_eps_in[arc.dest_state] = 1;
485-
}
486-
487-
// remap state id
488-
std::vector<int32_t> state_map_a2b(num_states_a, -1);
489-
int32_t num_states_b = 0;
490-
for (int32_t i = 0; i != num_states_a; ++i) {
491-
if (non_eps_in[i] == 1) state_map_a2b[i] = num_states_b++;
492-
}
493-
b->arc_indexes.reserve(num_states_b + 1);
494-
int32_t arc_num_b = 0;
495-
496-
const double *forward_state_weights = a.ForwardStateWeights();
497-
const double *backward_state_weights = a.BackwardStateWeights();
498-
const double best_weight = forward_state_weights[final_state] - beam;
499-
for (int32_t i = 0; i != num_states_a; ++i) {
500-
if (non_eps_in[i] != 1) continue;
501-
b->arc_indexes.push_back(arc_num_b);
502-
int32_t curr_state_b = state_map_a2b[i];
503-
// as the input FSA is top-sorted, we use a set here so we can process
504-
// states when they already have costs over all paths they are going to get
505-
std::set<int32_t> qstates;
506-
std::unordered_map<int32_t, std::shared_ptr<LogSumTracebackState>>
507-
traceback_states; // state -> LogSumTracebackState of this state
508-
std::shared_ptr<LogSumTracebackState> start_state(
509-
new LogSumTracebackState(i, forward_state_weights[i]));
510-
double start_forward_weights = start_state->forward_prob;
511-
traceback_states.emplace(i, start_state);
512-
qstates.insert(i);
513-
while (!qstates.empty()) {
514-
int32_t state = *(qstates.begin());
515-
qstates.erase(qstates.begin());
516-
517-
const auto &curr_traceback_state = traceback_states[state];
518-
double curr_forward_weights = curr_traceback_state->forward_prob;
519-
int32_t arc_end = fsa.indexes[state + 1];
520-
for (int32_t arc_index = fsa.indexes[state]; arc_index != arc_end;
521-
++arc_index) {
522-
int32_t next_state = arcs_a[arc_index].dest_state;
523-
int32_t label = arcs_a[arc_index].label;
524-
float curr_arc_weight = arc_weights_a[arc_index];
525-
double next_weight = curr_forward_weights + curr_arc_weight;
526-
if (next_weight + backward_state_weights[next_state] >= best_weight) {
527-
if (label == kEpsilon) {
528-
auto result = traceback_states.emplace(next_state, nullptr);
529-
if (result.second) {
530-
result.first->second = std::make_shared<LogSumTracebackState>(
531-
next_state, curr_traceback_state, arc_index, curr_arc_weight);
532-
qstates.insert(next_state);
533-
} else {
534-
result.first->second->Accept(curr_traceback_state, arc_index,
535-
curr_arc_weight);
536-
}
537-
} else {
538-
b->arcs.emplace_back(curr_state_b, state_map_a2b[next_state],
539-
label);
540-
b_arc_weights->push_back(curr_forward_weights + curr_arc_weight -
541-
start_forward_weights);
542-
543-
std::vector<std::pair<int32_t, float>> curr_arc_deriv;
544-
std::map<int32_t, LogSumTracebackState *> curr_states;
545-
curr_states.emplace(state, curr_traceback_state.get());
546-
TraceBackRmEpsilonsLogSum(&curr_states, arc_weights_a,
547-
&curr_arc_deriv);
548-
std::reverse(curr_arc_deriv.begin(), curr_arc_deriv.end());
549-
// push derivs info of current arc
550-
curr_arc_deriv.emplace_back(arc_index, 1);
551-
arc_derivs->emplace_back(std::move(curr_arc_deriv));
552-
++arc_num_b;
553-
}
554-
}
555-
}
556-
}
557-
}
558-
// duplicate of final state
559-
b->arc_indexes.push_back(b->arc_indexes.back());
560-
}
561-
562281
bool Intersect(const Fsa &a, const Fsa &b, Fsa *c,
563282
std::vector<int32_t> *arc_map_a /*= nullptr*/,
564283
std::vector<int32_t> *arc_map_b /*= nullptr*/) {

0 commit comments

Comments
 (0)