Skip to content

Commit e8f7138

Browse files
authored
Support silero_vad version 5 (k2-fsa#1064)
1 parent 98d10a2 commit e8f7138

File tree

6 files changed

+203
-50
lines changed

6 files changed

+203
-50
lines changed

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ project(sherpa-onnx)
88
# ./nodejs-addon-examples
99
# ./dart-api-examples/
1010
# ./sherpa-onnx/flutter/CHANGELOG.md
11-
set(SHERPA_ONNX_VERSION "1.10.5")
11+
set(SHERPA_ONNX_VERSION "1.10.6")
1212

1313
# Disable warning about
1414
#

nodejs-addon-examples/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
22
"dependencies": {
3-
"sherpa-onnx-node": "^1.10.3"
3+
"sherpa-onnx-node": "^1.10.6"
44
}
55
}

sherpa-onnx/csrc/silero-vad-model.cc

+187-45
Original file line numberDiff line numberDiff line change
@@ -61,25 +61,11 @@ class SileroVadModel::Impl {
6161
#endif
6262

6363
void Reset() {
64-
// 2 - number of LSTM layer
65-
// 1 - batch size
66-
// 64 - hidden dim
67-
std::array<int64_t, 3> shape{2, 1, 64};
68-
69-
Ort::Value h =
70-
Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());
71-
72-
Ort::Value c =
73-
Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());
74-
75-
Fill<float>(&h, 0);
76-
Fill<float>(&c, 0);
77-
78-
states_.clear();
79-
80-
states_.reserve(2);
81-
states_.push_back(std::move(h));
82-
states_.push_back(std::move(c));
64+
if (is_v5_) {
65+
ResetV5();
66+
} else {
67+
ResetV4();
68+
}
8369

8470
triggered_ = false;
8571
current_sample_ = 0;
@@ -94,31 +80,7 @@ class SileroVadModel::Impl {
9480
exit(-1);
9581
}
9682

97-
auto memory_info =
98-
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
99-
100-
std::array<int64_t, 2> x_shape = {1, n};
101-
102-
Ort::Value x =
103-
Ort::Value::CreateTensor(memory_info, const_cast<float *>(samples), n,
104-
x_shape.data(), x_shape.size());
105-
106-
int64_t sr_shape = 1;
107-
Ort::Value sr =
108-
Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1);
109-
110-
std::array<Ort::Value, 4> inputs = {std::move(x), std::move(sr),
111-
std::move(states_[0]),
112-
std::move(states_[1])};
113-
114-
auto out =
115-
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
116-
output_names_ptr_.data(), output_names_ptr_.size());
117-
118-
states_[0] = std::move(out[1]);
119-
states_[1] = std::move(out[2]);
120-
121-
float prob = out[0].GetTensorData<float>()[0];
83+
float prob = Run(samples, n);
12284

12385
float threshold = config_.silero_vad.threshold;
12486

