Skip to content

Commit 209eaaa

Browse files
authored
Limit number of tokens per second for whisper. (#1958)
Otherwise, it spends lots of time in the loop if the EOT token is not predicted.
1 parent 4917753 commit 209eaaa

4 files changed

+14
-6
lines changed

sherpa-onnx/csrc/offline-recognizer-whisper-impl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
131131
auto cross_kv = model_->ForwardEncoder(std::move(mel));
132132

133133
auto results = decoder_->Decode(std::move(cross_kv.first),
134-
std::move(cross_kv.second));
134+
std::move(cross_kv.second), num_frames);
135135

136136
auto r = Convert(results[0], symbol_table_);
137137
s->SetResult(r);

sherpa-onnx/csrc/offline-whisper-decoder.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ class OfflineWhisperDecoder {
3333
* @return Return a vector of size `N` containing the decoded results.
3434
*/
3535
virtual std::vector<OfflineWhisperDecoderResult> Decode(
36-
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
36+
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v,
37+
int32_t num_feature_frames) = 0;
3738

3839
virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;
3940
};

sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc

+8-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ void OfflineWhisperGreedySearchDecoder::SetConfig(
1919

2020
std::vector<OfflineWhisperDecoderResult>
2121
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
22-
Ort::Value cross_v) {
22+
Ort::Value cross_v,
23+
int32_t num_feature_frames) {
2324
auto memory_info =
2425
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
2526

@@ -99,7 +100,12 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
99100
int32_t n_text_ctx = model_->TextCtx();
100101

101102
std::vector<int32_t> predicted_tokens;
102-
for (int32_t i = 0; i < n_text_ctx / 2; ++i) {
103+
104+
// assume at most 6 tokens per second
105+
int32_t num_possible_tokens = num_feature_frames / 100 * 6;
106+
num_possible_tokens = std::min<int32_t>(num_possible_tokens, n_text_ctx / 2);
107+
108+
for (int32_t i = 0; i < num_possible_tokens; ++i) {
103109
if (max_token_id == model_->EOT()) {
104110
break;
105111
}

sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
1818
OfflineWhisperModel *model)
1919
: config_(config), model_(model) {}
2020

21-
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
22-
Ort::Value cross_v) override;
21+
std::vector<OfflineWhisperDecoderResult> Decode(
22+
Ort::Value cross_k, Ort::Value cross_v,
23+
int32_t num_feature_frames) override;
2324

2425
void SetConfig(const OfflineWhisperModelConfig &config) override;
2526

0 commit comments

Comments
 (0)