diff --git a/k2/csrc/determinize.h b/k2/csrc/determinize.h index a2bbf678c..4a15e235d 100644 --- a/k2/csrc/determinize.h +++ b/k2/csrc/determinize.h @@ -609,6 +609,8 @@ int32_t DetState::ProcessArcs( derivs_per_arc->push_back(std::move(deriv_info)); if (is_new_state) queue->push(std::unique_ptr>(det_state)); + else + delete det_state; } else { delete det_state; } diff --git a/k2/csrc/fsa_algo.cc b/k2/csrc/fsa_algo.cc index f58963538..138e0b208 100644 --- a/k2/csrc/fsa_algo.cc +++ b/k2/csrc/fsa_algo.cc @@ -8,11 +8,14 @@ #include "k2/csrc/fsa_algo.h" #include +#include #include +#include #include #include #include #include +#include #include #include "glog/logging.h" @@ -275,6 +278,111 @@ bool Connect(const Fsa &a, Fsa *b, std::vector *arc_map /*=nullptr*/) { return is_acyclic; } +void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b, + std::vector> *arc_derivs) { + CHECK_EQ(a.weight_type, kMaxWeight); + CHECK_GT(beam, 0); + CHECK_NOTNULL(b); + CHECK_NOTNULL(arc_derivs); + b->arc_indexes.clear(); + b->arcs.clear(); + arc_derivs->clear(); + + const auto &fsa = a.fsa; + if (IsEmpty(fsa)) return; + int32_t num_states_a = fsa.NumStates(); + int32_t final_state = fsa.FinalState(); + const auto &arcs_a = fsa.arcs; + const float *arc_weights_a = a.arc_weights; + + // identify all states that should be kept + std::vector non_eps_in(num_states_a, 0); + non_eps_in[0] = 1; + for (const auto &arc : arcs_a) { + // We suppose the input fsa `a` is top-sorted, but only check this in DEBUG + // time. + DCHECK_GE(arc.dest_state, arc.src_state); + if (arc.label != kEpsilon) non_eps_in[arc.dest_state] = 1; + } + + // remap state id + std::vector state_map_a2b(num_states_a, -1); + int32_t num_states_b = 0; + for (int32_t i = 0; i != num_states_a; ++i) { + if (non_eps_in[i] == 1) state_map_a2b[i] = num_states_b++; + } + b->arc_indexes.reserve(num_states_b + 1); + int32_t arc_num_b = 0; + + const double *forward_state_weights = a.ForwardStateWeights(); + const double *backward_state_weights = a.BackwardStateWeights(); + const double best_weight = forward_state_weights[final_state] - beam; + for (int32_t i = 0; i != num_states_a; ++i) { + if (non_eps_in[i] != 1) continue; + b->arc_indexes.push_back(arc_num_b); + int32_t curr_state_b = state_map_a2b[i]; + // as the input FSA is top-sorted, we use a map here so we can process + // states when they already have the best cost they are going to get + // stores states that have been queued + std::map + local_forward_weights; // state -> local_forward_state_weights of this + // state + // state -> (src_state, arc_index) entering this state which contributes to + // `local_forward_weights` of this state. + std::unordered_map> + local_backward_arcs; + local_forward_weights.emplace(i, forward_state_weights[i]); + // `-1` means we have traced back to current state `i` + local_backward_arcs.emplace(i, std::make_pair(i, -1)); + while (!local_forward_weights.empty()) { + std::pair curr_local_forward_weights = + *(local_forward_weights.begin()); + local_forward_weights.erase(local_forward_weights.begin()); + int32_t state = curr_local_forward_weights.first; + + int32_t arc_end = fsa.arc_indexes[state + 1]; + for (int32_t arc_index = fsa.arc_indexes[state]; arc_index != arc_end; + ++arc_index) { + int32_t next_state = arcs_a[arc_index].dest_state; + int32_t label = arcs_a[arc_index].label; + double next_weight = + curr_local_forward_weights.second + arc_weights_a[arc_index]; + if (next_weight + backward_state_weights[next_state] >= best_weight) { + if (label == kEpsilon) { + auto result = + local_forward_weights.emplace(next_state, next_weight); + if (result.second) { + local_backward_arcs[next_state] = + std::make_pair(state, arc_index); + } else { + if (next_weight > result.first->second) { + result.first->second = next_weight; + local_backward_arcs[next_state] = + std::make_pair(state, arc_index); + } + } + } else { + b->arcs.emplace_back(curr_state_b, state_map_a2b[next_state], + label); + std::vector curr_arc_deriv; + std::pair curr_backward_arc{state, arc_index}; + auto *backward_arc = &curr_backward_arc; + while (backward_arc->second != -1) { + curr_arc_deriv.push_back(backward_arc->second); + backward_arc = &(local_backward_arcs[backward_arc->first]); + } + std::reverse(curr_arc_deriv.begin(), curr_arc_deriv.end()); + arc_derivs->emplace_back(std::move(curr_arc_deriv)); + ++arc_num_b; + } + } + } + } + } + // duplicate of final state + b->arc_indexes.push_back(b->arc_indexes.back()); +} + bool Intersect(const Fsa &a, const Fsa &b, Fsa *c, std::vector *arc_map_a /*= nullptr*/, std::vector *arc_map_b /*= nullptr*/) { diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index f7f777242..fc4a56c06 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -90,72 +90,14 @@ bool Connect(const Fsa &a, Fsa *b, std::vector *arc_map = nullptr); keep paths that are within `beam` of the best path. Just make this very large if you don't want pruning. @param [out] b The output FSA; will be epsilon-free, and the states - will be in the same order that they were in in `a`. - @param [out] arc_map If non-NULL: for each arc in `b`, a list of - the arc-indexes in `a`, in order, that contributed - to that arc (e.g. its cost would be a sum of their costs). - - Notes on algorithm (please rework all this when it's complete, i.e. just - make sure the code is clear and remove this). - - The states in the output FSA will correspond to the subset of states in the - input FSA which are within `beam` of the best path and which have at least - one non-epsilon arc entering them, plus the start state. (Note: this - automatically includes the final state, assuming `a` has at least one - successful path; if it does not, the output will be empty). - - If we ever need the associated state map from calling code, we'll add an - extra output argument to this function. - - The basic algorithm is to (1) identify the kept states, (2) from each kept - input-state ki, we'll iterate over all states that are reachable via zero - or more epsilons from this state and process the non-epsilon outgoing arcs - from those states, which will become the arcs in the output. We'll also - store a back-pointer array that will allow us to figure out the best path - back to ki, in order to produce the output `arc_map`. Assume we have - arrays - - local_forward_weights (float) and local_backpointers (int) indexed by - state-id, and that the local_forward_weights are initialized with - -infinity's each time we process a new ki. (we have to figure out how to do - this efficiently). - - - Processing input-state ki: - local_forward_state_weights[ki] = forward_state_weights[ki] // from - WfsaWithFbWeights. - // Caution: - we should probably use - // double - here; these kinds of algorithms - // are - extremely sensitive to roundoff for - // very - long FSAs. local_backpointers[ki] = -1 // will terminate a sequence.. - queue.push_back(ki) - while (!queue.empty()) { - ji = queue.front() // we have to be a bit careful about order here, - to make sure - // we always process states when they already - have the - // best cost they are going to get. If - // FSA was top-sorted at the start, which we - assume, we could perhaps - // process them in numerical order, e.g. using a - heap. queue.pop_front() for each arc leaving state ji: next_weight = - local_forward_state_weights[ji] + arc_weights[this_arc_index] if next_weight - + backward_state_weights[arc_dest_state] < best_path_weight - beam: if arc - label is epsilon: if next_weight < local_forward_state_weight[next_state]: - local_forward_state_weight[next_state] = next_weight - local_backpointers[next_state] = ji - else: - add an arc to the output FSA, and create the appropriate - arc_map entry by following backpointers (hopefully you can - figure out the details). Note: the output FSA's weights can - be computed later on, by calling code, using the info in arc_map. + will be in the same order that they were in `a`. + @param [out] arc_derivs Indexed by arc in `b`, this is the sequence of + arcs in `a` that this arc in `b` corresponds to; the + weight of the arc in b will equal the sum of those input + arcs' weights */ void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b, - std::vector> *arc_map); + std::vector> *arc_derivs); /* Version of RmEpsilonsPrunedMax that doesn't support pruning; see its diff --git a/k2/csrc/fsa_algo_test.cc b/k2/csrc/fsa_algo_test.cc index ff1e2fbc0..9426220d2 100644 --- a/k2/csrc/fsa_algo_test.cc +++ b/k2/csrc/fsa_algo_test.cc @@ -283,6 +283,54 @@ TEST(FsaAlgo, Connect) { } } +class RmEpsilonTest : public ::testing::Test { + protected: + RmEpsilonTest() { + std::vector arcs = { + {0, 4, 1}, {0, 1, 1}, {1, 2, 0}, {1, 3, 0}, {1, 4, 0}, + {2, 7, 0}, {3, 7, 0}, {4, 6, 1}, {4, 6, 0}, {4, 8, 1}, + {4, 9, -1}, {5, 9, -1}, {6, 9, -1}, {7, 9, -1}, {8, 9, -1}, + }; + fsa_ = new Fsa(std::move(arcs), 9); + num_states_ = fsa_->NumStates(); + + auto num_arcs = fsa_->arcs.size(); + arc_weights_ = new float[num_arcs]; + std::vector weights = {1, 1, 2, 3, 2, 4, 5, 2, 3, 3, 2, 4, 3, 5, 6}; + std::copy_n(weights.begin(), num_arcs, arc_weights_); + + max_wfsa_ = new WfsaWithFbWeights(*fsa_, arc_weights_, kMaxWeight); + log_wfsa_ = new WfsaWithFbWeights(*fsa_, arc_weights_, kLogSumWeight); + } + + ~RmEpsilonTest() { + delete fsa_; + delete[] arc_weights_; + delete max_wfsa_; + delete log_wfsa_; + } + + WfsaWithFbWeights *max_wfsa_; + WfsaWithFbWeights *log_wfsa_; + Fsa *fsa_; + int32_t num_states_; + float *arc_weights_; +}; + +TEST_F(RmEpsilonTest, RmEpsilonsPrunedMax) { + Fsa b; + std::vector> arc_derivs_b; + RmEpsilonsPrunedMax(*max_wfsa_, 8, &b, &arc_derivs_b); + + EXPECT_TRUE(IsEpsilonFree(b)); + ASSERT_EQ(b.arcs.size(), 11); + ASSERT_EQ(b.arc_indexes.size(), 7); + ASSERT_EQ(arc_derivs_b.size(), 11); + + // TODO(haowen): check the equivalence after implementing RandEquivalent for + // WFSA +} + TEST(FsaAlgo, Intersect) { // empty fsa {