@@ -186,6 +148,8 @@ class SileroVadModel::Impl {
186148

187149
int32_t WindowSize() const { return config_.silero_vad.window_size; }
188150

151+
int32_t WindowShift() const { return WindowSize() - window_shift_; }
152+
189153
int32_t MinSilenceDurationSamples() const { return min_silence_samples_; }
190154

191155
int32_t MinSpeechDurationSamples() const { return min_speech_samples_; }
@@ -205,12 +169,76 @@ class SileroVadModel::Impl {
205169

206170
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
207171
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
172+
173+
if (input_names_.size() == 4 && output_names_.size() == 3) {
174+
is_v5_ = false;
175+
} else if (input_names_.size() == 3 && output_names_.size() == 2) {
176+
is_v5_ = true;
177+
178+
// 64 for 16kHz
179+
// 32 for 8kHz
180+
window_shift_ = 64;
181+
182+
if (WindowSize() != 512) {
183+
SHERPA_ONNX_LOGE(
184+
"For silero_vad v5, we require window_size to be 512 for 16kHz");
185+
exit(-1);
186+
}
187+
} else {
188+
SHERPA_ONNX_LOGE("Unsupported silero vad model");
189+
exit(-1);
190+
}
191+
208192
Check();
209193

210194
Reset();
211195
}
212196

213-
void Check() {
197+
void ResetV5() {
198+
// 2 - number of LSTM layer
199+
// 1 - batch size
200+
// 128 - hidden dim
201+
std::array<int64_t, 3> shape{2, 1, 128};
202+
203+
Ort::Value s =
204+
Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());
205+
206+
Fill<float>(&s, 0);
207+
states_.clear();
208+
states_.push_back(std::move(s));
209+
}
210+
211+
void ResetV4() {
212+
// 2 - number of LSTM layer
213+
// 1 - batch size
214+
// 64 - hidden dim
215+
std::array<int64_t, 3> shape{2, 1, 64};
216+
217+
Ort::Value h =
218+
Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());
219+
220+
Ort::Value c =
221+
Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());
222+
223+
Fill<float>(&h, 0);
224+
Fill<float>(&c, 0);
225+
226+
states_.clear();
227+
228+
states_.reserve(2);
229+
states_.push_back(std::move(h));
230+
states_.push_back(std::move(c));
231+
}
232+
233+
void Check() const {
234+
if (is_v5_) {
235+
CheckV5();
236+
} else {
237+
CheckV4();
238+
}
239+
}
240+
241+
void CheckV4() const {
214242
if (input_names_.size() != 4) {
215243
SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d",
216244
static_cast<int32_t>(input_names_.size()));
@@ -262,6 +290,114 @@ class SileroVadModel::Impl {
262290
}
263291
}
264292

293+
void CheckV5() const {
294+
if (input_names_.size() != 3) {
295+
SHERPA_ONNX_LOGE("Expect 3 inputs. Given: %d",
296+
static_cast<int32_t>(input_names_.size()));
297+
exit(-1);
298+
}
299+
300+
if (input_names_[0] != "input") {
301+
SHERPA_ONNX_LOGE("Input[0]: %s. Expected: input",
302+
input_names_[0].c_str());
303+
exit(-1);
304+
}
305+
306+
if (input_names_[1] != "state") {
307+
SHERPA_ONNX_LOGE("Input[1]: %s. Expected: state",
308+
input_names_[1].c_str());
309+
exit(-1);
310+
}
311+
312+
if (input_names_[2] != "sr") {
313+
SHERPA_ONNX_LOGE("Input[2]: %s. Expected: sr", input_names_[2].c_str());
314+
exit(-1);
315+
}
316+
317+
// Now for outputs
318+
if (output_names_.size() != 2) {
319+
SHERPA_ONNX_LOGE("Expect 2 outputs. Given: %d",
320+
static_cast<int32_t>(output_names_.size()));
321+
exit(-1);
322+
}
323+
324+
if (output_names_[0] != "output") {
325+
SHERPA_ONNX_LOGE("Output[0]: %s. Expected: output",
326+
output_names_[0].c_str());
327+
exit(-1);
328+
}
329+
330+
if (output_names_[1] != "stateN") {
331+
SHERPA_ONNX_LOGE("Output[1]: %s. Expected: stateN",
332+
output_names_[1].c_str());
333+
exit(-1);
334+
}
335+
}
336+
337+
float Run(const float *samples, int32_t n) {
338+
if (is_v5_) {
339+
return RunV5(samples, n);
340+
} else {
341+
return RunV4(samples, n);
342+
}
343+
}
344+
345+
float RunV5(const float *samples, int32_t n) {
346+
auto memory_info =
347+
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
348+
349+
std::array<int64_t, 2> x_shape = {1, n};
350+
351+
Ort::Value x =
352+
Ort::Value::CreateTensor(memory_info, const_cast<float *>(samples), n,
353+
x_shape.data(), x_shape.size());
354+
355+
int64_t sr_shape = 1;
356+
Ort::Value sr =
357+
Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1);
358+
359+
std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states_[0]),
360+
std::move(sr)};
361+
362+
auto out =
363+
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
364+
output_names_ptr_.data(), output_names_ptr_.size());
365+
366+
states_[0] = std::move(out[1]);
367+
368+
float prob = out[0].GetTensorData<float>()[0];
369+
return prob;
370+
}
371+
372+
float RunV4(const float *samples, int32_t n) {
373+
auto memory_info =
374+
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
375+
376+
std::array<int64_t, 2> x_shape = {1, n};
377+
378+
Ort::Value x =
379+
Ort::Value::CreateTensor(memory_info, const_cast<float *>(samples), n,
380+
x_shape.data(), x_shape.size());
381+
382+
int64_t sr_shape = 1;
383+
Ort::Value sr =
384+
Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1);
385+
386+
std::array<Ort::Value, 4> inputs = {std::move(x), std::move(sr),
387+
std::move(states_[0]),
388+
std::move(states_[1])};
389+
390+
auto out =
391+
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
392+
output_names_ptr_.data(), output_names_ptr_.size());
393+
394+
states_[0] = std::move(out[1]);
395+
states_[1] = std::move(out[2]);
396+
397+
float prob = out[0].GetTensorData<float>()[0];
398+
return prob;
399+
}
400+
265401
private:
266402
VadModelConfig config_;
267403

