Skip to content

Commit 53d6a4d

Browse files
Track token scores (k2-fsa#571)
* add export of per-token scores (ys, lm, context) - for best path of the modified-beam-search decoding of transducer * refactoring JSON export of OnlineRecognitionResult, extending pybind11 API of OnlineRecognitionResult * export per-token scores also for greedy-search (online-transducer) - export un-scaled lm_probs (modified-beam search, online-transducer) - polishing * fill lm_probs/context_scores only if LM/ContextGraph is present (make Result smaller)
1 parent 8289c0f commit 53d6a4d

11 files changed

+152
-46
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
build
22
*.zip
33
*.tgz
4+
*.sw?
45
onnxruntime-*
56
icefall-*
67
run.sh

sherpa-onnx/csrc/hypothesis.h

+13-1
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,21 @@ struct Hypothesis {
2929
std::vector<int32_t> timestamps;
3030

3131
// The acoustic probability for each token in ys.
32-
// Only used for keyword spotting task.
32+
// Used for keyword spotting task.
33+
// For transducer mofified beam-search and greedy-search,
34+
// this is filled with log_posterior scores.
3335
std::vector<float> ys_probs;
3436

37+
// lm_probs[i] contains the lm score for each token in ys.
38+
// Used only in transducer mofified beam-search.
39+
// Elements filled only if LM is used.
40+
std::vector<float> lm_probs;
41+
42+
// context_scores[i] contains the context-graph score for each token in ys.
43+
// Used only in transducer mofified beam-search.
44+
// Elements filled only if `ContextGraph` is used.
45+
std::vector<float> context_scores;
46+
3547
// The total score of ys in log space.
3648
// It contains only acoustic scores
3749
double log_prob = 0;

sherpa-onnx/csrc/online-recognizer-transducer-impl.h

+4
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
6969
r.timestamps.push_back(time);
7070
}
7171

72+
r.ys_probs = std::move(src.ys_probs);
73+
r.lm_probs = std::move(src.lm_probs);
74+
r.context_scores = std::move(src.context_scores);
75+
7276
r.segment = segment;
7377
r.start_time = frames_since_start * frame_shift_ms / 1000.;
7478

sherpa-onnx/csrc/online-recognizer.cc

+38-44
Original file line numberDiff line numberDiff line change
@@ -18,56 +18,50 @@
1818

1919
namespace sherpa_onnx {
2020

21-
std::string OnlineRecognizerResult::AsJsonString() const {
22-
std::ostringstream os;
23-
os << "{";
24-
os << "\"is_final\":" << (is_final ? "true" : "false") << ", ";
25-
os << "\"segment\":" << segment << ", ";
26-
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
27-
<< ", ";
28-
29-
os << "\"text\""
30-
<< ": ";
31-
os << "\"" << text << "\""
32-
<< ", ";
33-
34-
os << "\""
35-
<< "timestamps"
36-
<< "\""
37-
<< ": ";
38-
os << "[";
39-
21+
/// Helper for `OnlineRecognizerResult::AsJsonString()`
22+
template<typename T>
23+
std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
24+
std::ostringstream oss;
25+
oss << std::fixed << std::setprecision(precision);
26+
oss << "[ ";
4027
std::string sep = "";
41-
for (auto t : timestamps) {
42-
os << sep << std::fixed << std::setprecision(2) << t;
28+
for (const auto& item : vec) {
29+
oss << sep << item;
4330
sep = ", ";
4431
}
45-
os << "], ";
46-
47-
os << "\""
48-
<< "tokens"
49-
<< "\""
50-
<< ":";
51-
os << "[";
52-
53-
sep = "";
54-
auto oldFlags = os.flags();
55-
for (const auto &t : tokens) {
56-
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
57-
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
58-
os << sep << "\""
59-
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
60-
<< ">"
61-
<< "\"";
62-
os.flags(oldFlags);
63-
} else {
64-
os << sep << "\"" << t << "\"";
65-
}
32+
oss << " ]";
33+
return oss.str();
34+
}
35+
36+
/// Helper for `OnlineRecognizerResult::AsJsonString()`
37+
template<> // explicit specialization for T = std::string
38+
std::string VecToString<std::string>(const std::vector<std::string>& vec,
39+
int32_t) { // ignore 2nd arg
40+
std::ostringstream oss;
41+
oss << "[ ";
42+
std::string sep = "";
43+
for (const auto& item : vec) {
44+
oss << sep << "\"" << item << "\"";
6645
sep = ", ";
6746
}
68-
os << "]";
69-
os << "}";
47+
oss << " ]";
48+
return oss.str();
49+
}
7050

51+
std::string OnlineRecognizerResult::AsJsonString() const {
52+
std::ostringstream os;
53+
os << "{ ";
54+
os << "\"text\": " << "\"" << text << "\"" << ", ";
55+
os << "\"tokens\": " << VecToString(tokens) << ", ";
56+
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
57+
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
58+
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
59+
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
60+
os << "\"segment\": " << segment << ", ";
61+
os << "\"start_time\": " << std::fixed << std::setprecision(2)
62+
<< start_time << ", ";
63+
os << "\"is_final\": " << (is_final ? "true" : "false");
64+
os << "}";
7165
return os.str();
7266
}
7367

sherpa-onnx/csrc/online-recognizer.h

+9
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ struct OnlineRecognizerResult {
4040
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
4141
std::vector<float> timestamps;
4242

43+
std::vector<float> ys_probs; //< log-prob scores from ASR model
44+
std::vector<float> lm_probs; //< log-prob scores from language model
45+
//
46+
/// log-domain scores from "hot-phrase" contextual boosting
47+
std::vector<float> context_scores;
48+
4349
/// ID of this segment
4450
/// When an endpoint is detected, it is incremented
4551
int32_t segment = 0;
@@ -58,6 +64,9 @@ struct OnlineRecognizerResult {
5864
* "text": "The recognition result",
5965
* "tokens": [x, x, x],
6066
* "timestamps": [x, x, x],
67+
* "ys_probs": [x, x, x],
68+
* "lm_probs": [x, x, x],
69+
* "context_scores": [x, x, x],
6170
* "segment": x,
6271
* "start_time": x,
6372
* "is_final": true|false

sherpa-onnx/csrc/online-transducer-decoder.cc

+8
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
3737
frame_offset = other.frame_offset;
3838
timestamps = other.timestamps;
3939

40+
ys_probs = other.ys_probs;
41+
lm_probs = other.lm_probs;
42+
context_scores = other.context_scores;
43+
4044
return *this;
4145
}
4246

@@ -60,6 +64,10 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
6064
frame_offset = other.frame_offset;
6165
timestamps = std::move(other.timestamps);
6266

67+
ys_probs = std::move(other.ys_probs);
68+
lm_probs = std::move(other.lm_probs);
69+
context_scores = std::move(other.context_scores);
70+
6371
return *this;
6472
}
6573

sherpa-onnx/csrc/online-transducer-decoder.h

+4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ struct OnlineTransducerDecoderResult {
2626
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
2727
std::vector<int32_t> timestamps;
2828

29+
std::vector<float> ys_probs;
30+
std::vector<float> lm_probs;
31+
std::vector<float> context_scores;
32+
2933
// Cache decoder_out for endpointing
3034
Ort::Value decoder_out;
3135

sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc

+15
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
7171
r->tokens = std::vector<int64_t>(start, end);
7272
}
7373

74+
7475
void OnlineTransducerGreedySearchDecoder::Decode(
7576
Ort::Value encoder_out,
7677
std::vector<OnlineTransducerDecoderResult> *result) {
78+
7779
std::vector<int64_t> encoder_out_shape =
7880
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
7981

@@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
9799
break;
98100
}
99101
}
102+
100103
if (is_batch_decoder_out_cached) {
101104
auto &r = result->front();
102105
std::vector<int64_t> decoder_out_shape =
@@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
124127
if (blank_penalty_ > 0.0) {
125128
p_logit[0] -= blank_penalty_; // assuming blank id is 0
126129
}
130+
127131
auto y = static_cast<int32_t>(std::distance(
128132
static_cast<const float *>(p_logit),
129133
std::max_element(static_cast<const float *>(p_logit),
@@ -138,6 +142,17 @@ void OnlineTransducerGreedySearchDecoder::Decode(
138142
} else {
139143
++r.num_trailing_blanks;
140144
}
145+
146+
// export the per-token log scores
147+
if (y != 0 && y != unk_id_) {
148+
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
149+
// save time by doing it only for
150+
// emitted symbols
151+
const float *p_logprob = p_logit; // rename p_logit as p_logprob,
152+
// now it contains normalized
153+
// probability
154+
r.ys_probs.push_back(p_logprob[y]);
155+
}
141156
}
142157
if (emitted) {
143158
Ort::Value decoder_input = model_->BuildDecoderInput(*result);

sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc

+28
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
5959
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
6060
r->tokens = std::move(tokens);
6161
r->timestamps = std::move(hyp.timestamps);
62+
63+
// export per-token scores
64+
r->ys_probs = std::move(hyp.ys_probs);
65+
r->lm_probs = std::move(hyp.lm_probs);
66+
r->context_scores = std::move(hyp.context_scores);
67+
6268
r->num_trailing_blanks = hyp.num_trailing_blanks;
6369
}
6470

@@ -180,6 +186,28 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
180186
new_hyp.log_prob = p_logprob[k] + context_score -
181187
prev_lm_log_prob; // log_prob only includes the
182188
// score of the transducer
189+
// export the per-token log scores
190+
if (new_token != 0 && new_token != unk_id_) {
191+
const Hypothesis& prev_i = prev[hyp_index];
192+
// subtract 'prev[i]' path scores, which were added before
193+
// getting topk tokens
194+
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
195+
new_hyp.ys_probs.push_back(y_prob);
196+
197+
if (lm_) { // export only when LM is used
198+
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
199+
if (lm_scale_ != 0.0) {
200+
lm_prob /= lm_scale_; // remove lm-scale
201+
}
202+
new_hyp.lm_probs.push_back(lm_prob);
203+
}
204+
205+
// export only when `ContextGraph` is used
206+
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
207+
new_hyp.context_scores.push_back(context_score);
208+
}
209+
}
210+
183211
hyps.Add(std::move(new_hyp));
184212
} // for (auto k : topk)
185213
cur.push_back(std::move(hyps));

sherpa-onnx/python/csrc/online-recognizer.cc

+20-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,26 @@ static void PybindOnlineRecognizerResult(py::module *m) {
2828
[](PyClass &self) -> float { return self.start_time; })
2929
.def_property_readonly(
3030
"timestamps",
31-
[](PyClass &self) -> std::vector<float> { return self.timestamps; });
31+
[](PyClass &self) -> std::vector<float> { return self.timestamps; })
32+
.def_property_readonly(
33+
"ys_probs",
34+
[](PyClass &self) -> std::vector<float> { return self.ys_probs; })
35+
.def_property_readonly(
36+
"lm_probs",
37+
[](PyClass &self) -> std::vector<float> { return self.lm_probs; })
38+
.def_property_readonly(
39+
"context_scores",
40+
[](PyClass &self) -> std::vector<float> {
41+
return self.context_scores;
42+
})
43+
.def_property_readonly(
44+
"segment",
45+
[](PyClass &self) -> int32_t { return self.segment; })
46+
.def_property_readonly(
47+
"is_final",
48+
[](PyClass &self) -> bool { return self.is_final; })
49+
.def("as_json_string", &PyClass::AsJsonString,
50+
py::call_guard<py::gil_scoped_release>());
3251
}
3352

3453
static void PybindOnlineRecognizerConfig(py::module *m) {

sherpa-onnx/python/sherpa_onnx/online_recognizer.py

+12
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,9 @@ def is_ready(self, s: OnlineStream) -> bool:
503503
def get_result(self, s: OnlineStream) -> str:
504504
return self.recognizer.get_result(s).text.strip()
505505

506+
def get_result_as_json_string(self, s: OnlineStream) -> str:
507+
return self.recognizer.get_result(s).as_json_string()
508+
506509
def tokens(self, s: OnlineStream) -> List[str]:
507510
return self.recognizer.get_result(s).tokens
508511

@@ -512,6 +515,15 @@ def timestamps(self, s: OnlineStream) -> List[float]:
512515
def start_time(self, s: OnlineStream) -> float:
513516
return self.recognizer.get_result(s).start_time
514517

518+
def ys_probs(self, s: OnlineStream) -> List[float]:
519+
return self.recognizer.get_result(s).ys_probs
520+
521+
def lm_probs(self, s: OnlineStream) -> List[float]:
522+
return self.recognizer.get_result(s).lm_probs
523+
524+
def context_scores(self, s: OnlineStream) -> List[float]:
525+
return self.recognizer.get_result(s).context_scores
526+
515527
def is_endpoint(self, s: OnlineStream) -> bool:
516528
return self.recognizer.is_endpoint(s)
517529

0 commit comments

Comments
 (0)