1
+ // sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
2
+ //
3
+ // Copyright (c) 2022-2024 Xiaomi Corporation
4
+ // Copyright (c) 2024 Sangeet Sagar
5
+
6
+ #ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
7
+ #define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
8
+
9
+ #include < fstream>
10
+ #include < ios>
11
+ #include < memory>
12
+ #include < regex> // NOLINT
13
+ #include < sstream>
14
+ #include < string>
15
+ #include < utility>
16
+ #include < vector>
17
+
18
+ #if __ANDROID_API__ >= 9
19
+ #include " android/asset_manager.h"
20
+ #include " android/asset_manager_jni.h"
21
+ #endif
22
+
23
+ #include " sherpa-onnx/csrc/macros.h"
24
+ #include " sherpa-onnx/csrc/online-recognizer-impl.h"
25
+ #include " sherpa-onnx/csrc/online-recognizer.h"
26
+ #include " sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h"
27
+ #include " sherpa-onnx/csrc/online-transducer-nemo-model.h"
28
+ #include " sherpa-onnx/csrc/symbol-table.h"
29
+ #include " sherpa-onnx/csrc/transpose.h"
30
+ #include " sherpa-onnx/csrc/utils.h"
31
+
32
+ namespace sherpa_onnx {
33
+
34
+ // defined in ./online-recognizer-transducer-impl.h
35
+ // static may or may not be here? TODDOs
36
+ static OnlineRecognizerResult Convert (const OnlineTransducerDecoderResult &src,
37
+ const SymbolTable &sym_table,
38
+ float frame_shift_ms,
39
+ int32_t subsampling_factor,
40
+ int32_t segment,
41
+ int32_t frames_since_start);
42
+
43
+ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
44
+ public:
45
+ explicit OnlineRecognizerTransducerNeMoImpl (
46
+ const OnlineRecognizerConfig &config)
47
+ : config_(config),
48
+ symbol_table_(config.model_config.tokens),
49
+ endpoint_(config_.endpoint_config),
50
+ model_(std::make_unique<OnlineTransducerNeMoModel>(
51
+ config.model_config)) {
52
+ if (config.decoding_method == " greedy_search" ) {
53
+ decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
54
+ model_.get (), config_.blank_penalty );
55
+ } else {
56
+ SHERPA_ONNX_LOGE (" Unsupported decoding method: %s" ,
57
+ config.decoding_method .c_str ());
58
+ exit (-1 );
59
+ }
60
+ PostInit ();
61
+ }
62
+
63
+ #if __ANDROID_API__ >= 9
64
+ explicit OnlineRecognizerTransducerNeMoImpl (
65
+ AAssetManager *mgr, const OnlineRecognizerConfig &config)
66
+ : config_(config),
67
+ symbol_table_(mgr, config.model_config.tokens),
68
+ endpoint_(mgrconfig_.endpoint_config),
69
+ model_(std::make_unique<OnlineTransducerNeMoModel>(
70
+ mgr, config.model_config)) {
71
+ if (config.decoding_method == " greedy_search" ) {
72
+ decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
73
+ model_.get (), config_.blank_penalty );
74
+ } else {
75
+ SHERPA_ONNX_LOGE (" Unsupported decoding method: %s" ,
76
+ config.decoding_method .c_str ());
77
+ exit (-1 );
78
+ }
79
+
80
+ PostInit ();
81
+ }
82
+ #endif
83
+
84
+ std::unique_ptr<OnlineStream> CreateStream () const override {
85
+ auto stream = std::make_unique<OnlineStream>(config_.feat_config );
86
+ stream->SetStates (model_->GetInitStates ());
87
+ InitOnlineStream (stream.get ());
88
+ return stream;
89
+ }
90
+
91
+ bool IsReady (OnlineStream *s) const override {
92
+ return s->GetNumProcessedFrames () + model_->ChunkSize () <
93
+ s->NumFramesReady ();
94
+ }
95
+
96
+ OnlineRecognizerResult GetResult (OnlineStream *s) const override {
97
+ OnlineTransducerDecoderResult decoder_result = s->GetResult ();
98
+ decoder_->StripLeadingBlanks (&decoder_result);
99
+
100
+ // TODO(fangjun): Remember to change these constants if needed
101
+ int32_t frame_shift_ms = 10 ;
102
+ int32_t subsampling_factor = 8 ;
103
+ return Convert (decoder_result, symbol_table_, frame_shift_ms, subsampling_factor,
104
+ s->GetCurrentSegment (), s->GetNumFramesSinceStart ());
105
+ }
106
+
107
+ bool IsEndpoint (OnlineStream *s) const override {
108
+ if (!config_.enable_endpoint ) {
109
+ return false ;
110
+ }
111
+
112
+ int32_t num_processed_frames = s->GetNumProcessedFrames ();
113
+
114
+ // frame shift is 10 milliseconds
115
+ float frame_shift_in_seconds = 0.01 ;
116
+
117
+ // subsampling factor is 8
118
+ int32_t trailing_silence_frames = s->GetResult ().num_trailing_blanks * 8 ;
119
+
120
+ return endpoint_.IsEndpoint (num_processed_frames, trailing_silence_frames,
121
+ frame_shift_in_seconds);
122
+ }
123
+
124
+ void Reset (OnlineStream *s) const override {
125
+ {
126
+ // segment is incremented only when the last
127
+ // result is not empty
128
+ const auto &r = s->GetResult ();
129
+ if (!r.tokens .empty () && r.tokens .back () != 0 ) {
130
+ s->GetCurrentSegment () += 1 ;
131
+ }
132
+ }
133
+
134
+ // we keep the decoder_out
135
+ decoder_->UpdateDecoderOut (&s->GetResult ());
136
+ Ort::Value decoder_out = std::move (s->GetResult ().decoder_out );
137
+
138
+ auto r = decoder_->GetEmptyResult ();
139
+
140
+ s->SetResult (r);
141
+ s->GetResult ().decoder_out = std::move (decoder_out);
142
+
143
+ // Note: We only update counters. The underlying audio samples
144
+ // are not discarded.
145
+ s->Reset ();
146
+ }
147
+
148
+ void DecodeStreams (OnlineStream **ss, int32_t n) const override {
149
+ int32_t chunk_size = model_->ChunkSize ();
150
+ int32_t chunk_shift = model_->ChunkShift ();
151
+
152
+ int32_t feature_dim = ss[0 ]->FeatureDim ();
153
+
154
+ std::vector<OnlineTransducerDecoderResult> result (n);
155
+ std::vector<float > features_vec (n * chunk_size * feature_dim);
156
+ std::vector<std::vector<Ort::Value>> encoder_states (n);
157
+
158
+ for (int32_t i = 0 ; i != n; ++i) {
159
+ const auto num_processed_frames = ss[i]->GetNumProcessedFrames ();
160
+ std::vector<float > features =
161
+ ss[i]->GetFrames (num_processed_frames, chunk_size);
162
+
163
+ // Question: should num_processed_frames include chunk_shift?
164
+ ss[i]->GetNumProcessedFrames () += chunk_shift;
165
+
166
+ std::copy (features.begin (), features.end (),
167
+ features_vec.data () + i * chunk_size * feature_dim);
168
+
169
+ result[i] = std::move (ss[i]->GetResult ());
170
+ encoder_states[i] = std::move (ss[i]->GetStates ());
171
+
172
+ }
173
+
174
+ auto memory_info =
175
+ Ort::MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeDefault);
176
+
177
+ std::array<int64_t , 3 > x_shape{n, chunk_size, feature_dim};
178
+
179
+ Ort::Value x = Ort::Value::CreateTensor (memory_info, features_vec.data (),
180
+ features_vec.size (), x_shape.data (),
181
+ x_shape.size ());
182
+
183
+ // Batch size is 1
184
+ auto states = std::move (encoder_states[0 ]);
185
+ int32_t num_states = states.size (); // num_states = 3
186
+ auto t = model_->RunEncoder (std::move (x), std::move (states));
187
+ // t[0] encoder_out, float tensor, (batch_size, dim, T)
188
+ // t[1] next states
189
+
190
+ std::vector<Ort::Value> out_states;
191
+ out_states.reserve (num_states);
192
+
193
+ for (int32_t k = 1 ; k != num_states + 1 ; ++k) {
194
+ out_states.push_back (std::move (t[k]));
195
+ }
196
+
197
+ Ort::Value encoder_out = Transpose12 (model_->Allocator (), &t[0 ]);
198
+
199
+ // defined in online-transducer-greedy-search-nemo-decoder.h
200
+ // get intial states of decoder.
201
+ std::vector<Ort::Value> &decoder_states = ss[0 ]->GetNeMoDecoderStates ();
202
+
203
+ // Subsequent decoder states (for each chunks) are updated inside the Decode method.
204
+ // This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it.
205
+ decoder_states = decoder_->Decode (std::move (encoder_out),
206
+ std::move (decoder_states),
207
+ &result, ss, n);
208
+
209
+ ss[0 ]->SetResult (result[0 ]);
210
+
211
+ ss[0 ]->SetStates (std::move (out_states));
212
+ }
213
+
214
+ void InitOnlineStream (OnlineStream *stream) const {
215
+ auto r = decoder_->GetEmptyResult ();
216
+
217
+ stream->SetResult (r);
218
+ stream->SetNeMoDecoderStates (model_->GetDecoderInitStates (1 ));
219
+ }
220
+
221
+ private:
222
+ void PostInit () {
223
+ config_.feat_config .nemo_normalize_type =
224
+ model_->FeatureNormalizationMethod ();
225
+
226
+ config_.feat_config .low_freq = 0 ;
227
+ // config_.feat_config.high_freq = 8000;
228
+ config_.feat_config .is_librosa = true ;
229
+ config_.feat_config .remove_dc_offset = false ;
230
+ // config_.feat_config.window_type = "hann";
231
+ config_.feat_config .dither = 0 ;
232
+ config_.feat_config .nemo_normalize_type =
233
+ model_->FeatureNormalizationMethod ();
234
+
235
+ int32_t vocab_size = model_->VocabSize ();
236
+
237
+ // check the blank ID
238
+ if (!symbol_table_.Contains (" <blk>" )) {
239
+ SHERPA_ONNX_LOGE (" tokens.txt does not include the blank token <blk>" );
240
+ exit (-1 );
241
+ }
242
+
243
+ if (symbol_table_[" <blk>" ] != vocab_size - 1 ) {
244
+ SHERPA_ONNX_LOGE (" <blk> is not the last token!" );
245
+ exit (-1 );
246
+ }
247
+
248
+ if (symbol_table_.NumSymbols () != vocab_size) {
249
+ SHERPA_ONNX_LOGE (" number of lines in tokens.txt %d != %d (vocab_size)" ,
250
+ symbol_table_.NumSymbols (), vocab_size);
251
+ exit (-1 );
252
+ }
253
+
254
+ }
255
+
256
+ private:
257
+ OnlineRecognizerConfig config_;
258
+ SymbolTable symbol_table_;
259
+ std::unique_ptr<OnlineTransducerNeMoModel> model_;
260
+ std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_;
261
+ Endpoint endpoint_;
262
+
263
+ };
264
+
265
+ } // namespace sherpa_onnx
266
+
267
+ #endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
0 commit comments