diff --git a/k2/csrc/determinize.h b/k2/csrc/determinize.h index 4a15e235d..cc133eacf 100644 --- a/k2/csrc/determinize.h +++ b/k2/csrc/determinize.h @@ -267,9 +267,8 @@ struct LogSumTracebackState { double backward_prob; // Used temporarily in algorithms as a backward prob. // Undefined most of the time. - // This constructor is to be used only for the start-state (of both the - // input FSA and the determinized FSA). - LogSumTracebackState() : state_id(0), forward_prob(0.0) {} + explicit LogSumTracebackState(int32_t state_id = 0, double forward_prob = 0.0) + : state_id(state_id), forward_prob(forward_prob) {} /* @param [in] state_id State in input FSA diff --git a/k2/csrc/fsa_algo.cc b/k2/csrc/fsa_algo.cc index 138e0b208..0b29665f7 100644 --- a/k2/csrc/fsa_algo.cc +++ b/k2/csrc/fsa_algo.cc @@ -11,8 +11,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -50,6 +52,41 @@ inline int32_t InsertIntersectionState( return result.first->second; } +static void TraceBackRmEpsilonLogSum( + std::map *curr_states, + const float *arc_weights_in, + std::vector> *deriv_out) { + CHECK_EQ(curr_states->size(), 1); + deriv_out->clear(); + + // as the input fsa is top-sorted, we traverse states in a reverse order so we + // can process them when they already have correct backward_prob (all leaving + // arcs have been processed). + k2::LogSumTracebackState *state_ptr = curr_states->rbegin()->second; + state_ptr->backward_prob = -state_ptr->forward_prob; + while (!state_ptr->prev_elements.empty()) { + double backward_prob = state_ptr->backward_prob; + for (const auto &link : state_ptr->prev_elements) { + float arc_log_posterior = link.forward_prob + backward_prob; + deriv_out->emplace_back(link.arc_index, expf(arc_log_posterior)); + k2::LogSumTracebackState *prev_state = link.prev_state.get(); + double new_backward_prob = backward_prob + arc_weights_in[link.arc_index]; + auto result = curr_states->emplace(prev_state->state_id, prev_state); + if (result.second) { + prev_state->backward_prob = new_backward_prob; + } else { + prev_state->backward_prob = + k2::LogAdd(new_backward_prob, prev_state->backward_prob); + } + } + curr_states->erase(--curr_states->end()); + CHECK(!curr_states->empty()); + state_ptr = curr_states->rbegin()->second; + } + // we have reached the state from which we are trying to remove epsilon arcs. + CHECK_EQ(curr_states->size(), 1); +} + } // namespace namespace k2 { @@ -323,7 +360,6 @@ void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b, int32_t curr_state_b = state_map_a2b[i]; // as the input FSA is top-sorted, we use a map here so we can process // states when they already have the best cost they are going to get - // stores states that have been queued std::map local_forward_weights; // state -> local_forward_state_weights of this // state @@ -383,6 +419,111 @@ void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b, b->arc_indexes.push_back(b->arc_indexes.back()); } +void RmEpsilonsPrunedLogSum( + const WfsaWithFbWeights &a, float beam, Fsa *b, + std::vector *b_arc_weights, + std::vector>> *arc_derivs) { + CHECK_GT(beam, 0); + CHECK_NOTNULL(b); + CHECK_NOTNULL(b_arc_weights); + CHECK_NOTNULL(arc_derivs); + b->arc_indexes.clear(); + b->arcs.clear(); + b_arc_weights->clear(); + arc_derivs->clear(); + + const auto &fsa = a.fsa; + if (IsEmpty(fsa)) return; + int32_t num_states_a = fsa.NumStates(); + int32_t final_state = fsa.FinalState(); + const auto &arcs_a = fsa.arcs; + const float *arc_weights_a = a.arc_weights; + + // identify all states that should be kept + std::vector non_eps_in(num_states_a, 0); + non_eps_in[0] = 1; + for (const auto &arc : arcs_a) { + // We suppose the input fsa `a` is top-sorted, but only check this in DEBUG + // time. + DCHECK_GE(arc.dest_state, arc.src_state); + if (arc.label != kEpsilon) non_eps_in[arc.dest_state] = 1; + } + + // remap state id + std::vector state_map_a2b(num_states_a, -1); + int32_t num_states_b = 0; + for (int32_t i = 0; i != num_states_a; ++i) { + if (non_eps_in[i] == 1) state_map_a2b[i] = num_states_b++; + } + b->arc_indexes.reserve(num_states_b + 1); + int32_t arc_num_b = 0; + + const double *forward_state_weights = a.ForwardStateWeights(); + const double *backward_state_weights = a.BackwardStateWeights(); + const double best_weight = forward_state_weights[final_state] - beam; + for (int32_t i = 0; i != num_states_a; ++i) { + if (non_eps_in[i] != 1) continue; + b->arc_indexes.push_back(arc_num_b); + int32_t curr_state_b = state_map_a2b[i]; + // as the input FSA is top-sorted, we use a set here so we can process + // states when they already have costs over all paths they are going to get + std::set qstates; + std::unordered_map> + traceback_states; // state -> LogSumTracebackState of this state + std::shared_ptr start_state( + new LogSumTracebackState(i, forward_state_weights[i])); + double start_forward_weights = start_state->forward_prob; + traceback_states.emplace(i, start_state); + qstates.insert(i); + while (!qstates.empty()) { + int32_t state = *(qstates.begin()); + qstates.erase(qstates.begin()); + + const auto &curr_traceback_state = traceback_states[state]; + double curr_forward_weights = curr_traceback_state->forward_prob; + int32_t arc_end = fsa.arc_indexes[state + 1]; + for (int32_t arc_index = fsa.arc_indexes[state]; arc_index != arc_end; + ++arc_index) { + int32_t next_state = arcs_a[arc_index].dest_state; + int32_t label = arcs_a[arc_index].label; + float curr_arc_weight = arc_weights_a[arc_index]; + double next_weight = curr_forward_weights + curr_arc_weight; + if (next_weight + backward_state_weights[next_state] >= best_weight) { + if (label == kEpsilon) { + auto result = traceback_states.emplace(next_state, nullptr); + if (result.second) { + result.first->second = std::make_shared( + next_state, curr_traceback_state, arc_index, curr_arc_weight); + qstates.insert(next_state); + } else { + result.first->second->Accept(curr_traceback_state, arc_index, + curr_arc_weight); + } + } else { + b->arcs.emplace_back(curr_state_b, state_map_a2b[next_state], + label); + b_arc_weights->push_back(curr_forward_weights + curr_arc_weight - + start_forward_weights); + + std::vector> curr_arc_deriv; + std::map curr_states; + curr_states.emplace(state, curr_traceback_state.get()); + TraceBackRmEpsilonLogSum(&curr_states, arc_weights_a, + &curr_arc_deriv); + std::reverse(curr_arc_deriv.begin(), curr_arc_deriv.end()); + // push derivs info of current arc + curr_arc_deriv.emplace_back(arc_index, 1); + arc_derivs->emplace_back(std::move(curr_arc_deriv)); + ++arc_num_b; + } + } + } + } + } + // duplicate of final state + b->arc_indexes.push_back(b->arc_indexes.back()); +} + bool Intersect(const Fsa &a, const Fsa &b, Fsa *c, std::vector *arc_map_a /*= nullptr*/, std::vector *arc_map_b /*= nullptr*/) { diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index fc4a56c06..c0f26fcc1 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -118,6 +118,7 @@ void RmEpsilonsMax(const Fsa &a, float *a_weights, Fsa *b, the difference will affect pruning slightly. @param [in] beam Beam for pruning, must be > 0. @param [out] b The output FSA + @param [out] b Weights per arc of b. @param [out] arc_derivs Indexed by arc-index in b, it is an list of (input-arc, deriv), where 0 < deriv <= 1, where the lists are ordered by input-arc (unlike @@ -129,6 +130,7 @@ void RmEpsilonsMax(const Fsa &a, float *a_weights, Fsa *b, */ void RmEpsilonsPrunedLogSum( const WfsaWithFbWeights &a, float beam, Fsa *b, + std::vector *b_arc_weights, std::vector>> *arc_derivs); /* @@ -136,6 +138,7 @@ void RmEpsilonsPrunedLogSum( documentation. */ void RmEpsilonsLogSum(const Fsa &a, float *a_weights, Fsa *b, + std::vector *b_arc_weights, std::vector> *arc_map); /* diff --git a/k2/csrc/fsa_algo_test.cc b/k2/csrc/fsa_algo_test.cc index 9426220d2..b6c953b6d 100644 --- a/k2/csrc/fsa_algo_test.cc +++ b/k2/csrc/fsa_algo_test.cc @@ -320,7 +320,7 @@ class RmEpsilonTest : public ::testing::Test { TEST_F(RmEpsilonTest, RmEpsilonsPrunedMax) { Fsa b; std::vector> arc_derivs_b; - RmEpsilonsPrunedMax(*max_wfsa_, 8, &b, &arc_derivs_b); + RmEpsilonsPrunedMax(*max_wfsa_, 8.0, &b, &arc_derivs_b); EXPECT_TRUE(IsEpsilonFree(b)); ASSERT_EQ(b.arcs.size(), 11); @@ -331,6 +331,22 @@ TEST_F(RmEpsilonTest, RmEpsilonsPrunedMax) { // WFSA } +TEST_F(RmEpsilonTest, RmEpsilonsPrunedLogSum) { + Fsa b; + std::vector arc_weights_b; + std::vector>> arc_derivs_b; + RmEpsilonsPrunedLogSum(*log_wfsa_, 8.0, &b, &arc_weights_b, &arc_derivs_b); + + EXPECT_TRUE(IsEpsilonFree(b)); + ASSERT_EQ(b.arcs.size(), 11); + ASSERT_EQ(b.arc_indexes.size(), 7); + ASSERT_EQ(arc_weights_b.size(), 11); + ASSERT_EQ(arc_derivs_b.size(), 11); + + // TODO(haowen): check the equivalence after implementing RandEquivalent for + // RmEpsilonPrunedLogSum +} + TEST(FsaAlgo, Intersect) { // empty fsa {