Skip to content

Commit 1a43d1e

Browse files
authored
Support getting word IDs for CTC HLG decoding. (#978)
1 parent 69347ff commit 1a43d1e

13 files changed

+60
-13
lines changed

sherpa-onnx/csrc/offline-ctc-decoder.h

+8
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,16 @@ struct OfflineCtcDecoderResult {
1515
/// The decoded token IDs
1616
std::vector<int64_t> tokens;
1717

18+
/// The decoded word IDs
19+
/// Note: tokens.size() is usually not equal to words.size()
20+
/// words is empty for greedy search decoding.
21+
/// it is not empty when an HLG graph or an HLG graph is used.
22+
std::vector<int32_t> words;
23+
1824
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
1925
/// Note: The index is after subsampling
26+
///
27+
/// tokens.size() == timestamps.size()
2028
std::vector<int32_t> timestamps;
2129
};
2230

sherpa-onnx/csrc/offline-ctc-fst-decoder.cc

+3
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder,
108108
// -1 here since the input labels are incremented during graph
109109
// construction
110110
r.tokens.push_back(arc.ilabel - 1);
111+
if (arc.olabel != 0) {
112+
r.words.push_back(arc.olabel);
113+
}
111114

112115
r.timestamps.push_back(t);
113116
prev = arc.ilabel;

sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc

-4
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,6 @@ OfflineParaformerGreedySearchDecoder::Decode(
6464

6565
if (timestamps.size() == results[i].tokens.size()) {
6666
results[i].timestamps = std::move(timestamps);
67-
} else {
68-
SHERPA_ONNX_LOGE("time stamp for batch: %d, %d vs %d", i,
69-
static_cast<int32_t>(results[i].tokens.size()),
70-
static_cast<int32_t>(timestamps.size()));
7167
}
7268
}
7369
}

sherpa-onnx/csrc/offline-recognizer-ctc-impl.h

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
6565
r.timestamps.push_back(time);
6666
}
6767

68+
r.words = std::move(src.words);
69+
6870
return r;
6971
}
7072

sherpa-onnx/csrc/offline-stream.cc

+14
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,20 @@ std::string OfflineRecognitionResult::AsJsonString() const {
339339
}
340340
sep = ", ";
341341
}
342+
os << "], ";
343+
344+
sep = "";
345+
346+
os << "\""
347+
<< "words"
348+
<< "\""
349+
<< ": ";
350+
os << "[";
351+
for (int32_t w : words) {
352+
os << sep << w;
353+
sep = ", ";
354+
}
355+
342356
os << "]";
343357
os << "}";
344358

sherpa-onnx/csrc/offline-stream.h

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ struct OfflineRecognitionResult {
3030
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
3131
std::vector<float> timestamps;
3232

33+
std::vector<int32_t> words;
34+
3335
std::string AsJsonString() const;
3436
};
3537

sherpa-onnx/csrc/online-ctc-decoder.h

+8
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,16 @@ struct OnlineCtcDecoderResult {
2222
/// The decoded token IDs
2323
std::vector<int64_t> tokens;
2424

25+
/// The decoded word IDs
26+
/// Note: tokens.size() is usually not equal to words.size()
27+
/// words is empty for greedy search decoding.
28+
/// it is not empty when an HLG graph or an HLG graph is used.
29+
std::vector<int32_t> words;
30+
2531
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
2632
/// Note: The index is after subsampling
33+
///
34+
/// tokens.size() == timestamps.size()
2735
std::vector<int32_t> timestamps;
2836

2937
int32_t num_trailing_blanks = 0;

sherpa-onnx/csrc/online-ctc-fst-decoder.cc

+4-3
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
5151
bool ok = decoder->GetBestPath(&fst_out);
5252
if (ok) {
5353
std::vector<int32_t> isymbols_out;
54-
std::vector<int32_t> osymbols_out_unused;
55-
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out,
56-
&osymbols_out_unused, nullptr);
54+
std::vector<int32_t> osymbols_out;
55+
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out, &osymbols_out,
56+
nullptr);
5757
std::vector<int64_t> tokens;
5858
tokens.reserve(isymbols_out.size());
5959

