Skip to content

Commit 3f472a9

Browse files
sangeet2020sangeet2020
and
sangeet2020
authored
Add C++ runtime for *streaming* faster conformer transducer from NeMo. (#889)
Co-authored-by: sangeet2020 <15uec053@gmail.com>
1 parent 49d66ec commit 3f472a9

10 files changed

+1119
-2
lines changed

sherpa-onnx/csrc/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ set(sources
7474
online-transducer-model-config.cc
7575
online-transducer-model.cc
7676
online-transducer-modified-beam-search-decoder.cc
77+
online-transducer-nemo-model.cc
78+
online-transducer-greedy-search-nemo-decoder.cc
7779
online-wenet-ctc-model-config.cc
7880
online-wenet-ctc-model.cc
7981
online-zipformer-transducer-model.cc

sherpa-onnx/csrc/online-recognizer-impl.cc

+28-2
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,28 @@
77
#include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h"
88
#include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
99
#include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
10+
#include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h"
11+
#include "sherpa-onnx/csrc/onnx-utils.h"
1012

1113
namespace sherpa_onnx {
1214

1315
std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
1416
const OnlineRecognizerConfig &config) {
17+
1518
if (!config.model_config.transducer.encoder.empty()) {
16-
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
19+
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
20+
21+
auto decoder_model = ReadFile(config.model_config.transducer.decoder);
22+
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
23+
24+
size_t node_count = sess->GetOutputCount();
25+
26+
if (node_count == 1) {
27+
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
28+
} else {
29+
SHERPA_ONNX_LOGE("Running streaming Nemo transducer model");
30+
return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config);
31+
}
1732
}
1833

1934
if (!config.model_config.paraformer.encoder.empty()) {
@@ -34,7 +49,18 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
3449
std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
3550
AAssetManager *mgr, const OnlineRecognizerConfig &config) {
3651
if (!config.model_config.transducer.encoder.empty()) {
37-
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
52+
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
53+
54+
auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder);
55+
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
56+
57+
size_t node_count = sess->GetOutputCount();
58+
59+
if (node_count == 1) {
60+
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
61+
} else {
62+
return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(mgr, config);
63+
}
3864
}
3965

4066
if (!config.model_config.paraformer.encoder.empty()) {

sherpa-onnx/csrc/online-recognizer-transducer-impl.h

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
4646
r.timestamps.reserve(src.tokens.size());
4747

4848
for (auto i : src.tokens) {
49+
if (i == -1) continue;
4950
auto sym = sym_table[i];
5051

5152
r.text.append(sym);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
// sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
2+
//
3+
// Copyright (c) 2022-2024 Xiaomi Corporation
4+
// Copyright (c) 2024 Sangeet Sagar
5+
6+
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
7+
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
8+
9+
#include <fstream>
10+
#include <ios>
11+
#include <memory>
12+
#include <regex> // NOLINT
13+
#include <sstream>
14+
#include <string>
15+
#include <utility>
16+
#include <vector>
17+
18+
#if __ANDROID_API__ >= 9
19+
#include "android/asset_manager.h"
20+
#include "android/asset_manager_jni.h"
21+
#endif
22+
23+
#include "sherpa-onnx/csrc/macros.h"
24+
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
25+
#include "sherpa-onnx/csrc/online-recognizer.h"
26+
#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h"
27+
#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
28+
#include "sherpa-onnx/csrc/symbol-table.h"
29+
#include "sherpa-onnx/csrc/transpose.h"
30+
#include "sherpa-onnx/csrc/utils.h"
31+
32+
namespace sherpa_onnx {
33+
34+
// defined in ./online-recognizer-transducer-impl.h
35+
// static may or may not be here? TODDOs
36+
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
37+
const SymbolTable &sym_table,
38+
float frame_shift_ms,
39+
int32_t subsampling_factor,
40+
int32_t segment,
41+
int32_t frames_since_start);
42+
43+
class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
44+
public:
45+
explicit OnlineRecognizerTransducerNeMoImpl(
46+
const OnlineRecognizerConfig &config)
47+
: config_(config),
48+
symbol_table_(config.model_config.tokens),
49+
endpoint_(config_.endpoint_config),
50+
model_(std::make_unique<OnlineTransducerNeMoModel>(
51+
config.model_config)) {
52+
if (config.decoding_method == "greedy_search") {
53+
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
54+
model_.get(), config_.blank_penalty);
55+
} else {
56+
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
57+
config.decoding_method.c_str());
58+
exit(-1);
59+
}
60+
PostInit();
61+
}
62+
63+
#if __ANDROID_API__ >= 9
64+
explicit OnlineRecognizerTransducerNeMoImpl(
65+
AAssetManager *mgr, const OnlineRecognizerConfig &config)
66+
: config_(config),
67+
symbol_table_(mgr, config.model_config.tokens),
68+
endpoint_(mgrconfig_.endpoint_config),
69+
model_(std::make_unique<OnlineTransducerNeMoModel>(
70+
mgr, config.model_config)) {
71+
if (config.decoding_method == "greedy_search") {
72+
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
73+
model_.get(), config_.blank_penalty);
74+
} else {
75+
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
76+
config.decoding_method.c_str());
77+
exit(-1);
78+
}
79+
80+
PostInit();
81+
}
82+
#endif
83+
84+
std::unique_ptr<OnlineStream> CreateStream() const override {
85+
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
86+
stream->SetStates(model_->GetInitStates());
87+
InitOnlineStream(stream.get());
88+
return stream;
89+
}
90+
91+
bool IsReady(OnlineStream *s) const override {
92+
return s->GetNumProcessedFrames() + model_->ChunkSize() <
93+
s->NumFramesReady();
94+
}
95+
96+
OnlineRecognizerResult GetResult(OnlineStream *s) const override {
97+
OnlineTransducerDecoderResult decoder_result = s->GetResult();
98+
decoder_->StripLeadingBlanks(&decoder_result);
99+
100+
// TODO(fangjun): Remember to change these constants if needed
101+
int32_t frame_shift_ms = 10;
102+
int32_t subsampling_factor = 8;
103+
return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor,
104+
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
105+
}
106+
107+
bool IsEndpoint(OnlineStream *s) const override {
108+
if (!config_.enable_endpoint) {
109+
return false;
110+
}
111+
112+
int32_t num_processed_frames = s->GetNumProcessedFrames();
113+
114+
// frame shift is 10 milliseconds
115+
float frame_shift_in_seconds = 0.01;
116+
117+
// subsampling factor is 8
118+
int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 8;
119+
120+
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
121+
frame_shift_in_seconds);
122+
}
123+
124+
void Reset(OnlineStream *s) const override {
125+
{
126+
// segment is incremented only when the last
127+
// result is not empty
128+
const auto &r = s->GetResult();
129+
if (!r.tokens.empty() && r.tokens.back() != 0) {
130+
s->GetCurrentSegment() += 1;
131+
}
132+
}
133+
134+
// we keep the decoder_out
135+
decoder_->UpdateDecoderOut(&s->GetResult());
136+
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
137+
138+
auto r = decoder_->GetEmptyResult();
139+
140+
s->SetResult(r);
141+
s->GetResult().decoder_out = std::move(decoder_out);
142+
143+
// Note: We only update counters. The underlying audio samples
144+
// are not discarded.
145+
s->Reset();
146+
}
147+
148+
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
149+
int32_t chunk_size = model_->ChunkSize();
150+
int32_t chunk_shift = model_->ChunkShift();
151+
152+
int32_t feature_dim = ss[0]->FeatureDim();
153+
154+
std::vector<OnlineTransducerDecoderResult> result(n);
155+
std::vector<float> features_vec(n * chunk_size * feature_dim);
156+
std::vector<std::vector<Ort::Value>> encoder_states(n);
157+
158+
for (int32_t i = 0; i != n; ++i) {
159+
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
160+
std::vector<float> features =
161+
ss[i]->GetFrames(num_processed_frames, chunk_size);
162+
163+
// Question: should num_processed_frames include chunk_shift?
164+
ss[i]->GetNumProcessedFrames() += chunk_shift;
165+
166+
std::copy(features.begin(), features.end(),
167+
features_vec.data() + i * chunk_size * feature_dim);
168+
169+
result[i] = std::move(ss[i]->GetResult());
170+
encoder_states[i] = std::move(ss[i]->GetStates());
171+
172+
}
173+
174+
auto memory_info =
175+
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
176+
177+
std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};
178+
179+
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
180+
features_vec.size(), x_shape.data(),
181+
x_shape.size());
182+
183+
// Batch size is 1
184+
auto states = std::move(encoder_states[0]);
185+
int32_t num_states = states.size(); // num_states = 3
186+
auto t = model_->RunEncoder(std::move(x), std::move(states));
187+
// t[0] encoder_out, float tensor, (batch_size, dim, T)
188+
// t[1] next states
189+
190+
std::vector<Ort::Value> out_states;
191+
out_states.reserve(num_states);
192+
193+
for (int32_t k = 1; k != num_states + 1; ++k) {
194+
out_states.push_back(std::move(t[k]));
195+
}
196+
197+
Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);
198+
199+
// defined in online-transducer-greedy-search-nemo-decoder.h
200+
// get intial states of decoder.
201+
std::vector<Ort::Value> &decoder_states = ss[0]->GetNeMoDecoderStates();
202+
203+
// Subsequent decoder states (for each chunks) are updated inside the Decode method.
204+
// This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it.
205+
decoder_states = decoder_->Decode(std::move(encoder_out),
206+
std::move(decoder_states),
207+
&result, ss, n);
208+
209+
ss[0]->SetResult(result[0]);
210+
211+
ss[0]->SetStates(std::move(out_states));
212+
}
213+
214+
void InitOnlineStream(OnlineStream *stream) const {
215+
auto r = decoder_->GetEmptyResult();
216+
217+
stream->SetResult(r);
218+
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1));
219+
}
220+
221+
private:
222+
void PostInit() {
223+
config_.feat_config.nemo_normalize_type =
224+
model_->FeatureNormalizationMethod();
225+
226+
config_.feat_config.low_freq = 0;
227+
// config_.feat_config.high_freq = 8000;
228+
config_.feat_config.is_librosa = true;
229+
config_.feat_config.remove_dc_offset = false;
230+
// config_.feat_config.window_type = "hann";
231+
config_.feat_config.dither = 0;
232+
config_.feat_config.nemo_normalize_type =
233+
model_->FeatureNormalizationMethod();
234+
235+
int32_t vocab_size = model_->VocabSize();
236+
237+
// check the blank ID
238+
if (!symbol_table_.Contains("<blk>")) {
239+
SHERPA_ONNX_LOGE("tokens.txt does not include the blank token <blk>");
240+
exit(-1);
241+
}
242+
243+
if (symbol_table_["<blk>"] != vocab_size - 1) {
244+
SHERPA_ONNX_LOGE("<blk> is not the last token!");
245+
exit(-1);
246+
}
247+
248+
if (symbol_table_.NumSymbols() != vocab_size) {
249+
SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
250+
symbol_table_.NumSymbols(), vocab_size);
251+
exit(-1);
252+
}
253+
254+
}
255+
256+
private:
257+
OnlineRecognizerConfig config_;
258+
SymbolTable symbol_table_;
259+
std::unique_ptr<OnlineTransducerNeMoModel> model_;
260+
std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_;
261+
Endpoint endpoint_;
262+
263+
};
264+
265+
} // namespace sherpa_onnx
266+
267+
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_

