Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add beam in RandEquivalent #46

Merged
merged 2 commits into from
May 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions k2/csrc/fsa_algo_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "k2/csrc/fsa_equivalent.h"
#include "k2/csrc/fsa_renderer.h"
#include "k2/csrc/fsa_util.h"
#include "k2/csrc/properties.h"
Expand Down Expand Up @@ -593,22 +594,23 @@ TEST_F(DeterminizeTest, DeterminizePrunedMax) {
Fsa b;
std::vector<float> b_arc_weights;
std::vector<std::vector<int32_t>> arc_derivs;
DeterminizePrunedMax(*max_wfsa_, 10, 100, &b, &b_arc_weights, &arc_derivs);
DeterminizePrunedMax(*max_wfsa_, 10.0, 100, &b, &b_arc_weights, &arc_derivs);

EXPECT_TRUE(IsDeterministic(b));

// TODO(haowen) as the type of `label_to_state` is `unordered_map` (instead of
// `map`), the output `state_id` and `arc_id` may differ under different STL
// implementations, we need to check the equivalence automatically
EXPECT_TRUE(IsRandEquivalent<kMaxWeight>(
max_wfsa_->fsa, max_wfsa_->arc_weights, b, b_arc_weights.data(), 10.0));
}

TEST_F(DeterminizeTest, DeterminizePrunedLogSum) {
Fsa b;
std::vector<float> b_arc_weights;
std::vector<std::vector<std::pair<int32_t, float>>> arc_derivs;
DeterminizePrunedLogSum(*log_wfsa_, 10, 100, &b, &b_arc_weights, &arc_derivs);
DeterminizePrunedLogSum(*log_wfsa_, 10.0, 100, &b, &b_arc_weights,
&arc_derivs);

EXPECT_TRUE(IsDeterministic(b));
EXPECT_TRUE(IsRandEquivalent<kLogSumWeight>(
log_wfsa_->fsa, log_wfsa_->arc_weights, b, b_arc_weights.data(), 10.0));

// TODO(haowen): how to check `arc_derivs_out` here, may return `num_steps` to
// check the sum of `derivs_out` for each output arc?
Expand Down
37 changes: 31 additions & 6 deletions k2/csrc/fsa_equivalent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,12 @@ bool IsRandEquivalent(const Fsa &a, const Fsa &b, std::size_t npath /*=100*/) {

template <FbWeightType Type>
bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
const float *b_weights, bool top_sorted /*=true*/,
const float *b_weights, float beam /*=kFloatInfinity*/,
float delta /*=1e-6*/, bool top_sorted /*=true*/,
std::size_t npath /*= 100*/) {
CHECK_GT(beam, 0);
CHECK_NOTNULL(a_weights);
CHECK_NOTNULL(b_weights);
Fsa connected_a, connected_b, valid_a, valid_b;
std::vector<int32_t> connected_a_arc_map, connected_b_arc_map,
valid_a_arc_map, valid_b_arc_map;
Expand Down Expand Up @@ -199,10 +203,25 @@ bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
(*(labels_difference.begin())) != kEpsilon))
return false;

double loglike_cutoff_a, loglike_cutoff_b;
if (beam != kFloatInfinity) {
loglike_cutoff_a =
ShortestDistance<Type>(valid_a, valid_a_weights.data()) - beam;
loglike_cutoff_b =
ShortestDistance<Type>(valid_b, valid_b_weights.data()) - beam;
if (Type == kMaxWeight &&
!DoubleApproxEqual(loglike_cutoff_a, loglike_cutoff_b))
return false;
} else {
loglike_cutoff_a = kDoubleNegativeInfinity;
loglike_cutoff_b = kDoubleNegativeInfinity;
}

std::random_device rd;
std::mt19937 gen(rd());
std::bernoulli_distribution coin(0.5);
for (auto i = 0; i != npath; ++i) {
std::size_t n = 0;
while (n < npath) {
const auto &fsa = coin(gen) ? valid_a : valid_b;
Fsa path, valid_path;
RandomPathWithoutEpsilonArc(fsa, &path); // path is already connected
Expand All @@ -220,22 +239,28 @@ bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
// find out that we don't need that version, we will remove flag
// `top_sorted` and add requirements as comments in the header file.
CHECK(top_sorted);
double sum_a =
double cost_a =
ShortestDistance<Type>(a_compose_path, a_compose_weights.data());
double sum_b =
double cost_b =
ShortestDistance<Type>(b_compose_path, b_compose_weights.data());
if (!DoubleApproxEqual(sum_a, sum_b)) return false;
if (cost_a < loglike_cutoff_a && cost_b < loglike_cutoff_b) {
continue;
} else {
if (!DoubleApproxEqual(cost_a, cost_b, delta)) return false;
++n;
}
}
return true;
}

// explicit instantiation here
template bool IsRandEquivalent<kMaxWeight>(const Fsa &a, const float *a_weights,
const Fsa &b, const float *b_weights,
float beam, float delta,
bool top_sorted, std::size_t npath);
template bool IsRandEquivalent<kLogSumWeight>(
const Fsa &a, const float *a_weights, const Fsa &b, const float *b_weights,
bool top_sorted, std::size_t npath);
float beam, float delta, bool top_sorted, std::size_t npath);

