-
Notifications
You must be signed in to change notification settings - Fork 222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement RmEpsilonPrunedLogSum #47
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,8 +11,10 @@ | |
#include <functional> | ||
#include <limits> | ||
#include <map> | ||
#include <memory> | ||
#include <numeric> | ||
#include <queue> | ||
#include <set> | ||
#include <stack> | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
|
@@ -50,6 +52,41 @@ inline int32_t InsertIntersectionState( | |
return result.first->second; | ||
} | ||
|
||
static void TraceBackRmEpsilonLogSum( | ||
std::map<int32_t, k2::LogSumTracebackState *> *curr_states, | ||
const float *arc_weights_in, | ||
std::vector<std::pair<int32_t, float>> *deriv_out) { | ||
CHECK_EQ(curr_states->size(), 1); | ||
deriv_out->clear(); | ||
|
||
// as the input fsa is top-sorted, we traverse states in a reverse order so we | ||
// can process them when they already have correct backward_prob (all leaving | ||
// arcs have been processed). | ||
k2::LogSumTracebackState *state_ptr = curr_states->rbegin()->second; | ||
state_ptr->backward_prob = -state_ptr->forward_prob; | ||
while (!state_ptr->prev_elements.empty()) { | ||
double backward_prob = state_ptr->backward_prob; | ||
for (const auto &link : state_ptr->prev_elements) { | ||
float arc_log_posterior = link.forward_prob + backward_prob; | ||
deriv_out->emplace_back(link.arc_index, expf(arc_log_posterior)); | ||
k2::LogSumTracebackState *prev_state = link.prev_state.get(); | ||
double new_backward_prob = backward_prob + arc_weights_in[link.arc_index]; | ||
auto result = curr_states->emplace(prev_state->state_id, prev_state); | ||
if (result.second) { | ||
prev_state->backward_prob = new_backward_prob; | ||
} else { | ||
prev_state->backward_prob = | ||
k2::LogAdd(new_backward_prob, prev_state->backward_prob); | ||
} | ||
} | ||
curr_states->erase(--curr_states->end()); | ||
CHECK(!curr_states->empty()); | ||
state_ptr = curr_states->rbegin()->second; | ||
} | ||
// we have reached the state from which we are trying to remove epsilon arcs. | ||
CHECK_EQ(curr_states->size(), 1); | ||
} | ||
|
||
} // namespace | ||
|
||
namespace k2 { | ||
|
@@ -323,7 +360,6 @@ void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b, | |
int32_t curr_state_b = state_map_a2b[i]; | ||
// as the input FSA is top-sorted, we use a map here so we can process | ||
// states when they already have the best cost they are going to get | ||
// stores states that have been queued | ||
std::map<int32_t, double> | ||
local_forward_weights; // state -> local_forward_state_weights of this | ||
// state | ||
|
@@ -383,6 +419,111 @@ void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b, | |
b->arc_indexes.push_back(b->arc_indexes.back()); | ||
} | ||
|
||
void RmEpsilonsPrunedLogSum( | ||
const WfsaWithFbWeights &a, float beam, Fsa *b, | ||
std::vector<float> *b_arc_weights, | ||
qindazhu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
std::vector<std::vector<std::pair<int32_t, float>>> *arc_derivs) { | ||
CHECK_GT(beam, 0); | ||
CHECK_NOTNULL(b); | ||
CHECK_NOTNULL(b_arc_weights); | ||
CHECK_NOTNULL(arc_derivs); | ||
b->arc_indexes.clear(); | ||
b->arcs.clear(); | ||
b_arc_weights->clear(); | ||
arc_derivs->clear(); | ||
|
||
const auto &fsa = a.fsa; | ||
if (IsEmpty(fsa)) return; | ||
int32_t num_states_a = fsa.NumStates(); | ||
int32_t final_state = fsa.FinalState(); | ||
const auto &arcs_a = fsa.arcs; | ||
const float *arc_weights_a = a.arc_weights; | ||
|
||
// identify all states that should be kept | ||
std::vector<char> non_eps_in(num_states_a, 0); | ||
non_eps_in[0] = 1; | ||
for (const auto &arc : arcs_a) { | ||
// We suppose the input fsa `a` is top-sorted, but only check this in DEBUG | ||
// time. | ||
DCHECK_GE(arc.dest_state, arc.src_state); | ||
if (arc.label != kEpsilon) non_eps_in[arc.dest_state] = 1; | ||
} | ||
|
||
// remap state id | ||
std::vector<int32_t> state_map_a2b(num_states_a, -1); | ||
int32_t num_states_b = 0; | ||
for (int32_t i = 0; i != num_states_a; ++i) { | ||
if (non_eps_in[i] == 1) state_map_a2b[i] = num_states_b++; | ||
} | ||
b->arc_indexes.reserve(num_states_b + 1); | ||
int32_t arc_num_b = 0; | ||
|
||
const double *forward_state_weights = a.ForwardStateWeights(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @qindazhu I notice that your code tends to be a bit unstructured and hard to read. Compare with my code in determinize.cc... see that I mostly have shorter functions and I spend a lot of time documenting things. Just think about it for the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, actually I plan to refactor the code after implemented RmEpsMax/LogSum(non pruned version, so there are 4 similar function). Of course, will notice the documentation issue. Thanks! |
||
const double *backward_state_weights = a.BackwardStateWeights(); | ||
const double best_weight = forward_state_weights[final_state] - beam; | ||
for (int32_t i = 0; i != num_states_a; ++i) { | ||
if (non_eps_in[i] != 1) continue; | ||
b->arc_indexes.push_back(arc_num_b); | ||
int32_t curr_state_b = state_map_a2b[i]; | ||
// as the input FSA is top-sorted, we use a set here so we can process | ||
// states when they already have costs over all paths they are going to get | ||
std::set<int32_t> qstates; | ||
std::unordered_map<int32_t, std::shared_ptr<LogSumTracebackState>> | ||
traceback_states; // state -> LogSumTracebackState of this state | ||
std::shared_ptr<LogSumTracebackState> start_state( | ||
new LogSumTracebackState(i, forward_state_weights[i])); | ||
double start_forward_weights = start_state->forward_prob; | ||
traceback_states.emplace(i, start_state); | ||
qstates.insert(i); | ||
while (!qstates.empty()) { | ||
int32_t state = *(qstates.begin()); | ||
qstates.erase(qstates.begin()); | ||
|
||
const auto &curr_traceback_state = traceback_states[state]; | ||
double curr_forward_weights = curr_traceback_state->forward_prob; | ||
int32_t arc_end = fsa.arc_indexes[state + 1]; | ||
for (int32_t arc_index = fsa.arc_indexes[state]; arc_index != arc_end; | ||
++arc_index) { | ||
int32_t next_state = arcs_a[arc_index].dest_state; | ||
int32_t label = arcs_a[arc_index].label; | ||
float curr_arc_weight = arc_weights_a[arc_index]; | ||
double next_weight = curr_forward_weights + curr_arc_weight; | ||
if (next_weight + backward_state_weights[next_state] >= best_weight) { | ||
if (label == kEpsilon) { | ||
auto result = traceback_states.emplace(next_state, nullptr); | ||
if (result.second) { | ||
result.first->second = std::make_shared<LogSumTracebackState>( | ||
next_state, curr_traceback_state, arc_index, curr_arc_weight); | ||
qstates.insert(next_state); | ||
} else { | ||
result.first->second->Accept(curr_traceback_state, arc_index, | ||
curr_arc_weight); | ||
} | ||
} else { | ||
b->arcs.emplace_back(curr_state_b, state_map_a2b[next_state], | ||
label); | ||
b_arc_weights->push_back(curr_forward_weights + curr_arc_weight - | ||
start_forward_weights); | ||
|
||
std::vector<std::pair<int32_t, float>> curr_arc_deriv; | ||
std::map<int32_t, LogSumTracebackState *> curr_states; | ||
curr_states.emplace(state, curr_traceback_state.get()); | ||
TraceBackRmEpsilonLogSum(&curr_states, arc_weights_a, | ||
&curr_arc_deriv); | ||
std::reverse(curr_arc_deriv.begin(), curr_arc_deriv.end()); | ||
// push derivs info of current arc | ||
curr_arc_deriv.emplace_back(arc_index, 1); | ||
arc_derivs->emplace_back(std::move(curr_arc_deriv)); | ||
++arc_num_b; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
// duplicate of final state | ||
b->arc_indexes.push_back(b->arc_indexes.back()); | ||
} | ||
|
||
bool Intersect(const Fsa &a, const Fsa &b, Fsa *c, | ||
std::vector<int32_t> *arc_map_a /*= nullptr*/, | ||
std::vector<int32_t> *arc_map_b /*= nullptr*/) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compare this with how carefully I document functions like this in determinize.cc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, will add in another PR.