Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix GetEnteringArcs. #9

Merged
merged 1 commit into from
Apr 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion k2/csrc/fsa_renderer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ using k2::StateId;
std::string ProcessState(const Fsa &fsa, int32_t state) {
std::ostringstream os;
os << " " << state << " [label = \"" << state
<< "\", shape = circle, style = bold, fontsize=14]"
<< "\", shape = circle, style = bold, fontsize = 14]"
<< "\n";

int32_t begin = fsa.leaving_arcs[state].begin;
Expand Down
35 changes: 15 additions & 20 deletions k2/csrc/fsa_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,29 @@

namespace k2 {

void GetEnteringArcs(const Fsa &fsa, VecOfVec *entering_arcs) {
void GetEnteringArcs(const Fsa &fsa, std::vector<int32_t> *arc_index,
std::vector<int32_t> *end_index) {
// CHECK(CheckProperties(fsa, KTopSorted));

int32_t num_states = fsa.NumStates();
std::vector<std::vector<std::pair<Label, StateId>>> vec(num_states);
std::vector<std::vector<int32_t>> vec(num_states);
int32_t num_arcs = 0;
int32_t k = 0;
for (const auto &arc : fsa.arcs) {
auto src_state = arc.src_state;
auto dest_state = arc.dest_state;
auto label = arc.label;
vec[dest_state].emplace_back(label, src_state);
vec[dest_state].push_back(k);
++num_arcs;
++k;
}

auto &ranges = entering_arcs->ranges;
auto &values = entering_arcs->values;
ranges.clear();
values.clear();
ranges.reserve(num_states);
values.reserve(num_arcs);

int32_t start = 0;
int32_t end = 0;
for (const auto &label_state : vec) {
values.insert(values.end(), label_state.begin(), label_state.end());
start = end;
end += static_cast<int32_t>(label_state.size());
ranges.push_back({start, end});
arc_index->clear();
end_index->clear();
arc_index->reserve(num_arcs);
end_index->reserve(num_states);

for (const auto &indices : vec) {
arc_index->insert(arc_index->end(), indices.begin(), indices.end());
auto end = static_cast<int32_t>(arc_index->size());
end_index->push_back(end);
}
}

Expand Down
22 changes: 12 additions & 10 deletions k2/csrc/fsa_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

// See ../../LICENSE for clarification regarding multiple authors

#include "k2/csrc/fsa.h"

#ifndef K2_CSRC_FSA_UTIL_H_
#define K2_CSRC_FSA_UTIL_H_

#include <vector>

#include "k2/csrc/fsa.h"

namespace k2 {

/*
Expand All @@ -20,19 +22,19 @@ namespace k2 {

@param [out] arc_index A list of arc indexes.
For states 0 < s < fsa.NumStates(),
the elements arc_index[i] for end_index[s-1] <= i < end_index[s]
contain the arc-indexes in fsa.arcs for arcs that
enter state s.
the elements arc_index[i] for
end_index[s-1] <= i < end_index[s] contain the
arc-indexes in fsa.arcs for arcs that enter
state s.
@param [out] end_index For each state, the `end` index in `arc_index`
where we can find arcs entering this state, i.e.
one past the index of the last element in `arc_index`
that points to an arc entering this state.
one past the index of the last element in
`arc_index` that points to an arc entering
this state.
*/
void GetEnteringArcs(const Fsa &fsa,
std::vector<int32_t> *arc_index,
void GetEnteringArcs(const Fsa &fsa, std::vector<int32_t> *arc_index,
std::vector<int32_t> *end_index);


} // namespace k2

#endif // K2_CSRC_FSA_UTIL_H_
64 changes: 25 additions & 39 deletions k2/csrc/fsa_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ namespace k2 {

TEST(FsaUtil, GetEnteringArcs) {
std::vector<Arc> arcs = {
{0, 1, 2}, {0, 2, 1}, {1, 2, 0}, {1, 3, 5}, {2, 3, 6},
{0, 1, 2}, // 0
{0, 2, 1}, // 1
{1, 2, 0}, // 2
{1, 3, 5}, // 3
{2, 3, 6}, // 4
};
std::vector<Range> leaving_arcs = {
{0, 2}, {2, 4}, {4, 5}, {0, 0}, // the last state has no leaving arcs
Expand All @@ -25,44 +29,26 @@ TEST(FsaUtil, GetEnteringArcs) {
fsa.leaving_arcs = std::move(leaving_arcs);
fsa.arcs = std::move(arcs);

VecOfVec entering_arcs;
GetEnteringArcs(fsa, &entering_arcs);

const auto &ranges = entering_arcs.ranges;
const auto &values = entering_arcs.values;
EXPECT_EQ(ranges.size(), 4u); // there are 4 states
EXPECT_EQ(values.size(), 5u); // there are 5 arcs

// state 0, no entering arcs
EXPECT_EQ(ranges[0].begin, ranges[0].end);

// state 1 has one entering arc from state 0 with label 2
EXPECT_EQ(ranges[1].begin, 0);
EXPECT_EQ(ranges[1].end, 1);
EXPECT_EQ(values[0].first, 2); // label is 2
EXPECT_EQ(values[0].second, 0); // state is 0

// state 2 has two entering arcs
// the first one: from state 0 with label 1
// the second one: from state 1 with label 0
EXPECT_EQ(ranges[2].begin, 1);
EXPECT_EQ(ranges[2].end, 3);
EXPECT_EQ(values[1].first, 1); // label is 1
EXPECT_EQ(values[1].second, 0); // state is 0

EXPECT_EQ(values[2].first, 0); // label is 0
EXPECT_EQ(values[2].second, 1); // state is 1

// state 3 has two entering arcs
// the first one: from state 1 with label 5
// the second one: from state 2 with label 6
EXPECT_EQ(ranges[3].begin, 3);
EXPECT_EQ(ranges[3].end, 5);
EXPECT_EQ(values[3].first, 5); // label is 5
EXPECT_EQ(values[3].second, 1); // state is 1

EXPECT_EQ(values[4].first, 6); // label is 6
EXPECT_EQ(values[4].second, 2); // state is 2
std::vector<int32_t> arc_index(10); // an arbitray number
std::vector<int32_t> end_index(20);

GetEnteringArcs(fsa, &arc_index, &end_index);

EXPECT_EQ(end_index.size(), 4u); // there are 4 states
EXPECT_EQ(arc_index.size(), 5u); // there are 5 arcs

EXPECT_EQ(end_index[0], 0); // state 0 has no entering arcs

EXPECT_EQ(end_index[1], 1); // state 1 has one entering arc
EXPECT_EQ(arc_index[0], 0); // arc index 0 from state 0

EXPECT_EQ(end_index[2], 3); // state 2 has two entering arcs
EXPECT_EQ(arc_index[1], 1); // arc index 1 from state 0
EXPECT_EQ(arc_index[2], 2); // arc index 2 from state 1

EXPECT_EQ(end_index[3], 5); // state 3 has two entering arcs
EXPECT_EQ(arc_index[3], 3); // arc index 3 from state 1
EXPECT_EQ(arc_index[4], 4); // arc index 4 from state 2
}

} // namespace k2