@@ -61,25 +61,11 @@ class SileroVadModel::Impl {
61
61
#endif
62
62
63
63
void Reset () {
64
- // 2 - number of LSTM layer
65
- // 1 - batch size
66
- // 64 - hidden dim
67
- std::array<int64_t , 3 > shape{2 , 1 , 64 };
68
-
69
- Ort::Value h =
70
- Ort::Value::CreateTensor<float >(allocator_, shape.data (), shape.size ());
71
-
72
- Ort::Value c =
73
- Ort::Value::CreateTensor<float >(allocator_, shape.data (), shape.size ());
74
-
75
- Fill<float >(&h, 0 );
76
- Fill<float >(&c, 0 );
77
-
78
- states_.clear ();
79
-
80
- states_.reserve (2 );
81
- states_.push_back (std::move (h));
82
- states_.push_back (std::move (c));
64
+ if (is_v5_) {
65
+ ResetV5 ();
66
+ } else {
67
+ ResetV4 ();
68
+ }
83
69
84
70
triggered_ = false ;
85
71
current_sample_ = 0 ;
@@ -94,31 +80,7 @@ class SileroVadModel::Impl {
94
80
exit (-1 );
95
81
}
96
82
97
- auto memory_info =
98
- Ort::MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeDefault);
99
-
100
- std::array<int64_t , 2 > x_shape = {1 , n};
101
-
102
- Ort::Value x =
103
- Ort::Value::CreateTensor (memory_info, const_cast <float *>(samples), n,
104
- x_shape.data (), x_shape.size ());
105
-
106
- int64_t sr_shape = 1 ;
107
- Ort::Value sr =
108
- Ort::Value::CreateTensor (memory_info, &sample_rate_, 1 , &sr_shape, 1 );
109
-
110
- std::array<Ort::Value, 4 > inputs = {std::move (x), std::move (sr),
111
- std::move (states_[0 ]),
112
- std::move (states_[1 ])};
113
-
114
- auto out =
115
- sess_->Run ({}, input_names_ptr_.data (), inputs.data (), inputs.size (),
116
- output_names_ptr_.data (), output_names_ptr_.size ());
117
-
118
- states_[0 ] = std::move (out[1 ]);
119
- states_[1 ] = std::move (out[2 ]);
120
-
121
- float prob = out[0 ].GetTensorData <float >()[0 ];
83
+ float prob = Run (samples, n);
122
84
123
85
float threshold = config_.silero_vad .threshold ;
124
86
@@ -186,6 +148,8 @@ class SileroVadModel::Impl {
186
148
187
149
int32_t WindowSize () const { return config_.silero_vad .window_size ; }
188
150
151
+ int32_t WindowShift () const { return WindowSize () - window_shift_; }
152
+
189
153
int32_t MinSilenceDurationSamples () const { return min_silence_samples_; }
190
154
191
155
int32_t MinSpeechDurationSamples () const { return min_speech_samples_; }
@@ -205,12 +169,76 @@ class SileroVadModel::Impl {
205
169
206
170
GetInputNames (sess_.get (), &input_names_, &input_names_ptr_);
207
171
GetOutputNames (sess_.get (), &output_names_, &output_names_ptr_);
172
+
173
+ if (input_names_.size () == 4 && output_names_.size () == 3 ) {
174
+ is_v5_ = false ;
175
+ } else if (input_names_.size () == 3 && output_names_.size () == 2 ) {
176
+ is_v5_ = true ;
177
+
178
+ // 64 for 16kHz
179
+ // 32 for 8kHz
180
+ window_shift_ = 64 ;
181
+
182
+ if (WindowSize () != 512 ) {
183
+ SHERPA_ONNX_LOGE (
184
+ " For silero_vad v5, we require window_size to be 512 for 16kHz" );
185
+ exit (-1 );
186
+ }
187
+ } else {
188
+ SHERPA_ONNX_LOGE (" Unsupported silero vad model" );
189
+ exit (-1 );
190
+ }
191
+
208
192
Check ();
209
193
210
194
Reset ();
211
195
}
212
196
213
- void Check () {
197
+ void ResetV5 () {
198
+ // 2 - number of LSTM layer
199
+ // 1 - batch size
200
+ // 128 - hidden dim
201
+ std::array<int64_t , 3 > shape{2 , 1 , 128 };
202
+
203
+ Ort::Value s =
204
+ Ort::Value::CreateTensor<float >(allocator_, shape.data (), shape.size ());
205
+
206
+ Fill<float >(&s, 0 );
207
+ states_.clear ();
208
+ states_.push_back (std::move (s));
209
+ }
210
+
211
+ void ResetV4 () {
212
+ // 2 - number of LSTM layer
213
+ // 1 - batch size
214
+ // 64 - hidden dim
215
+ std::array<int64_t , 3 > shape{2 , 1 , 64 };
216
+
217
+ Ort::Value h =
218
+ Ort::Value::CreateTensor<float >(allocator_, shape.data (), shape.size ());
219
+
220
+ Ort::Value c =
221
+ Ort::Value::CreateTensor<float >(allocator_, shape.data (), shape.size ());
222
+
223
+ Fill<float >(&h, 0 );
224
+ Fill<float >(&c, 0 );
225
+
226
+ states_.clear ();
227
+
228
+ states_.reserve (2 );
229
+ states_.push_back (std::move (h));
230
+ states_.push_back (std::move (c));
231
+ }
232
+
233
+ void Check () const {
234
+ if (is_v5_) {
235
+ CheckV5 ();
236
+ } else {
237
+ CheckV4 ();
238
+ }
239
+ }
240
+
241
+ void CheckV4 () const {
214
242
if (input_names_.size () != 4 ) {
215
243
SHERPA_ONNX_LOGE (" Expect 4 inputs. Given: %d" ,
216
244
static_cast <int32_t >(input_names_.size ()));
@@ -262,6 +290,114 @@ class SileroVadModel::Impl {
262
290
}
263
291
}
264
292
293
+ void CheckV5 () const {
294
+ if (input_names_.size () != 3 ) {
295
+ SHERPA_ONNX_LOGE (" Expect 3 inputs. Given: %d" ,
296
+ static_cast <int32_t >(input_names_.size ()));
297
+ exit (-1 );
298
+ }
299
+
300
+ if (input_names_[0 ] != " input" ) {
301
+ SHERPA_ONNX_LOGE (" Input[0]: %s. Expected: input" ,
302
+ input_names_[0 ].c_str ());
303
+ exit (-1 );
304
+ }
305
+
306
+ if (input_names_[1 ] != " state" ) {
307
+ SHERPA_ONNX_LOGE (" Input[1]: %s. Expected: state" ,
308
+ input_names_[1 ].c_str ());
309
+ exit (-1 );
310
+ }
311
+
312
+ if (input_names_[2 ] != " sr" ) {
313
+ SHERPA_ONNX_LOGE (" Input[2]: %s. Expected: sr" , input_names_[2 ].c_str ());
314
+ exit (-1 );
315
+ }
316
+
317
+ // Now for outputs
318
+ if (output_names_.size () != 2 ) {
319
+ SHERPA_ONNX_LOGE (" Expect 2 outputs. Given: %d" ,
320
+ static_cast <int32_t >(output_names_.size ()));
321
+ exit (-1 );
322
+ }
323
+
324
+ if (output_names_[0 ] != " output" ) {
325
+ SHERPA_ONNX_LOGE (" Output[0]: %s. Expected: output" ,
326
+ output_names_[0 ].c_str ());
327
+ exit (-1 );
328
+ }
329
+
330
+ if (output_names_[1 ] != " stateN" ) {
331
+ SHERPA_ONNX_LOGE (" Output[1]: %s. Expected: stateN" ,
332
+ output_names_[1 ].c_str ());
333
+ exit (-1 );
334
+ }
335
+ }
336
+
337
+ float Run (const float *samples, int32_t n) {
338
+ if (is_v5_) {
339
+ return RunV5 (samples, n);
340
+ } else {
341
+ return RunV4 (samples, n);
342
+ }
343
+ }
344
+
345
+ float RunV5 (const float *samples, int32_t n) {
346
+ auto memory_info =
347
+ Ort::MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeDefault);
348
+
349
+ std::array<int64_t , 2 > x_shape = {1 , n};
350
+
351
+ Ort::Value x =
352
+ Ort::Value::CreateTensor (memory_info, const_cast <float *>(samples), n,
353
+ x_shape.data (), x_shape.size ());
354
+
355
+ int64_t sr_shape = 1 ;
356
+ Ort::Value sr =
357
+ Ort::Value::CreateTensor (memory_info, &sample_rate_, 1 , &sr_shape, 1 );
358
+
359
+ std::array<Ort::Value, 3 > inputs = {std::move (x), std::move (states_[0 ]),
360
+ std::move (sr)};
361
+
362
+ auto out =
363
+ sess_->Run ({}, input_names_ptr_.data (), inputs.data (), inputs.size (),
364
+ output_names_ptr_.data (), output_names_ptr_.size ());
365
+
366
+ states_[0 ] = std::move (out[1 ]);
367
+
368
+ float prob = out[0 ].GetTensorData <float >()[0 ];
369
+ return prob;
370
+ }
371
+
372
+ float RunV4 (const float *samples, int32_t n) {
373
+ auto memory_info =
374
+ Ort::MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeDefault);
375
+
376
+ std::array<int64_t , 2 > x_shape = {1 , n};
377
+
378
+ Ort::Value x =
379
+ Ort::Value::CreateTensor (memory_info, const_cast <float *>(samples), n,
380
+ x_shape.data (), x_shape.size ());
381
+
382
+ int64_t sr_shape = 1 ;
383
+ Ort::Value sr =
384
+ Ort::Value::CreateTensor (memory_info, &sample_rate_, 1 , &sr_shape, 1 );
385
+
386
+ std::array<Ort::Value, 4 > inputs = {std::move (x), std::move (sr),
387
+ std::move (states_[0 ]),
388
+ std::move (states_[1 ])};
389
+
390
+ auto out =
391
+ sess_->Run ({}, input_names_ptr_.data (), inputs.data (), inputs.size (),
392
+ output_names_ptr_.data (), output_names_ptr_.size ());
393
+
394
+ states_[0 ] = std::move (out[1 ]);
395
+ states_[1 ] = std::move (out[2 ]);
396
+
397
+ float prob = out[0 ].GetTensorData <float >()[0 ];
398
+ return prob;
399
+ }
400
+
265
401
private:
266
402
VadModelConfig config_;
267
403
@@ -286,6 +422,10 @@ class SileroVadModel::Impl {
286
422
int32_t current_sample_ = 0 ;
287
423
int32_t temp_start_ = 0 ;
288
424
int32_t temp_end_ = 0 ;
425
+
426
+ int32_t window_shift_ = 0 ;
427
+
428
+ bool is_v5_ = false ;
289
429
};
290
430
291
431
SileroVadModel::SileroVadModel (const VadModelConfig &config)
@@ -306,6 +446,8 @@ bool SileroVadModel::IsSpeech(const float *samples, int32_t n) {
306
446
307
447
int32_t SileroVadModel::WindowSize () const { return impl_->WindowSize (); }
308
448
449
+ int32_t SileroVadModel::WindowShift () const { return impl_->WindowShift (); }
450
+
309
451
int32_t SileroVadModel::MinSilenceDurationSamples () const {
310
452
return impl_->MinSilenceDurationSamples ();
311
453
}
0 commit comments