Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement RmEpsilonPruneMax #40

Merged
merged 3 commits into from
May 26, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add derivs_out
qindazhu committed May 20, 2020
commit 0983f9dca13c5817a18ce03e04438b9fb0fefdbe
2 changes: 2 additions & 0 deletions k2/csrc/determinize.h
Original file line number Diff line number Diff line change
@@ -609,6 +609,8 @@ int32_t DetState<TracebackState>::ProcessArcs(
derivs_per_arc->push_back(std::move(deriv_info));
if (is_new_state)
queue->push(std::unique_ptr<DetState<TracebackState>>(det_state));
else
delete det_state;
} else {
delete det_state;
}
105 changes: 105 additions & 0 deletions k2/csrc/fsa_algo.cc
Original file line number Diff line number Diff line change
@@ -275,6 +275,111 @@ bool Connect(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map /*=nullptr*/) {
return is_acyclic;
}

void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b,
std::vector<std::vector<int32_t>> *arc_derivs) {
CHECK_EQ(a.weight_type, kMaxWeight);
CHECK_GT(beam, 0);
CHECK_NOTNULL(b);
CHECK_NOTNULL(arc_derivs);
b->arc_indexes.clear();
b->arcs.clear();
arc_derivs->clear();

const auto &fsa = a.fsa;
if (IsEmpty(fsa)) return;
int32_t num_states_a = fsa.NumStates();
int32_t final_state = fsa.FinalState();
const auto &arcs_a = fsa.arcs;
const float *arc_weights_a = a.arc_weights;

// identify all states that should be kept
std::vector<char> non_eps_in(num_states_a, 0);
non_eps_in[0] = 1;
using WeightsPair = std::pair<double, std::vector<int32_t>>;
// (label, dest_state) -> `sum` of weights along all paths from current state
// to `dest_state` with label == `label`(or plus numbers of epsilon),
// `vector<int32_t>` in `WeightsPair` records all arc-indexes in `a` that
// contributes the current arc in `b`.
using ArcMap =
std::unordered_map<std::pair<int32_t, int32_t>, WeightsPair, PairHash>;
std::vector<ArcMap> arcs_b(num_states_a);
for (int32_t i = static_cast<int32_t>(arcs_a.size()) - 1; i >= 0; --i) {
const auto &arc = arcs_a[i];
const auto src_state = arc.src_state;
const auto dest_state = arc.dest_state;
const auto label = arc.label;
DCHECK_GE(dest_state, src_state);

double arc_weight = arc_weights_a[i];
if (label != kEpsilon) {
non_eps_in[dest_state] = 1;
WeightsPair weights_pair =
std::make_pair(arc_weight, std::vector<int32_t>{i});
auto insert_result = arcs_b[src_state].emplace(
std::make_pair(label, dest_state), weights_pair);
if (!insert_result.second) {
auto &old_weights_pair = insert_result.first->second;
// compare `arc_weights`
if (weights_pair.first > old_weights_pair.first) {
std::swap(old_weights_pair, weights_pair);
}
}
} else {
// remove epsilon arcs
for (const auto &item : arcs_b[dest_state]) {
auto weights_pair = item.second; // copy intended
// `times` the arc weights along path
weights_pair.first += arc_weight;
weights_pair.second.push_back(i);
auto insert_result =
arcs_b[src_state].emplace(item.first, weights_pair);
if (!insert_result.second) {
auto &old_weights_pair = insert_result.first->second;
// compare `arc_weights`
if (weights_pair.first > old_weights_pair.first) {
std::swap(old_weights_pair, weights_pair);
}
}
}
}
}

// remap state id
std::vector<int32_t> state_map_a2b(num_states_a, -1);
int32_t num_states_b = 0;
for (int32_t i = 0; i != num_states_a; ++i) {
if (non_eps_in[i] == 1) state_map_a2b[i] = num_states_b++;
}

// prune and output `b`
const double *forward_state_weights = a.ForwardStateWeights();
const double *backward_state_weights = a.BackwardStateWeights();
const double best_weight = forward_state_weights[final_state] - beam;
b->arc_indexes.reserve(num_states_b + 1);
int32_t arc_num_b = 0;
for (int32_t s = 0; s < num_states_a; ++s) {
if (non_eps_in[s] == 1) {
b->arc_indexes.push_back(arc_num_b);
for (const auto &arcs : arcs_b[s]) {
int32_t dest_state = arcs.first.second;
double weight = arcs.second.first;
if (forward_state_weights[s] + weight +
backward_state_weights[dest_state] >
best_weight) {
b->arcs.emplace_back(state_map_a2b[s], state_map_a2b[dest_state],
arcs.first.first);
auto curr_arc_deriv = std::move(arcs.second.second);
std::reverse(curr_arc_deriv.begin(), curr_arc_deriv.end());
arc_derivs->emplace_back(curr_arc_deriv);
++arc_num_b;
}
}
}
}
// duplicate of final state
b->arc_indexes.push_back(b->arc_indexes.back());
}

