From 0983f9dca13c5817a18ce03e04438b9fb0fefdbe Mon Sep 17 00:00:00 2001 From: Haowen Qiu Date: Wed, 20 May 2020 18:20:46 +0800 Subject: [PATCH 1/3] add derivs_out --- k2/csrc/determinize.h | 2 + k2/csrc/fsa_algo.cc | 105 +++++++++++++++++++++++++++++++++++++++ k2/csrc/fsa_algo.h | 68 ++----------------------- k2/csrc/fsa_algo_test.cc | 48 ++++++++++++++++++ 4 files changed, 160 insertions(+), 63 deletions(-) diff --git a/k2/csrc/determinize.h b/k2/csrc/determinize.h index a2bbf678c..4a15e235d 100644 --- a/k2/csrc/determinize.h +++ b/k2/csrc/determinize.h @@ -609,6 +609,8 @@ int32_t DetState::ProcessArcs( derivs_per_arc->push_back(std::move(deriv_info)); if (is_new_state) queue->push(std::unique_ptr>(det_state)); + else + delete det_state; } else { delete det_state; } diff --git a/k2/csrc/fsa_algo.cc b/k2/csrc/fsa_algo.cc index f58963538..89a4d6b0d 100644 --- a/k2/csrc/fsa_algo.cc +++ b/k2/csrc/fsa_algo.cc @@ -275,6 +275,111 @@ bool Connect(const Fsa &a, Fsa *b, std::vector *arc_map /*=nullptr*/) { return is_acyclic; } +void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b, + std::vector> *arc_derivs) { + CHECK_EQ(a.weight_type, kMaxWeight); + CHECK_GT(beam, 0); + CHECK_NOTNULL(b); + CHECK_NOTNULL(arc_derivs); + b->arc_indexes.clear(); + b->arcs.clear(); + arc_derivs->clear(); + + const auto &fsa = a.fsa; + if (IsEmpty(fsa)) return; + int32_t num_states_a = fsa.NumStates(); + int32_t final_state = fsa.FinalState(); + const auto &arcs_a = fsa.arcs; + const float *arc_weights_a = a.arc_weights; + + // identify all states that should be kept + std::vector non_eps_in(num_states_a, 0); + non_eps_in[0] = 1; + using WeightsPair = std::pair>; + // (label, dest_state) -> `sum` of weights along all paths from current state + // to `dest_state` with label == `label`(or plus numbers of epsilon), + // `vector` in `WeightsPair` records all arc-indexes in `a` that + // contributes the current arc in `b`. + using ArcMap = + std::unordered_map, WeightsPair, PairHash>; + std::vector arcs_b(num_states_a); + for (int32_t i = static_cast(arcs_a.size()) - 1; i >= 0; --i) { + const auto &arc = arcs_a[i]; + const auto src_state = arc.src_state; + const auto dest_state = arc.dest_state; + const auto label = arc.label; + DCHECK_GE(dest_state, src_state); + + double arc_weight = arc_weights_a[i]; + if (label != kEpsilon) { + non_eps_in[dest_state] = 1; + WeightsPair weights_pair = + std::make_pair(arc_weight, std::vector{i}); + auto insert_result = arcs_b[src_state].emplace( + std::make_pair(label, dest_state), weights_pair); + if (!insert_result.second) { + auto &old_weights_pair = insert_result.first->second; + // compare `arc_weights` + if (weights_pair.first > old_weights_pair.first) { + std::swap(old_weights_pair, weights_pair); + } + } + } else { + // remove epsilon arcs + for (const auto &item : arcs_b[dest_state]) { + auto weights_pair = item.second; // copy intended + // `times` the arc weights along path + weights_pair.first += arc_weight; + weights_pair.second.push_back(i); + auto insert_result = + arcs_b[src_state].emplace(item.first, weights_pair); + if (!insert_result.second) { + auto &old_weights_pair = insert_result.first->second; + // compare `arc_weights` + if (weights_pair.first > old_weights_pair.first) { + std::swap(old_weights_pair, weights_pair); + } + } + } + } + } + + // remap state id + std::vector state_map_a2b(num_states_a, -1); + int32_t num_states_b = 0; + for (int32_t i = 0; i != num_states_a; ++i) { + if (non_eps_in[i] == 1) state_map_a2b[i] = num_states_b++; + } + + // prune and output `b` + const double *forward_state_weights = a.ForwardStateWeights(); + const double *backward_state_weights = a.BackwardStateWeights(); + const double best_weight = forward_state_weights[final_state] - beam; + b->arc_indexes.reserve(num_states_b + 1); + int32_t arc_num_b = 0; + for (int32_t s = 0; s < num_states_a; ++s) { + if (non_eps_in[s] == 1) { + b->arc_indexes.push_back(arc_num_b); + for (const auto &arcs : arcs_b[s]) { + int32_t dest_state = arcs.first.second; + double weight = arcs.second.first; + if (forward_state_weights[s] + weight + + backward_state_weights[dest_state] > + best_weight) { + b->arcs.emplace_back(state_map_a2b[s], state_map_a2b[dest_state], + arcs.first.first); + auto curr_arc_deriv = std::move(arcs.second.second); + std::reverse(curr_arc_deriv.begin(), curr_arc_deriv.end()); + arc_derivs->emplace_back(curr_arc_deriv); + ++arc_num_b; + } + } + } + } + // duplicate of final state + b->arc_indexes.push_back(b->arc_indexes.back()); +} + bool Intersect(const Fsa &a, const Fsa &b, Fsa *c, std::vector *arc_map_a /*= nullptr*/, std::vector *arc_map_b /*= nullptr*/) { diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index f7f777242..a5e3191b4 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -91,71 +91,13 @@ bool Connect(const Fsa &a, Fsa *b, std::vector *arc_map = nullptr); Just make this very large if you don't want pruning. @param [out] b The output FSA; will be epsilon-free, and the states will be in the same order that they were in in `a`. - @param [out] arc_map If non-NULL: for each arc in `b`, a list of - the arc-indexes in `a`, in order, that contributed - to that arc (e.g. its cost would be a sum of their costs). - - Notes on algorithm (please rework all this when it's complete, i.e. just - make sure the code is clear and remove this). - - The states in the output FSA will correspond to the subset of states in the - input FSA which are within `beam` of the best path and which have at least - one non-epsilon arc entering them, plus the start state. (Note: this - automatically includes the final state, assuming `a` has at least one - successful path; if it does not, the output will be empty). - - If we ever need the associated state map from calling code, we'll add an - extra output argument to this function. - - The basic algorithm is to (1) identify the kept states, (2) from each kept - input-state ki, we'll iterate over all states that are reachable via zero - or more epsilons from this state and process the non-epsilon outgoing arcs - from those states, which will become the arcs in the output. We'll also - store a back-pointer array that will allow us to figure out the best path - back to ki, in order to produce the output `arc_map`. Assume we have - arrays - - local_forward_weights (float) and local_backpointers (int) indexed by - state-id, and that the local_forward_weights are initialized with - -infinity's each time we process a new ki. (we have to figure out how to do - this efficiently). - - - Processing input-state ki: - local_forward_state_weights[ki] = forward_state_weights[ki] // from - WfsaWithFbWeights. - // Caution: - we should probably use - // double - here; these kinds of algorithms - // are - extremely sensitive to roundoff for - // very - long FSAs. local_backpointers[ki] = -1 // will terminate a sequence.. - queue.push_back(ki) - while (!queue.empty()) { - ji = queue.front() // we have to be a bit careful about order here, - to make sure - // we always process states when they already - have the - // best cost they are going to get. If - // FSA was top-sorted at the start, which we - assume, we could perhaps - // process them in numerical order, e.g. using a - heap. queue.pop_front() for each arc leaving state ji: next_weight = - local_forward_state_weights[ji] + arc_weights[this_arc_index] if next_weight - + backward_state_weights[arc_dest_state] < best_path_weight - beam: if arc - label is epsilon: if next_weight < local_forward_state_weight[next_state]: - local_forward_state_weight[next_state] = next_weight - local_backpointers[next_state] = ji - else: - add an arc to the output FSA, and create the appropriate - arc_map entry by following backpointers (hopefully you can - figure out the details). Note: the output FSA's weights can - be computed later on, by calling code, using the info in arc_map. + @param [out] arc_derivs Indexed by arc in `b`, this is the sequence of + arcs in `a` that this arc in `b` corresponds to; the + weight of the arc in b will equal the sum of those input + arcs' weights */ void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b, - std::vector> *arc_map); + std::vector> *arc_derivs); /* Version of RmEpsilonsPrunedMax that doesn't support pruning; see its diff --git a/k2/csrc/fsa_algo_test.cc b/k2/csrc/fsa_algo_test.cc index ff1e2fbc0..81e865553 100644 --- a/k2/csrc/fsa_algo_test.cc +++ b/k2/csrc/fsa_algo_test.cc @@ -283,6 +283,54 @@ TEST(FsaAlgo, Connect) { } } +class RmEpsilonTest : public ::testing::Test { + protected: + RmEpsilonTest() { + std::vector arcs = { + {0, 4, 1}, {0, 1, 1}, {1, 2, 0}, {1, 3, 0}, {1, 4, 0}, + {2, 7, 0}, {3, 7, 0}, {4, 6, 1}, {4, 6, 0}, {4, 8, 1}, + {4, 9, -1}, {5, 9, -1}, {6, 9, -1}, {7, 9, -1}, {8, 9, -1}, + }; + fsa_ = new Fsa(std::move(arcs), 9); + num_states_ = fsa_->NumStates(); + + auto num_arcs = fsa_->arcs.size(); + arc_weights_ = new float[num_arcs]; + std::vector weights = {1, 1, 2, 3, 2, 4, 5, 2, 3, 3, 2, 4, 3, 5, 6}; + std::copy_n(weights.begin(), num_arcs, arc_weights_); + + max_wfsa_ = new WfsaWithFbWeights(*fsa_, arc_weights_, kMaxWeight); + log_wfsa_ = new WfsaWithFbWeights(*fsa_, arc_weights_, kLogSumWeight); + } + + ~RmEpsilonTest() { + delete fsa_; + delete[] arc_weights_; + delete max_wfsa_; + delete log_wfsa_; + } + + WfsaWithFbWeights *max_wfsa_; + WfsaWithFbWeights *log_wfsa_; + Fsa *fsa_; + int32_t num_states_; + float *arc_weights_; +}; + +TEST_F(RmEpsilonTest, RmEpsilonsPrunedMax) { + Fsa b; + std::vector> arc_derivs_b; + RmEpsilonsPrunedMax(*max_wfsa_, 8, &b, &arc_derivs_b); + + IsEpsilonFree(b); + ASSERT_EQ(b.arcs.size(), 10); + ASSERT_EQ(b.arc_indexes.size(), 7); + ASSERT_EQ(arc_derivs_b.size(), 10); + + // TODO(haowen): check the equivalence after implementing RandEquivalent for + // WFSA +} + TEST(FsaAlgo, Intersect) { // empty fsa { From d9d1e0a21e0ee88e4d3eba0678e158ba430cbad2 Mon Sep 17 00:00:00 2001 From: Haowen Qiu Date: Mon, 25 May 2020 20:36:04 +0800 Subject: [PATCH 2/3] implement with traceback-based algorithm --- k2/csrc/fsa_algo.cc | 136 ++++++++++++++++++++------------------- k2/csrc/fsa_algo.h | 2 +- k2/csrc/fsa_algo_test.cc | 6 +- 3 files changed, 75 insertions(+), 69 deletions(-) diff --git a/k2/csrc/fsa_algo.cc b/k2/csrc/fsa_algo.cc index 89a4d6b0d..6e3598085 100644 --- a/k2/csrc/fsa_algo.cc +++ b/k2/csrc/fsa_algo.cc @@ -8,11 +8,13 @@ #include "k2/csrc/fsa_algo.h" #include +#include #include #include #include #include #include +#include #include #include "glog/logging.h" @@ -295,53 +297,11 @@ void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b, // identify all states that should be kept std::vector non_eps_in(num_states_a, 0); non_eps_in[0] = 1; - using WeightsPair = std::pair>; - // (label, dest_state) -> `sum` of weights along all paths from current state - // to `dest_state` with label == `label`(or plus numbers of epsilon), - // `vector` in `WeightsPair` records all arc-indexes in `a` that - // contributes the current arc in `b`. - using ArcMap = - std::unordered_map, WeightsPair, PairHash>; - std::vector arcs_b(num_states_a); - for (int32_t i = static_cast(arcs_a.size()) - 1; i >= 0; --i) { - const auto &arc = arcs_a[i]; - const auto src_state = arc.src_state; - const auto dest_state = arc.dest_state; - const auto label = arc.label; - DCHECK_GE(dest_state, src_state); - - double arc_weight = arc_weights_a[i]; - if (label != kEpsilon) { - non_eps_in[dest_state] = 1; - WeightsPair weights_pair = - std::make_pair(arc_weight, std::vector{i}); - auto insert_result = arcs_b[src_state].emplace( - std::make_pair(label, dest_state), weights_pair); - if (!insert_result.second) { - auto &old_weights_pair = insert_result.first->second; - // compare `arc_weights` - if (weights_pair.first > old_weights_pair.first) { - std::swap(old_weights_pair, weights_pair); - } - } - } else { - // remove epsilon arcs - for (const auto &item : arcs_b[dest_state]) { - auto weights_pair = item.second; // copy intended - // `times` the arc weights along path - weights_pair.first += arc_weight; - weights_pair.second.push_back(i); - auto insert_result = - arcs_b[src_state].emplace(item.first, weights_pair); - if (!insert_result.second) { - auto &old_weights_pair = insert_result.first->second; - // compare `arc_weights` - if (weights_pair.first > old_weights_pair.first) { - std::swap(old_weights_pair, weights_pair); - } - } - } - } + for (const auto &arc : arcs_a) { + // We suppose the input fsa `a` is top-sorted, but only check this in DEBUG + // time. + DCHECK_GE(arc.dest_state, arc.src_state); + if (arc.label != kEpsilon) non_eps_in[arc.dest_state] = 1; } // remap state id @@ -350,28 +310,74 @@ void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b, for (int32_t i = 0; i != num_states_a; ++i) { if (non_eps_in[i] == 1) state_map_a2b[i] = num_states_b++; } + b->arc_indexes.reserve(num_states_b + 1); + int32_t arc_num_b = 0; - // prune and output `b` const double *forward_state_weights = a.ForwardStateWeights(); const double *backward_state_weights = a.BackwardStateWeights(); const double best_weight = forward_state_weights[final_state] - beam; - b->arc_indexes.reserve(num_states_b + 1); - int32_t arc_num_b = 0; - for (int32_t s = 0; s < num_states_a; ++s) { - if (non_eps_in[s] == 1) { - b->arc_indexes.push_back(arc_num_b); - for (const auto &arcs : arcs_b[s]) { - int32_t dest_state = arcs.first.second; - double weight = arcs.second.first; - if (forward_state_weights[s] + weight + - backward_state_weights[dest_state] > - best_weight) { - b->arcs.emplace_back(state_map_a2b[s], state_map_a2b[dest_state], - arcs.first.first); - auto curr_arc_deriv = std::move(arcs.second.second); - std::reverse(curr_arc_deriv.begin(), curr_arc_deriv.end()); - arc_derivs->emplace_back(curr_arc_deriv); - ++arc_num_b; + for (int32_t i = 0; i != num_states_a; ++i) { + if (non_eps_in[i] != 1) continue; + b->arc_indexes.push_back(arc_num_b); + int32_t curr_state_b = state_map_a2b[i]; + // as the input FSA is top-sorted, we use a heap here so we can process + // states when they already have the best cost they are going to get + std::priority_queue, std::greater> q; + // stores states that have been queued + std::unordered_set qstates; + // state -> local_forward_state_weights of this state + std::unordered_map local_forward_weights; + // state -> (src_state, arc_index) entering this state which contributes to + // `local_forward_weights` of this state. + std::unordered_map> + local_backward_arcs; + local_forward_weights.emplace(i, forward_state_weights[i]); + // `-1` means we have traced back to current state `i` + local_backward_arcs.emplace(i, std::make_pair(i, -1)); + q.push(i); + qstates.insert(i); + while (!q.empty()) { + int32_t state = q.top(); + q.pop(); + int32_t arc_end = fsa.arc_indexes[state + 1]; + for (int32_t arc_index = fsa.arc_indexes[state]; arc_index != arc_end; + ++arc_index) { + int32_t next_state = arcs_a[arc_index].dest_state; + int32_t label = arcs_a[arc_index].label; + double next_weight = + local_forward_weights[state] + arc_weights_a[arc_index]; + if (next_weight + backward_state_weights[next_state] >= best_weight) { + if (label == kEpsilon) { + auto result = + local_forward_weights.emplace(next_state, next_weight); + if (result.second) { + local_backward_arcs[next_state] = + std::make_pair(state, arc_index); + } else { + if (next_weight > result.first->second) { + result.first->second = next_weight; + local_backward_arcs[next_state] = + std::make_pair(state, arc_index); + } + } + if (qstates.find(next_state) == qstates.end()) { + q.push(next_state); + qstates.insert(next_state); + } + } else { + b->arcs.emplace_back(curr_state_b, state_map_a2b[next_state], + label); + std::vector curr_arc_deriv; + std::pair curr_backward_arc{state, arc_index}; + auto *backward_arc = &curr_backward_arc; + while (backward_arc->second != -1) { + curr_arc_deriv.push_back(backward_arc->second); + backward_arc = &(local_backward_arcs[backward_arc->first]); + } + std::reverse(curr_arc_deriv.begin(), curr_arc_deriv.end()); + arc_derivs->emplace_back(std::move(curr_arc_deriv)); + ++arc_num_b; + } } } } diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index a5e3191b4..fc4a56c06 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -90,7 +90,7 @@ bool Connect(const Fsa &a, Fsa *b, std::vector *arc_map = nullptr); keep paths that are within `beam` of the best path. Just make this very large if you don't want pruning. @param [out] b The output FSA; will be epsilon-free, and the states - will be in the same order that they were in in `a`. + will be in the same order that they were in `a`. @param [out] arc_derivs Indexed by arc in `b`, this is the sequence of arcs in `a` that this arc in `b` corresponds to; the weight of the arc in b will equal the sum of those input diff --git a/k2/csrc/fsa_algo_test.cc b/k2/csrc/fsa_algo_test.cc index 81e865553..9426220d2 100644 --- a/k2/csrc/fsa_algo_test.cc +++ b/k2/csrc/fsa_algo_test.cc @@ -322,10 +322,10 @@ TEST_F(RmEpsilonTest, RmEpsilonsPrunedMax) { std::vector> arc_derivs_b; RmEpsilonsPrunedMax(*max_wfsa_, 8, &b, &arc_derivs_b); - IsEpsilonFree(b); - ASSERT_EQ(b.arcs.size(), 10); + EXPECT_TRUE(IsEpsilonFree(b)); + ASSERT_EQ(b.arcs.size(), 11); ASSERT_EQ(b.arc_indexes.size(), 7); - ASSERT_EQ(arc_derivs_b.size(), 10); + ASSERT_EQ(arc_derivs_b.size(), 11); // TODO(haowen): check the equivalence after implementing RandEquivalent for // WFSA From fd4f964e6e076104b7cecc7de8ef86acda82dc88 Mon Sep 17 00:00:00 2001 From: Haowen Qiu Date: Tue, 26 May 2020 13:44:27 +0800 Subject: [PATCH 3/3] replace head with map in RmEpsilon --- k2/csrc/fsa_algo.cc | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/k2/csrc/fsa_algo.cc b/k2/csrc/fsa_algo.cc index 6e3598085..138e0b208 100644 --- a/k2/csrc/fsa_algo.cc +++ b/k2/csrc/fsa_algo.cc @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -320,13 +321,12 @@ void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b, if (non_eps_in[i] != 1) continue; b->arc_indexes.push_back(arc_num_b); int32_t curr_state_b = state_map_a2b[i]; - // as the input FSA is top-sorted, we use a heap here so we can process + // as the input FSA is top-sorted, we use a map here so we can process // states when they already have the best cost they are going to get - std::priority_queue, std::greater> q; // stores states that have been queued - std::unordered_set qstates; - // state -> local_forward_state_weights of this state - std::unordered_map local_forward_weights; + std::map + local_forward_weights; // state -> local_forward_state_weights of this + // state // state -> (src_state, arc_index) entering this state which contributes to // `local_forward_weights` of this state. std::unordered_map> @@ -334,18 +334,19 @@ void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b, local_forward_weights.emplace(i, forward_state_weights[i]); // `-1` means we have traced back to current state `i` local_backward_arcs.emplace(i, std::make_pair(i, -1)); - q.push(i); - qstates.insert(i); - while (!q.empty()) { - int32_t state = q.top(); - q.pop(); + while (!local_forward_weights.empty()) { + std::pair curr_local_forward_weights = + *(local_forward_weights.begin()); + local_forward_weights.erase(local_forward_weights.begin()); + int32_t state = curr_local_forward_weights.first; + int32_t arc_end = fsa.arc_indexes[state + 1]; for (int32_t arc_index = fsa.arc_indexes[state]; arc_index != arc_end; ++arc_index) { int32_t next_state = arcs_a[arc_index].dest_state; int32_t label = arcs_a[arc_index].label; double next_weight = - local_forward_weights[state] + arc_weights_a[arc_index]; + curr_local_forward_weights.second + arc_weights_a[arc_index]; if (next_weight + backward_state_weights[next_state] >= best_weight) { if (label == kEpsilon) { auto result = @@ -360,10 +361,6 @@ void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b, std::make_pair(state, arc_index); } } - if (qstates.find(next_state) == qstates.end()) { - q.push(next_state); - qstates.insert(next_state); - } } else { b->arcs.emplace_back(curr_state_b, state_map_a2b[next_state], label);