@@ -286,6 +422,10 @@ class SileroVadModel::Impl {
286422
int32_t current_sample_ = 0;
287423
int32_t temp_start_ = 0;
288424
int32_t temp_end_ = 0;
425+
426+
int32_t window_shift_ = 0;
427+
428+
bool is_v5_ = false;
289429
};
290430

291431
SileroVadModel::SileroVadModel(const VadModelConfig &config)
@@ -306,6 +446,8 @@ bool SileroVadModel::IsSpeech(const float *samples, int32_t n) {
306446

307447
int32_t SileroVadModel::WindowSize() const { return impl_->WindowSize(); }
308448

449+
int32_t SileroVadModel::WindowShift() const { return impl_->WindowShift(); }
450+
309451
int32_t SileroVadModel::MinSilenceDurationSamples() const {
310452
return impl_->MinSilenceDurationSamples();
311453
}

sherpa-onnx/csrc/silero-vad-model.h

+5
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ class SileroVadModel : public VadModel {
3939

4040
int32_t WindowSize() const override;
4141

42+
// For silero vad V4, it is WindowSize().
43+
// For silero vad V5, it is WindowSize()-64 for 16kHz and
44+
// WindowSize()-32 for 8kHz
45+
int32_t WindowShift() const override;
46+
4247
int32_t MinSilenceDurationSamples() const override;
4348
int32_t MinSpeechDurationSamples() const override;
4449

sherpa-onnx/csrc/vad-model.h

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class VadModel {
4040

4141
virtual int32_t WindowSize() const = 0;
4242

43+
virtual int32_t WindowShift() const = 0;
44+
4345
virtual int32_t MinSilenceDurationSamples() const = 0;
4446
virtual int32_t MinSpeechDurationSamples() const = 0;
4547
virtual void SetMinSilenceDuration(float s) = 0;

sherpa-onnx/csrc/voice-activity-detector.cc

+7-3
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,20 @@ class VoiceActivityDetector::Impl {
3838
}
3939

4040
int32_t window_size = model_->WindowSize();
41+
int32_t window_shift = model_->WindowShift();
4142

4243
// note n is usually window_size and there is no need to use
4344
// an extra buffer here
4445
last_.insert(last_.end(), samples, samples + n);
45-
int32_t k = static_cast<int32_t>(last_.size()) / window_size;
46+
47+
// Note: For v4, window_shift == window_size
48+
int32_t k =
49+
(static_cast<int32_t>(last_.size()) - window_size) / window_shift + 1;
4650
const float *p = last_.data();
4751
bool is_speech = false;
4852

49-
for (int32_t i = 0; i != k; ++i, p += window_size) {
50-
buffer_.Push(p, window_size);
53+
for (int32_t i = 0; i != k; ++i, p += window_shift) {
54+
buffer_.Push(p, window_shift);
5155
// NOTE(fangjun): Please don't use a very large n.
5256
bool this_window_is_speech = model_->IsSpeech(p, window_size);
5357
is_speech = is_speech || this_window_is_speech;

0 commit comments

Comments
 (0)