6
6
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
7
7
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
8
8
9
+ #include < algorithm>
9
10
#include < fstream>
10
11
#include < ios>
11
12
#include < memory>
32
33
namespace sherpa_onnx {
33
34
34
35
// defined in ./online-recognizer-transducer-impl.h
35
- // static may or may not be here? TODDOs
36
- static OnlineRecognizerResult Convert (const OnlineTransducerDecoderResult &src,
37
- const SymbolTable &sym_table,
38
- float frame_shift_ms,
39
- int32_t subsampling_factor,
40
- int32_t segment,
41
- int32_t frames_since_start);
36
+ OnlineRecognizerResult Convert (const OnlineTransducerDecoderResult &src,
37
+ const SymbolTable &sym_table,
38
+ float frame_shift_ms, int32_t subsampling_factor,
39
+ int32_t segment, int32_t frames_since_start);
42
40
43
41
class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
44
- public:
42
+ public:
45
43
explicit OnlineRecognizerTransducerNeMoImpl (
46
44
const OnlineRecognizerConfig &config)
47
45
: config_(config),
48
46
symbol_table_(config.model_config.tokens),
49
47
endpoint_(config_.endpoint_config),
50
- model_(std::make_unique<OnlineTransducerNeMoModel>(
51
- config.model_config)) {
48
+ model_(
49
+ std::make_unique<OnlineTransducerNeMoModel>( config.model_config)) {
52
50
if (config.decoding_method == " greedy_search" ) {
53
51
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
54
52
model_.get (), config_.blank_penalty );
@@ -73,7 +71,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
73
71
model_.get (), config_.blank_penalty );
74
72
} else {
75
73
SHERPA_ONNX_LOGE (" Unsupported decoding method: %s" ,
76
- config.decoding_method .c_str ());
74
+ config.decoding_method .c_str ());
77
75
exit (-1 );
78
76
}
79
77
@@ -83,7 +81,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
83
81
84
82
std::unique_ptr<OnlineStream> CreateStream () const override {
85
83
auto stream = std::make_unique<OnlineStream>(config_.feat_config );
86
- stream->SetStates (model_->GetInitStates ());
87
84
InitOnlineStream (stream.get ());
88
85
return stream;
89
86
}
@@ -94,14 +91,12 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
94
91
}
95
92
96
93
OnlineRecognizerResult GetResult (OnlineStream *s) const override {
97
- OnlineTransducerDecoderResult decoder_result = s->GetResult ();
98
- decoder_->StripLeadingBlanks (&decoder_result);
99
-
100
94
// TODO(fangjun): Remember to change these constants if needed
101
95
int32_t frame_shift_ms = 10 ;
102
- int32_t subsampling_factor = 8 ;
103
- return Convert (decoder_result, symbol_table_, frame_shift_ms, subsampling_factor,
104
- s->GetCurrentSegment (), s->GetNumFramesSinceStart ());
96
+ int32_t subsampling_factor = model_->SubsamplingFactor ();
97
+ return Convert (s->GetResult (), symbol_table_, frame_shift_ms,
98
+ subsampling_factor, s->GetCurrentSegment (),
99
+ s->GetNumFramesSinceStart ());
105
100
}
106
101
107
102
bool IsEndpoint (OnlineStream *s) const override {
@@ -114,8 +109,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
114
109
// frame shift is 10 milliseconds
115
110
float frame_shift_in_seconds = 0.01 ;
116
111
117
- // subsampling factor is 8
118
- int32_t trailing_silence_frames = s->GetResult ().num_trailing_blanks * 8 ;
112
+ int32_t trailing_silence_frames =
113
+ s->GetResult ().num_trailing_blanks * model_-> SubsamplingFactor () ;
119
114
120
115
return endpoint_.IsEndpoint (num_processed_frames, trailing_silence_frames,
121
116
frame_shift_in_seconds);
@@ -126,19 +121,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
126
121
// segment is incremented only when the last
127
122
// result is not empty
128
123
const auto &r = s->GetResult ();
129
- if (!r.tokens .empty () && r. tokens . back () != 0 ) {
124
+ if (!r.tokens .empty ()) {
130
125
s->GetCurrentSegment () += 1 ;
131
126
}
132
127
}
133
128
134
- // we keep the decoder_out
135
- decoder_-> UpdateDecoderOut (&s-> GetResult ());
136
- Ort::Value decoder_out = std::move ( s->GetResult (). decoder_out );
129
+ s-> SetResult ({});
130
+
131
+ s->SetStates (model_-> GetEncoderInitStates () );
137
132
138
- auto r = decoder_->GetEmptyResult ();
139
-
140
- s->SetResult (r);
141
- s->GetResult ().decoder_out = std::move (decoder_out);
133
+ s->SetNeMoDecoderStates (model_->GetDecoderInitStates ());
142
134
143
135
// Note: We only update counters. The underlying audio samples
144
136
// are not discarded.
@@ -151,10 +143,9 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
151
143
152
144
int32_t feature_dim = ss[0 ]->FeatureDim ();
153
145
154
- std::vector<OnlineTransducerDecoderResult> result (n);
155
146
std::vector<float > features_vec (n * chunk_size * feature_dim);
156
147
std::vector<std::vector<Ort::Value>> encoder_states (n);
157
-
148
+
158
149
for (int32_t i = 0 ; i != n; ++i) {
159
150
const auto num_processed_frames = ss[i]->GetNumProcessedFrames ();
160
151
std::vector<float > features =
@@ -166,9 +157,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
166
157
std::copy (features.begin (), features.end (),
167
158
features_vec.data () + i * chunk_size * feature_dim);
168
159
169
- result[i] = std::move (ss[i]->GetResult ());
170
160
encoder_states[i] = std::move (ss[i]->GetStates ());
171
-
172
161
}
173
162
174
163
auto memory_info =
@@ -180,42 +169,35 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
180
169
features_vec.size (), x_shape.data (),
181
170
x_shape.size ());
182
171
183
- // Batch size is 1
184
- auto states = std::move (encoder_states[0 ]);
185
- int32_t num_states = states.size (); // num_states = 3
172
+ auto states = model_->StackStates (std::move (encoder_states));
173
+ int32_t num_states = states.size (); // num_states = 3
186
174
auto t = model_->RunEncoder (std::move (x), std::move (states));
187
175
// t[0] encoder_out, float tensor, (batch_size, dim, T)
188
176
// t[1] next states
189
-
177
+
190
178
std::vector<Ort::Value> out_states;
191
179
out_states.reserve (num_states);
192
-
180
+
193
181
for (int32_t k = 1 ; k != num_states + 1 ; ++k) {
194
182
out_states.push_back (std::move (t[k]));
195
183
}
196
184
185
+ auto unstacked_states = model_->UnStackStates (std::move (out_states));
186
+ for (int32_t i = 0 ; i != n; ++i) {
187
+ ss[i]->SetStates (std::move (unstacked_states[i]));
188
+ }
189
+
197
190
Ort::Value encoder_out = Transpose12 (model_->Allocator (), &t[0 ]);
198
-
199
- // defined in online-transducer-greedy-search-nemo-decoder.h
200
- // get intial states of decoder.
201
- std::vector<Ort::Value> &decoder_states = ss[0 ]->GetNeMoDecoderStates ();
202
-
203
- // Subsequent decoder states (for each chunks) are updated inside the Decode method.
204
- // This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it.
205
- decoder_states = decoder_->Decode (std::move (encoder_out),
206
- std::move (decoder_states),
207
- &result, ss, n);
208
-
209
- ss[0 ]->SetResult (result[0 ]);
210
-
211
- ss[0 ]->SetStates (std::move (out_states));
191
+
192
+ decoder_->Decode (std::move (encoder_out), ss, n);
212
193
}
213
194
214
195
void InitOnlineStream (OnlineStream *stream) const {
215
- auto r = decoder_->GetEmptyResult ();
196
+ // set encoder states
197
+ stream->SetStates (model_->GetEncoderInitStates ());
216
198
217
- stream-> SetResult (r);
218
- stream->SetNeMoDecoderStates (model_->GetDecoderInitStates (1 ));
199
+ // set decoder states
200
+ stream->SetNeMoDecoderStates (model_->GetDecoderInitStates ());
219
201
}
220
202
221
203
private:
@@ -250,7 +232,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
250
232
symbol_table_.NumSymbols (), vocab_size);
251
233
exit (-1 );
252
234
}
253
-
254
235
}
255
236
256
237
private:
@@ -259,9 +240,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
259
240
std::unique_ptr<OnlineTransducerNeMoModel> model_;
260
241
std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_;
261
242
Endpoint endpoint_;
262
-
263
243
};
264
244
265
245
} // namespace sherpa_onnx
266
246
267
- #endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
247
+ #endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
0 commit comments