@@ -71,9 +71,11 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
71
71
r->tokens = std::vector<int64_t >(start, end);
72
72
}
73
73
74
+
74
75
void OnlineTransducerGreedySearchDecoder::Decode (
75
76
Ort::Value encoder_out,
76
77
std::vector<OnlineTransducerDecoderResult> *result) {
78
+
77
79
std::vector<int64_t > encoder_out_shape =
78
80
encoder_out.GetTensorTypeAndShapeInfo ().GetShape ();
79
81
@@ -97,6 +99,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
97
99
break ;
98
100
}
99
101
}
102
+
100
103
if (is_batch_decoder_out_cached) {
101
104
auto &r = result->front ();
102
105
std::vector<int64_t > decoder_out_shape =
@@ -124,6 +127,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
124
127
if (blank_penalty_ > 0.0 ) {
125
128
p_logit[0 ] -= blank_penalty_; // assuming blank id is 0
126
129
}
130
+
127
131
auto y = static_cast <int32_t >(std::distance (
128
132
static_cast <const float *>(p_logit),
129
133
std::max_element (static_cast <const float *>(p_logit),
@@ -138,6 +142,18 @@ void OnlineTransducerGreedySearchDecoder::Decode(
138
142
} else {
139
143
++r.num_trailing_blanks ;
140
144
}
145
+
146
+ // export the per-token log scores
147
+ if (y != 0 && y != unk_id_) {
148
+ LogSoftmax (p_logit, vocab_size); // renormalize probabilities,
149
+ // save time by doing it only for
150
+ // emitted symbols
151
+ float *p_logprob = p_logit; // rename p_logit as p_logprob,
152
+ // now it contains normalized
153
+ // probability
154
+ r.ys_probs .push_back (p_logprob[y]);
155
+ }
156
+
141
157
}
142
158
if (emitted) {
143
159
Ort::Value decoder_input = model_->BuildDecoderInput (*result);
0 commit comments