Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement RmEpsilonPrunedLogSum #47

Merged
merged 1 commit into from
May 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions k2/csrc/determinize.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,8 @@ struct LogSumTracebackState {
double backward_prob; // Used temporarily in algorithms as a backward prob.
// Undefined most of the time.

// This constructor is to be used only for the start-state (of both the
// input FSA and the determinized FSA).
LogSumTracebackState() : state_id(0), forward_prob(0.0) {}
explicit LogSumTracebackState(int32_t state_id = 0, double forward_prob = 0.0)
: state_id(state_id), forward_prob(forward_prob) {}

/*
@param [in] state_id State in input FSA
Expand Down
143 changes: 142 additions & 1 deletion k2/csrc/fsa_algo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
#include <functional>
#include <limits>
#include <map>
#include <memory>
#include <numeric>
#include <queue>
#include <set>
#include <stack>
#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -50,6 +52,41 @@ inline int32_t InsertIntersectionState(
return result.first->second;
}

static void TraceBackRmEpsilonLogSum(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compare this with how carefully I document functions like this in determinize.cc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, will add in another PR.

std::map<int32_t, k2::LogSumTracebackState *> *curr_states,
const float *arc_weights_in,
std::vector<std::pair<int32_t, float>> *deriv_out) {
CHECK_EQ(curr_states->size(), 1);
deriv_out->clear();

// as the input fsa is top-sorted, we traverse states in a reverse order so we
// can process them when they already have correct backward_prob (all leaving
// arcs have been processed).
k2::LogSumTracebackState *state_ptr = curr_states->rbegin()->second;
state_ptr->backward_prob = -state_ptr->forward_prob;
while (!state_ptr->prev_elements.empty()) {
double backward_prob = state_ptr->backward_prob;
for (const auto &link : state_ptr->prev_elements) {
float arc_log_posterior = link.forward_prob + backward_prob;
deriv_out->emplace_back(link.arc_index, expf(arc_log_posterior));
k2::LogSumTracebackState *prev_state = link.prev_state.get();
double new_backward_prob = backward_prob + arc_weights_in[link.arc_index];
auto result = curr_states->emplace(prev_state->state_id, prev_state);
if (result.second) {
prev_state->backward_prob = new_backward_prob;
} else {
prev_state->backward_prob =
k2::LogAdd(new_backward_prob, prev_state->backward_prob);
}
}
curr_states->erase(--curr_states->end());
CHECK(!curr_states->empty());
state_ptr = curr_states->rbegin()->second;
}
// we have reached the state from which we are trying to remove epsilon arcs.
CHECK_EQ(curr_states->size(), 1);
}

} // namespace

namespace k2 {
Expand Down Expand Up @@ -323,7 +360,6 @@ void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b,
int32_t curr_state_b = state_map_a2b[i];
// as the input FSA is top-sorted, we use a map here so we can process
// states when they already have the best cost they are going to get
// stores states that have been queued
std::map<int32_t, double>
local_forward_weights; // state -> local_forward_state_weights of this
// state
Expand Down Expand Up @@ -383,6 +419,111 @@ void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b,
b->arc_indexes.push_back(b->arc_indexes.back());
}

void RmEpsilonsPrunedLogSum(
const WfsaWithFbWeights &a, float beam, Fsa *b,
std::vector<float> *b_arc_weights,
std::vector<std::vector<std::pair<int32_t, float>>> *arc_derivs) {
CHECK_GT(beam, 0);
CHECK_NOTNULL(b);
CHECK_NOTNULL(b_arc_weights);
CHECK_NOTNULL(arc_derivs);
b->arc_indexes.clear();
b->arcs.clear();
b_arc_weights->clear();
arc_derivs->clear();

const auto &fsa = a.fsa;
if (IsEmpty(fsa)) return;
int32_t num_states_a = fsa.NumStates();
int32_t final_state = fsa.FinalState();
const auto &arcs_a = fsa.arcs;
const float *arc_weights_a = a.arc_weights;

// identify all states that should be kept
std::vector<char> non_eps_in(num_states_a, 0);
non_eps_in[0] = 1;
for (const auto &arc : arcs_a) {
// We suppose the input fsa `a` is top-sorted, but only check this in DEBUG
// time.
DCHECK_GE(arc.dest_state, arc.src_state);
if (arc.label != kEpsilon) non_eps_in[arc.dest_state] = 1;
}

// remap state id
std::vector<int32_t> state_map_a2b(num_states_a, -1);
int32_t num_states_b = 0;
for (int32_t i = 0; i != num_states_a; ++i) {
if (non_eps_in[i] == 1) state_map_a2b[i] = num_states_b++;
}
b->arc_indexes.reserve(num_states_b + 1);
int32_t arc_num_b = 0;

const double *forward_state_weights = a.ForwardStateWeights();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qindazhu I notice that your code tends to be a bit unstructured and hard to read. Compare with my code in determinize.cc... see that I mostly have shorter functions and I spend a lot of time documenting things. Just think about it for the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, actually I plan to refactor the code after implemented RmEpsMax/LogSum(non pruned version, so there are 4 similar function). Of course, will notice the documentation issue. Thanks!

const double *backward_state_weights = a.BackwardStateWeights();
const double best_weight = forward_state_weights[final_state] - beam;
for (int32_t i = 0; i != num_states_a; ++i) {
if (non_eps_in[i] != 1) continue;
b->arc_indexes.push_back(arc_num_b);
int32_t curr_state_b = state_map_a2b[i];
// as the input FSA is top-sorted, we use a set here so we can process
// states when they already have costs over all paths they are going to get
std::set<int32_t> qstates;
std::unordered_map<int32_t, std::shared_ptr<LogSumTracebackState>>
traceback_states; // state -> LogSumTracebackState of this state
std::shared_ptr<LogSumTracebackState> start_state(
new LogSumTracebackState(i, forward_state_weights[i]));
double start_forward_weights = start_state->forward_prob;
traceback_states.emplace(i, start_state);
qstates.insert(i);
while (!qstates.empty()) {
int32_t state = *(qstates.begin());
qstates.erase(qstates.begin());

const auto &curr_traceback_state = traceback_states[state];
double curr_forward_weights = curr_traceback_state->forward_prob;
int32_t arc_end = fsa.arc_indexes[state + 1];
for (int32_t arc_index = fsa.arc_indexes[state]; arc_index != arc_end;
++arc_index) {
int32_t next_state = arcs_a[arc_index].dest_state;
int32_t label = arcs_a[arc_index].label;
float curr_arc_weight = arc_weights_a[arc_index];
double next_weight = curr_forward_weights + curr_arc_weight;
if (next_weight + backward_state_weights[next_state] >= best_weight) {
if (label == kEpsilon) {
auto result = traceback_states.emplace(next_state, nullptr);
if (result.second) {
result.first->second = std::make_shared<LogSumTracebackState>(
next_state, curr_traceback_state, arc_index, curr_arc_weight);
qstates.insert(next_state);
} else {
result.first->second->Accept(curr_traceback_state, arc_index,
curr_arc_weight);
}
} else {
b->arcs.emplace_back(curr_state_b, state_map_a2b[next_state],
label);
b_arc_weights->push_back(curr_forward_weights + curr_arc_weight -
start_forward_weights);

std::vector<std::pair<int32_t, float>> curr_arc_deriv;
std::map<int32_t, LogSumTracebackState *> curr_states;
curr_states.emplace(state, curr_traceback_state.get());
TraceBackRmEpsilonLogSum(&curr_states, arc_weights_a,
&curr_arc_deriv);
std::reverse(curr_arc_deriv.begin(), curr_arc_deriv.end());
// push derivs info of current arc
curr_arc_deriv.emplace_back(arc_index, 1);
arc_derivs->emplace_back(std::move(curr_arc_deriv));
++arc_num_b;
}
}
}
}
}
// duplicate of final state
b->arc_indexes.push_back(b->arc_indexes.back());
}

bool Intersect(const Fsa &a, const Fsa &b, Fsa *c,
std::vector<int32_t> *arc_map_a /*= nullptr*/,
std::vector<int32_t> *arc_map_b /*= nullptr*/) {
Expand Down
3 changes: 3 additions & 0 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ void RmEpsilonsMax(const Fsa &a, float *a_weights, Fsa *b,
the difference will affect pruning slightly.
@param [in] beam Beam for pruning, must be > 0.
@param [out] b The output FSA
@param [out] b Weights per arc of b.
@param [out] arc_derivs Indexed by arc-index in b, it is an list of
(input-arc, deriv), where 0 < deriv <= 1, where the
lists are ordered by input-arc (unlike
Expand All @@ -129,13 +130,15 @@ void RmEpsilonsMax(const Fsa &a, float *a_weights, Fsa *b,
*/
void RmEpsilonsPrunedLogSum(
const WfsaWithFbWeights &a, float beam, Fsa *b,
std::vector<float> *b_arc_weights,
std::vector<std::vector<std::pair<int32_t, float>>> *arc_derivs);

/*
Version of RmEpsilonsLogSum that doesn't support pruning; see its
documentation.
*/
void RmEpsilonsLogSum(const Fsa &a, float *a_weights, Fsa *b,
std::vector<float> *b_arc_weights,
std::vector<std::vector<int32_t>> *arc_map);

/*
Expand Down
18 changes: 17 additions & 1 deletion k2/csrc/fsa_algo_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ class RmEpsilonTest : public ::testing::Test {
TEST_F(RmEpsilonTest, RmEpsilonsPrunedMax) {
Fsa b;
std::vector<std::vector<int32_t>> arc_derivs_b;
RmEpsilonsPrunedMax(*max_wfsa_, 8, &b, &arc_derivs_b);
RmEpsilonsPrunedMax(*max_wfsa_, 8.0, &b, &arc_derivs_b);

EXPECT_TRUE(IsEpsilonFree(b));
ASSERT_EQ(b.arcs.size(), 11);
Expand All @@ -331,6 +331,22 @@ TEST_F(RmEpsilonTest, RmEpsilonsPrunedMax) {
// WFSA
}

TEST_F(RmEpsilonTest, RmEpsilonsPrunedLogSum) {
Fsa b;
std::vector<float> arc_weights_b;
std::vector<std::vector<std::pair<int32_t, float>>> arc_derivs_b;
RmEpsilonsPrunedLogSum(*log_wfsa_, 8.0, &b, &arc_weights_b, &arc_derivs_b);

EXPECT_TRUE(IsEpsilonFree(b));
ASSERT_EQ(b.arcs.size(), 11);
ASSERT_EQ(b.arc_indexes.size(), 7);
ASSERT_EQ(arc_weights_b.size(), 11);
ASSERT_EQ(arc_derivs_b.size(), 11);

// TODO(haowen): check the equivalence after implementing RandEquivalent for
// RmEpsilonPrunedLogSum
}

TEST(FsaAlgo, Intersect) {
// empty fsa
{
Expand Down