sherpa-onnx/csrc/online-stream.cc

+15
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ class OnlineStream::Impl {
9090

9191
std::vector<Ort::Value> &GetStates() { return states_; }
9292

93+
void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) {
94+
decoder_states_ = std::move(decoder_states);
95+
}
96+
97+
std::vector<Ort::Value> &GetNeMoDecoderStates() { return decoder_states_; }
98+
9399
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
94100

95101
std::vector<float> &GetParaformerFeatCache() {
@@ -129,6 +135,7 @@ class OnlineStream::Impl {
129135
TransducerKeywordResult empty_keyword_result_;
130136
OnlineCtcDecoderResult ctc_result_;
131137
std::vector<Ort::Value> states_; // states for transducer or ctc models
138+
std::vector<Ort::Value> decoder_states_; // states for nemo transducer models
132139
std::vector<float> paraformer_feat_cache_;
133140
std::vector<float> paraformer_encoder_out_cache_;
134141
std::vector<float> paraformer_alpha_cache_;
@@ -218,6 +225,14 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
218225
return impl_->GetStates();
219226
}
220227

228+
void OnlineStream::SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) {
229+
return impl_->SetNeMoDecoderStates(std::move(decoder_states));
230+
}
231+
232+
std::vector<Ort::Value> &OnlineStream::GetNeMoDecoderStates() {
233+
return impl_->GetNeMoDecoderStates();
234+
}
235+
221236
const ContextGraphPtr &OnlineStream::GetContextGraph() const {
222237
return impl_->GetContextGraph();
223238
}

sherpa-onnx/csrc/online-stream.h

+3
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ class OnlineStream {
9191
void SetStates(std::vector<Ort::Value> states);
9292
std::vector<Ort::Value> &GetStates();
9393

94+
void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states);
95+
std::vector<Ort::Value> &GetNeMoDecoderStates();
96+
9497
/**
9598
* Get the context graph corresponding to this stream.
9699
*

0 commit comments

Comments
 (0)