Skip to content

Commit 3d4f212

Browse files
committed
refactoring JSON export of OnlineRecognitionResult, extending pybind11 API of OnlineRecognitionResult
1 parent 92eff88 commit 3d4f212

4 files changed

+55
-49
lines changed

sherpa-onnx/csrc/online-recognizer.cc

+36-44
Original file line numberDiff line numberDiff line change
@@ -18,56 +18,48 @@
1818

1919
namespace sherpa_onnx {
2020

21-
std::string OnlineRecognizerResult::AsJsonString() const {
22-
std::ostringstream os;
23-
os << "{";
24-
os << "\"is_final\":" << (is_final ? "true" : "false") << ", ";
25-
os << "\"segment\":" << segment << ", ";
26-
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
27-
<< ", ";
28-
29-
os << "\"text\""
30-
<< ": ";
31-
os << "\"" << text << "\""
32-
<< ", ";
33-
34-
os << "\""
35-
<< "timestamps"
36-
<< "\""
37-
<< ": ";
38-
os << "[";
39-
21+
/// Helper for `OnlineRecognizerResult::AsJsonString()`
22+
template<typename T>
23+
const std::string& VecToString(const std::vector<T>& vec, int32_t precision = 6) {
24+
std::ostringstrean oss;
25+
oss << std::fixed << std::setprecision(precision);
26+
oss << "[ " <<
4027
std::string sep = "";
41-
for (auto t : timestamps) {
42-
os << sep << std::fixed << std::setprecision(2) << t;
28+
for (auto item : vec) {
29+
oss << sep << item;
4330
sep = ", ";
4431
}
45-
os << "], ";
46-
47-
os << "\""
48-
<< "tokens"
49-
<< "\""
50-
<< ":";
51-
os << "[";
52-
53-
sep = "";
54-
auto oldFlags = os.flags();
55-
for (const auto &t : tokens) {
56-
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
57-
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
58-
os << sep << "\""
59-
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
60-
<< ">"
61-
<< "\"";
62-
os.flags(oldFlags);
63-
} else {
64-
os << sep << "\"" << t << "\"";
65-
}
32+
oss << " ]";
33+
return oss.str();
34+
}
35+
36+
/// Helper for `OnlineRecognizerResult::AsJsonString()`
37+
template<> // explicit specialization for T = std::string
38+
const std::string& VecToString<std::string>(const std::vector<T>& vec, int32_t) { // ignore 2nd arg
39+
std::ostringstrean oss;
40+
oss << "[ " <<
41+
std::string sep = "";
42+
for (auto item : vec) {
43+
oss << sep << "\"" << item << "\"";
6644
sep = ", ";
6745
}
68-
os << "]";
69-
os << "}";
46+
oss << " ]";
47+
return oss.str();
48+
}
7049

50+
std::string OnlineRecognizerResult::AsJsonString() const {
51+
std::ostringstream os;
52+
os << "{ ";
53+
os << "\"text\": " << "\"" << text << "\"" << ", ";
54+
os << "\"tokens\": " << VecToString(tokens) << ", ";
55+
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
56+
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
57+
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
58+
os << "\"constext_scores\": " << VecToString(context_scores, 6) << ", ";
59+
os << "\"segment\": " << segment << ", ";
60+
os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time << ", ";
61+
os << "\"is_final\": " << (is_final ? "true" : "false");
62+
os << "}";
7163
return os.str();
7264
}
7365

sherpa-onnx/csrc/online-recognizer.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ struct OnlineRecognizerResult {
4040
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
4141
std::vector<float> timestamps;
4242

43-
std::vector<float> ys_probs;
44-
std::vector<float> lm_probs;
45-
std::vector<float> context_scores;
43+
std::vector<float> ys_probs; //< log-prob scores from ASR model
44+
std::vector<float> lm_probs; //< log-prob scores from language model
45+
std::vector<float> context_scores; //< log-domain scores from "hot-phrase" contextual boosting
4646

4747
/// ID of this segment
4848
/// When an endpoint is detected, it is incremented
@@ -62,6 +62,9 @@ struct OnlineRecognizerResult {
6262
* "text": "The recognition result",
6363
* "tokens": [x, x, x],
6464
* "timestamps": [x, x, x],
65+
* "ys_probs": [x, x, x],
66+
* "lm_probs": [x, x, x],
67+
* "context_scores": [x, x, x],
6568
* "segment": x,
6669
* "start_time": x,
6770
* "is_final": true|false

sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
198198
new_hyp.ys_probs.push_back(y_prob);
199199

200200
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
201-
new_hyp.lm_probs.push_back(lm_probs);
201+
new_hyp.lm_probs.push_back(lm_prob);
202202

203203
new_hyp.context_scores.push_back(context_score);
204204
}

sherpa-onnx/python/csrc/online-recognizer.cc

+12-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,18 @@ static void PybindOnlineRecognizerResult(py::module *m) {
3434
[](PyClass &self) -> std::vector<float> { return self.lm_probs; })
3535
.def_property_readonly(
3636
"context_scores",
37-
[](PyClass &self) -> std::vector<float> { return self.context_scores; });
37+
[](PyClass &self) -> std::vector<float> { return self.context_scores; })
38+
.def_property_readonly(
39+
"segment",
40+
[](PyClass &self) -> int32_t { return self.segment; })
41+
.def_property_readonly(
42+
"start_time",
43+
[](PyClass &self) -> float { return self.start_time; })
44+
.def_property_readonly(
45+
"is_final",
46+
[](PyClass &self) -> bool { return self.is_final; })
47+
.def("as_json_string", &PyClass::AsJsonString,
48+
py::call_guard<py::gil_scoped_release>());
3849
}
3950

4051
static void PybindOnlineRecognizerConfig(py::module *m) {

0 commit comments

Comments
 (0)