bool Intersect(const Fsa &a, const Fsa &b, Fsa *c,
std::vector<int32_t> *arc_map_a /*= nullptr*/,
std::vector<int32_t> *arc_map_b /*= nullptr*/) {
68 changes: 5 additions & 63 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
@@ -91,71 +91,13 @@ bool Connect(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map = nullptr);
Just make this very large if you don't want pruning.
@param [out] b The output FSA; will be epsilon-free, and the states
will be in the same order that they were in in `a`.
@param [out] arc_map If non-NULL: for each arc in `b`, a list of
the arc-indexes in `a`, in order, that contributed
to that arc (e.g. its cost would be a sum of their costs).

Notes on algorithm (please rework all this when it's complete, i.e. just
make sure the code is clear and remove this).

The states in the output FSA will correspond to the subset of states in the
input FSA which are within `beam` of the best path and which have at least
one non-epsilon arc entering them, plus the start state. (Note: this
automatically includes the final state, assuming `a` has at least one
successful path; if it does not, the output will be empty).

If we ever need the associated state map from calling code, we'll add an
extra output argument to this function.

The basic algorithm is to (1) identify the kept states, (2) from each kept
input-state ki, we'll iterate over all states that are reachable via zero
or more epsilons from this state and process the non-epsilon outgoing arcs
from those states, which will become the arcs in the output. We'll also
store a back-pointer array that will allow us to figure out the best path
back to ki, in order to produce the output `arc_map`. Assume we have
arrays

local_forward_weights (float) and local_backpointers (int) indexed by
state-id, and that the local_forward_weights are initialized with
-infinity's each time we process a new ki. (we have to figure out how to do
this efficiently).


Processing input-state ki:
local_forward_state_weights[ki] = forward_state_weights[ki] // from
WfsaWithFbWeights.
// Caution:
we should probably use
// double
here; these kinds of algorithms
// are
extremely sensitive to roundoff for
// very
long FSAs. local_backpointers[ki] = -1 // will terminate a sequence..
queue.push_back(ki)
while (!queue.empty()) {
ji = queue.front() // we have to be a bit careful about order here,
to make sure
// we always process states when they already
have the
// best cost they are going to get. If
// FSA was top-sorted at the start, which we
assume, we could perhaps
// process them in numerical order, e.g. using a
heap. queue.pop_front() for each arc leaving state ji: next_weight =
local_forward_state_weights[ji] + arc_weights[this_arc_index] if next_weight
+ backward_state_weights[arc_dest_state] < best_path_weight - beam: if arc
label is epsilon: if next_weight < local_forward_state_weight[next_state]:
local_forward_state_weight[next_state] = next_weight
local_backpointers[next_state] = ji
else:
add an arc to the output FSA, and create the appropriate
arc_map entry by following backpointers (hopefully you can
figure out the details). Note: the output FSA's weights can
be computed later on, by calling code, using the info in arc_map.
@param [out] arc_derivs Indexed by arc in `b`, this is the sequence of
arcs in `a` that this arc in `b` corresponds to; the
weight of the arc in b will equal the sum of those input
arcs' weights
*/
void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b,
std::vector<std::vector<int32_t>> *arc_map);
std::vector<std::vector<int32_t>> *arc_derivs);

/*
Version of RmEpsilonsPrunedMax that doesn't support pruning; see its
48 changes: 48 additions & 0 deletions k2/csrc/fsa_algo_test.cc
Original file line number Diff line number Diff line change
@@ -283,6 +283,54 @@ TEST(FsaAlgo, Connect) {
}
}

class RmEpsilonTest : public ::testing::Test {
protected:
RmEpsilonTest() {
std::vector<Arc> arcs = {
{0, 4, 1}, {0, 1, 1}, {1, 2, 0}, {1, 3, 0}, {1, 4, 0},
{2, 7, 0}, {3, 7, 0}, {4, 6, 1}, {4, 6, 0}, {4, 8, 1},
{4, 9, -1}, {5, 9, -1}, {6, 9, -1}, {7, 9, -1}, {8, 9, -1},
};
fsa_ = new Fsa(std::move(arcs), 9);
num_states_ = fsa_->NumStates();

auto num_arcs = fsa_->arcs.size();
arc_weights_ = new float[num_arcs];
std::vector<float> weights = {1, 1, 2, 3, 2, 4, 5, 2, 3, 3, 2, 4, 3, 5, 6};
std::copy_n(weights.begin(), num_arcs, arc_weights_);

max_wfsa_ = new WfsaWithFbWeights(*fsa_, arc_weights_, kMaxWeight);
log_wfsa_ = new WfsaWithFbWeights(*fsa_, arc_weights_, kLogSumWeight);
}

~RmEpsilonTest() {
delete fsa_;
delete[] arc_weights_;
delete max_wfsa_;
delete log_wfsa_;
}

WfsaWithFbWeights *max_wfsa_;
WfsaWithFbWeights *log_wfsa_;
Fsa *fsa_;
int32_t num_states_;
float *arc_weights_;
};

TEST_F(RmEpsilonTest, RmEpsilonsPrunedMax) {
Fsa b;
std::vector<std::vector<int32_t>> arc_derivs_b;
RmEpsilonsPrunedMax(*max_wfsa_, 8, &b, &arc_derivs_b);

IsEpsilonFree(b);
ASSERT_EQ(b.arcs.size(), 10);
ASSERT_EQ(b.arc_indexes.size(), 7);
ASSERT_EQ(arc_derivs_b.size(), 10);

// TODO(haowen): check the equivalence after implementing RandEquivalent for
// WFSA
}

TEST(FsaAlgo, Intersect) {
// empty fsa
{