@@ -83,6 +83,7 @@ static void DecodeOne(const float *log_probs, int32_t num_rows,
8383
}
8484

8585
result->tokens = std::move(tokens);
86+
result->words = std::move(osymbols_out);
8687
result->timestamps = std::move(timestamps);
8788
// no need to set frame_offset
8889
}

sherpa-onnx/csrc/online-recognizer-ctc-impl.h

+1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
5959
}
6060

6161
r.segment = segment;
62+
r.words = std::move(src.words);
6263
r.start_time = frames_since_start * frame_shift_ms / 1000.;
6364

6465
return r;

sherpa-onnx/csrc/online-recognizer.cc

+11-6
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@ namespace sherpa_onnx {
2222
template <typename T>
2323
std::string VecToString(const std::vector<T> &vec, int32_t precision = 6) {
2424
std::ostringstream oss;
25-
oss << std::fixed << std::setprecision(precision);
26-
oss << "[ ";
25+
if (precision != 0) {
26+
oss << std::fixed << std::setprecision(precision);
27+
}
28+
oss << "[";
2729
std::string sep = "";
2830
for (const auto &item : vec) {
2931
oss << sep << item;
3032
sep = ", ";
3133
}
32-
oss << " ]";
34+
oss << "]";
3335
return oss.str();
3436
}
3537

@@ -38,26 +40,29 @@ template <> // explicit specialization for T = std::string
3840
std::string VecToString<std::string>(const std::vector<std::string> &vec,
3941
int32_t) { // ignore 2nd arg
4042
std::ostringstream oss;
41-
oss << "[ ";
43+
oss << "[";
4244
std::string sep = "";
4345
for (const auto &item : vec) {
4446
oss << sep << "\"" << item << "\"";
4547
sep = ", ";
4648
}
47-
oss << " ]";
49+
oss << "]";
4850
return oss.str();
4951
}
5052

5153
std::string OnlineRecognizerResult::AsJsonString() const {
5254
std::ostringstream os;
5355
os << "{ ";
54-
os << "\"text\": " << "\"" << text << "\"" << ", ";
56+
os << "\"text\": "
57+
<< "\"" << text << "\""
58+
<< ", ";
5559
os << "\"tokens\": " << VecToString(tokens) << ", ";
5660
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
5761
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
5862
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
5963
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
6064
os << "\"segment\": " << segment << ", ";
65+
os << "\"words\": " << VecToString(words, 0) << ", ";
6166
os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time
6267
<< ", ";
6368
os << "\"is_final\": " << (is_final ? "true" : "false");

sherpa-onnx/csrc/online-recognizer.h

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ struct OnlineRecognizerResult {
4747
/// log-domain scores from "hot-phrase" contextual boosting
4848
std::vector<float> context_scores;
4949

50+
std::vector<int32_t> words;
51+
5052
/// ID of this segment
5153
/// When an endpoint is detected, it is incremented
5254
int32_t segment = 0;

sherpa-onnx/python/csrc/offline-stream.cc

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
3434
})
3535
.def_property_readonly("tokens",
3636
[](const PyClass &self) { return self.tokens; })
37+
.def_property_readonly("words",
38+
[](const PyClass &self) { return self.words; })
3739
.def_property_readonly(
3840
"timestamps", [](const PyClass &self) { return self.timestamps; });
3941
}

sherpa-onnx/python/csrc/online-recognizer.cc

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ static void PybindOnlineRecognizerResult(py::module *m) {
4040
})
4141
.def_property_readonly(
4242
"segment", [](PyClass &self) -> int32_t { return self.segment; })
43+
.def_property_readonly(
44+
"words",
45+
[](PyClass &self) -> std::vector<int32_t> { return self.words; })
4346
.def_property_readonly(
4447
"is_final", [](PyClass &self) -> bool { return self.is_final; })
4548
.def("__str__", &PyClass::AsJsonString,

0 commit comments

Comments
 (0)