bool RandomPath(const Fsa &a, Fsa *b,
std::vector<int32_t> *state_map /*=nullptr*/) {
Expand Down
33 changes: 24 additions & 9 deletions k2/csrc/fsa_equivalent.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,32 @@ bool IsRandEquivalent(const Fsa &a, const Fsa &b, std::size_t npath = 100);
@param [in] a_weights Arc weights of `a`
@param [in] b The other FSA to be checked the equivalence
@param [in] b_weights Arc weights of `b`
@param [in] top_sorted If both `a` and `b` are topological sorted or not.
We may remove this flag if we finally find out that
input FSAs in all scenarios are top-sorted.
@param [in] beam beam > 0 that affects pruning; the algorithm
will only check paths within `beam` of the
best path(for tropical semiring, it's max
weight over all paths from start state to
final state; for log semiring, it's log-sum probs
over all paths) in `a` or `b`. That is,
any symbol sequence, whose total weights
over all paths are within `beam` of the best
path (either in `a` or `b`), must have
the same weights in `a` and `b`.
There is no any requirement on symbol sequences
whose total weights over paths are outside `beam`.
Just keep `kFloatInfinity` if you don't want pruning.
@param [in] delta Tolerance for path weights to check the equivalence.
If abs(weights_a, weights_b) <= delta, we say the two
paths are equivalent.
@param [in] top_sorted The user may set this to true if both `a` and `b` are
topologically sorted; this makes this function faster.
Otherwise it must be set to false.
@param [in] npath The number of paths will be generated to check the
equivalence of `a` and `b`
*/
template <FbWeightType Type>
bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
const float *b_weights, bool top_sorted = true,
const float *b_weights, float beam = kFloatInfinity,
float delta = 1e-6, bool top_sorted = true,
std::size_t npath = 100);

/*
Expand All @@ -61,18 +78,16 @@ bool RandomPath(const Fsa &a, Fsa *b,
bool RandomPathWithoutEpsilonArc(const Fsa &a, Fsa *b,
std::vector<int32_t> *state_map = nullptr);
/*
Computes the intersection of two FSAs where one FSA has weights on arc. This
function will be called in the version of `IsRandEquivalent` for Wfsa.
Computes the intersection of two FSAs where one FSA has weights on arc.

@param [in] a One of the FSAs to be intersected. Must satisfy
ArcSorted(a)
@param [in] a_weights Arc weights of `a`
@param [in] b The other FSA to be intersected Must satisfy
ArcSorted(b) and IsEpsilonFree(b). It is usually a path
generated from `RandomNonEpsilonPath`
generated from `RandomNonEpsilonPath`
@param [out] c The composed FSA will be output to here.
@param [out] c_weights Arc weights of output FSA `c` which are corresponding
arc weights in `a`
@param [out] c_weights Arc weights of output FSA `c`.
@param [out] arc_map_a If non-NULL, at exit will be a vector of
size c->arcs.size(), saying for each arc in
`c` what the source arc in `a` was, `-1` represents
Expand Down
13 changes: 13 additions & 0 deletions k2/csrc/fsa_equivalent_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,19 @@ TEST(FsaEquivalent, IsWfsaRandEquivalent) {
c_weights.data());
EXPECT_FALSE(status);
}

// check equivalence with beam
{
bool status = IsRandEquivalent<kMaxWeight>(a, a_weights.data(), b,
b_weights.data(), 4.0);
EXPECT_TRUE(status);
}
// check equivalence with beam
{
bool status = IsRandEquivalent<kMaxWeight>(a, a_weights.data(), c,
c_weights.data(), 6.0);
EXPECT_FALSE(status);
}
}

TEST(FsaEquivalent, RandomPathFail) {
Expand Down
8 changes: 2 additions & 6 deletions k2/csrc/properties.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ bool IsValid(const Fsa &fsa) {
if (arc.src_state == state) {
++num_arcs;
} else {
// every state contains at least one arc.
if (arc.src_state != state + 1) return false;
// `arc_indexes` and `arcs` in this state are not consistent.
if ((fsa.arc_indexes[state + 1] - fsa.arc_indexes[state]) != num_arcs)
return false;
Expand All @@ -55,7 +53,6 @@ bool IsValid(const Fsa &fsa) {
}
}
// check the last state
if (final_state != state + 1) return false;
if ((fsa.arc_indexes[final_state] - fsa.arc_indexes[state]) != num_arcs)
return false;
return true;
Expand Down Expand Up @@ -106,8 +103,7 @@ bool IsAcyclic(const Fsa &fsa, std::vector<int32_t> *order /*= nullptr*/) {
if (current_state.arc_begin == current_state.arc_end) {
// we have finished visiting this state
state_status[current_state.state] = kVisited;
if (order != nullptr)
order->push_back(current_state.state);
if (order != nullptr) order->push_back(current_state.state);
stack.pop();
continue;
}
Expand All @@ -119,7 +115,7 @@ bool IsAcyclic(const Fsa &fsa, std::vector<int32_t> *order /*= nullptr*/) {
// a new discovered node
state_status[next_state] = kVisiting;
stack.push({next_state, fsa.arc_indexes[next_state],
fsa.arc_indexes[next_state + 1]});
fsa.arc_indexes[next_state + 1]});
++current_state.arc_begin;
break;
}
Expand Down
5 changes: 2 additions & 3 deletions k2/csrc/properties.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ enum Properties {
`fsa` is valid if:
1. it is empty, if not, it contains at least two states.
2. only kFinalSymbol arcs enter the final state.
3. every state contains at least one arc except the final state.
4. `arcs_indexes` and `arcs` in this state are not consistent.
3. `arcs_indexes` and `arcs` in this state are not consistent.
TODO(haowen): add more rules?
*/
bool IsValid(const Fsa &fsa);
Expand All @@ -59,7 +58,7 @@ bool HasSelfLoops(const Fsa &fsa);
accessible (i.e. from the start state) are not considered.
The optional argument order, assigns the order in which visiting states is
finished in DFS traversal. State 0 has the largest order (num_states - 1) and
the final state has the smallest order (0).
the final state has the smallest order (0).
*/
bool IsAcyclic(const Fsa &fsa, std::vector<int32_t> *order = nullptr);

Expand Down
Loading