From ff6a419a01931d926166216c271862f4c4b88603 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 25 Apr 2020 18:45:51 +0800 Subject: [PATCH] fix GetEnteringArcs. --- k2/csrc/fsa_renderer.cc | 2 +- k2/csrc/fsa_util.cc | 35 ++++++++++------------ k2/csrc/fsa_util.h | 22 +++++++------- k2/csrc/fsa_util_test.cc | 64 ++++++++++++++++------------------------ 4 files changed, 53 insertions(+), 70 deletions(-) diff --git a/k2/csrc/fsa_renderer.cc b/k2/csrc/fsa_renderer.cc index 5edc41b77..3ba9b52cd 100644 --- a/k2/csrc/fsa_renderer.cc +++ b/k2/csrc/fsa_renderer.cc @@ -36,7 +36,7 @@ using k2::StateId; std::string ProcessState(const Fsa &fsa, int32_t state) { std::ostringstream os; os << " " << state << " [label = \"" << state - << "\", shape = circle, style = bold, fontsize=14]" + << "\", shape = circle, style = bold, fontsize = 14]" << "\n"; int32_t begin = fsa.leaving_arcs[state].begin; diff --git a/k2/csrc/fsa_util.cc b/k2/csrc/fsa_util.cc index 0caa327dd..6a851f9b5 100644 --- a/k2/csrc/fsa_util.cc +++ b/k2/csrc/fsa_util.cc @@ -11,34 +11,29 @@ namespace k2 { -void GetEnteringArcs(const Fsa &fsa, VecOfVec *entering_arcs) { +void GetEnteringArcs(const Fsa &fsa, std::vector *arc_index, + std::vector *end_index) { // CHECK(CheckProperties(fsa, KTopSorted)); int32_t num_states = fsa.NumStates(); - std::vector>> vec(num_states); + std::vector> vec(num_states); int32_t num_arcs = 0; + int32_t k = 0; for (const auto &arc : fsa.arcs) { - auto src_state = arc.src_state; auto dest_state = arc.dest_state; - auto label = arc.label; - vec[dest_state].emplace_back(label, src_state); + vec[dest_state].push_back(k); ++num_arcs; + ++k; } - - auto &ranges = entering_arcs->ranges; - auto &values = entering_arcs->values; - ranges.clear(); - values.clear(); - ranges.reserve(num_states); - values.reserve(num_arcs); - - int32_t start = 0; - int32_t end = 0; - for (const auto &label_state : vec) { - values.insert(values.end(), label_state.begin(), label_state.end()); - start = end; - end += static_cast(label_state.size()); - ranges.push_back({start, end}); + arc_index->clear(); + end_index->clear(); + arc_index->reserve(num_arcs); + end_index->reserve(num_states); + + for (const auto &indices : vec) { + arc_index->insert(arc_index->end(), indices.begin(), indices.end()); + auto end = static_cast(arc_index->size()); + end_index->push_back(end); } } diff --git a/k2/csrc/fsa_util.h b/k2/csrc/fsa_util.h index 98d8a0bc9..879535e8b 100644 --- a/k2/csrc/fsa_util.h +++ b/k2/csrc/fsa_util.h @@ -4,11 +4,13 @@ // See ../../LICENSE for clarification regarding multiple authors -#include "k2/csrc/fsa.h" - #ifndef K2_CSRC_FSA_UTIL_H_ #define K2_CSRC_FSA_UTIL_H_ +#include + +#include "k2/csrc/fsa.h" + namespace k2 { /* @@ -20,19 +22,19 @@ namespace k2 { @param [out] arc_index A list of arc indexes. For states 0 < s < fsa.NumStates(), - the elements arc_index[i] for end_index[s-1] <= i < end_index[s] - contain the arc-indexes in fsa.arcs for arcs that - enter state s. + the elements arc_index[i] for + end_index[s-1] <= i < end_index[s] contain the + arc-indexes in fsa.arcs for arcs that enter + state s. @param [out] end_index For each state, the `end` index in `arc_index` where we can find arcs entering this state, i.e. - one past the index of the last element in `arc_index` - that points to an arc entering this state. + one past the index of the last element in + `arc_index` that points to an arc entering + this state. */ -void GetEnteringArcs(const Fsa &fsa, - std::vector *arc_index, +void GetEnteringArcs(const Fsa &fsa, std::vector *arc_index, std::vector *end_index); - } // namespace k2 #endif // K2_CSRC_FSA_UTIL_H_ diff --git a/k2/csrc/fsa_util_test.cc b/k2/csrc/fsa_util_test.cc index d5075fbe7..6eb14455c 100644 --- a/k2/csrc/fsa_util_test.cc +++ b/k2/csrc/fsa_util_test.cc @@ -15,7 +15,11 @@ namespace k2 { TEST(FsaUtil, GetEnteringArcs) { std::vector arcs = { - {0, 1, 2}, {0, 2, 1}, {1, 2, 0}, {1, 3, 5}, {2, 3, 6}, + {0, 1, 2}, // 0 + {0, 2, 1}, // 1 + {1, 2, 0}, // 2 + {1, 3, 5}, // 3 + {2, 3, 6}, // 4 }; std::vector leaving_arcs = { {0, 2}, {2, 4}, {4, 5}, {0, 0}, // the last state has no leaving arcs @@ -25,44 +29,26 @@ TEST(FsaUtil, GetEnteringArcs) { fsa.leaving_arcs = std::move(leaving_arcs); fsa.arcs = std::move(arcs); - VecOfVec entering_arcs; - GetEnteringArcs(fsa, &entering_arcs); - - const auto &ranges = entering_arcs.ranges; - const auto &values = entering_arcs.values; - EXPECT_EQ(ranges.size(), 4u); // there are 4 states - EXPECT_EQ(values.size(), 5u); // there are 5 arcs - - // state 0, no entering arcs - EXPECT_EQ(ranges[0].begin, ranges[0].end); - - // state 1 has one entering arc from state 0 with label 2 - EXPECT_EQ(ranges[1].begin, 0); - EXPECT_EQ(ranges[1].end, 1); - EXPECT_EQ(values[0].first, 2); // label is 2 - EXPECT_EQ(values[0].second, 0); // state is 0 - - // state 2 has two entering arcs - // the first one: from state 0 with label 1 - // the second one: from state 1 with label 0 - EXPECT_EQ(ranges[2].begin, 1); - EXPECT_EQ(ranges[2].end, 3); - EXPECT_EQ(values[1].first, 1); // label is 1 - EXPECT_EQ(values[1].second, 0); // state is 0 - - EXPECT_EQ(values[2].first, 0); // label is 0 - EXPECT_EQ(values[2].second, 1); // state is 1 - - // state 3 has two entering arcs - // the first one: from state 1 with label 5 - // the second one: from state 2 with label 6 - EXPECT_EQ(ranges[3].begin, 3); - EXPECT_EQ(ranges[3].end, 5); - EXPECT_EQ(values[3].first, 5); // label is 5 - EXPECT_EQ(values[3].second, 1); // state is 1 - - EXPECT_EQ(values[4].first, 6); // label is 6 - EXPECT_EQ(values[4].second, 2); // state is 2 + std::vector arc_index(10); // an arbitray number + std::vector end_index(20); + + GetEnteringArcs(fsa, &arc_index, &end_index); + + EXPECT_EQ(end_index.size(), 4u); // there are 4 states + EXPECT_EQ(arc_index.size(), 5u); // there are 5 arcs + + EXPECT_EQ(end_index[0], 0); // state 0 has no entering arcs + + EXPECT_EQ(end_index[1], 1); // state 1 has one entering arc + EXPECT_EQ(arc_index[0], 0); // arc index 0 from state 0 + + EXPECT_EQ(end_index[2], 3); // state 2 has two entering arcs + EXPECT_EQ(arc_index[1], 1); // arc index 1 from state 0 + EXPECT_EQ(arc_index[2], 2); // arc index 2 from state 1 + + EXPECT_EQ(end_index[3], 5); // state 3 has two entering arcs + EXPECT_EQ(arc_index[3], 3); // arc index 3 from state 1 + EXPECT_EQ(arc_index[4], 4); // arc index 4 from state 2 } } // namespace k2