Skip to content

Commit 6c5d358

Browse files
committed
export per-token scores also for greedy-search (online-transducer)
- export un-scaled lm_probs (modified-beam search, online-transducer) - polishing
1 parent 3d4f212 commit 6c5d358

5 files changed

+32
-10
lines changed

sherpa-onnx/csrc/hypothesis.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ struct Hypothesis {
3030

3131
// The acoustic probability for each token in ys.
3232
// Used for keyword spotting task.
33-
// For transducer mofified beam-search, this is filled with log_posterior scores.
33+
// For transducer mofified beam-search and greedy-search,
34+
// this is filled with log_posterior scores.
3435
std::vector<float> ys_probs;
3536

3637
// lm_probs[i] contains the lm score for each token in ys.

sherpa-onnx/csrc/online-recognizer.cc

+7-5
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ namespace sherpa_onnx {
2121
/// Helper for `OnlineRecognizerResult::AsJsonString()`
2222
template<typename T>
2323
const std::string& VecToString(const std::vector<T>& vec, int32_t precision = 6) {
24-
std::ostringstrean oss;
24+
std::ostringstream oss;
2525
oss << std::fixed << std::setprecision(precision);
26-
oss << "[ " <<
26+
oss << "[ ";
2727
std::string sep = "";
2828
for (auto item : vec) {
2929
oss << sep << item;
@@ -35,9 +35,11 @@ const std::string& VecToString(const std::vector<T>& vec, int32_t precision = 6)
3535

3636
/// Helper for `OnlineRecognizerResult::AsJsonString()`
3737
template<> // explicit specialization for T = std::string
38-
const std::string& VecToString<std::string>(const std::vector<T>& vec, int32_t) { // ignore 2nd arg
39-
std::ostringstrean oss;
40-
oss << "[ " <<
38+
const std::string& VecToString<std::string>(const std::vector<std::string>& vec,
39+
int32_t) // ignore 2nd arg
40+
{
41+
std::ostringstream oss;
42+
oss << "[ ";
4143
std::string sep = "";
4244
for (auto item : vec) {
4345
oss << sep << "\"" << item << "\"";

sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc

+16
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
7171
r->tokens = std::vector<int64_t>(start, end);
7272
}
7373

74+
7475
void OnlineTransducerGreedySearchDecoder::Decode(
7576
Ort::Value encoder_out,
7677
std::vector<OnlineTransducerDecoderResult> *result) {
78+
7779
std::vector<int64_t> encoder_out_shape =
7880
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
7981

@@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
9799
break;
98100
}
99101
}
102+
100103
if (is_batch_decoder_out_cached) {
101104
auto &r = result->front();
102105
std::vector<int64_t> decoder_out_shape =
@@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
124127
if (blank_penalty_ > 0.0) {
125128
p_logit[0] -= blank_penalty_; // assuming blank id is 0
126129
}
130+
127131
auto y = static_cast<int32_t>(std::distance(
128132
static_cast<const float *>(p_logit),
129133
std::max_element(static_cast<const float *>(p_logit),
@@ -138,6 +142,18 @@ void OnlineTransducerGreedySearchDecoder::Decode(
138142
} else {
139143
++r.num_trailing_blanks;
140144
}
145+
146+
// export the per-token log scores
147+
if (y != 0 && y != unk_id_) {
148+
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
149+
// save time by doing it only for
150+
// emitted symbols
151+
float *p_logprob = p_logit; // rename p_logit as p_logprob,
152+
// now it contains normalized
153+
// probability
154+
r.ys_probs.push_back(p_logprob[y]);
155+
}
156+
141157
}
142158
if (emitted) {
143159
Ort::Value decoder_input = model_->BuildDecoderInput(*result);

sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
6060
r->tokens = std::move(tokens);
6161
r->timestamps = std::move(hyp.timestamps);
6262

63-
6463
// export per-token scores
6564
r->ys_probs = std::move(hyp.ys_probs);
6665
r->lm_probs = std::move(hyp.lm_probs);
@@ -149,8 +148,6 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
149148
}
150149
p_logprob = p_logit; // we changed p_logprob in the above for loop
151150

152-
// KarelVesely: Sholud the context score be added already before taking topk tokens ?
153-
154151
for (int32_t b = 0; b != batch_size; ++b) {
155152
int32_t frame_offset = (*result)[b].frame_offset;
156153
int32_t start = hyps_row_splits[b];
@@ -190,14 +187,17 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
190187
prev_lm_log_prob; // log_prob only includes the
191188
// score of the transducer
192189
// export the per-token log scores
193-
{
190+
if (new_token != 0 && new_token != unk_id_) {
194191
const Hypothesis& prev_i = prev[hyp_index];
195192
// subtract 'prev[i]' path scores, which were added before
196193
// for getting topk tokens
197194
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
198195
new_hyp.ys_probs.push_back(y_prob);
199196

200197
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
198+
if (lm_scale_ != 0.0) {
199+
lm_prob /= lm_scale_; // remove lm-scale
200+
}
201201
new_hyp.lm_probs.push_back(lm_prob);
202202

203203
new_hyp.context_scores.push_back(context_score);

sherpa-onnx/python/sherpa_onnx/online_recognizer.py

+3
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,9 @@ def is_ready(self, s: OnlineStream) -> bool:
503503
def get_result(self, s: OnlineStream) -> str:
504504
return self.recognizer.get_result(s).text.strip()
505505

506+
def get_result_as_json_string(self, s: OnlineStream) -> str:
507+
return self.recognizer.get_result(s).as_json_string()
508+
506509
def tokens(self, s: OnlineStream) -> List[str]:
507510
return self.recognizer.get_result(s).tokens
508511

0 commit comments

Comments
 (0)