Skip to content

Commit d77d0fd

Browse files
authored
Merge pull request #468 from snakers4/adamnsandle
Adamnsandle
2 parents 4392725 + 49b421a commit d77d0fd

File tree

5 files changed

+79
-45
lines changed

5 files changed

+79
-45
lines changed

README.md

+18-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
<br/>
1414

1515
<p align="center">
16-
<img src="https://user-images.githubusercontent.com/12515440/228639780-876f7801-8ec5-4daf-89f3-b45b22dd1a73.png" />
16+
<img src="https://github.com/snakers4/silero-vad/assets/36505480/300bd062-4da5-4f19-9736-9c144a45d7a7" />
1717
</p>
1818

1919

@@ -38,20 +38,16 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
3838

3939
- **Lightweight**
4040

41-
JIT model is around one megabyte in size.
41+
JIT model is around two megabytes in size.
4242

4343
- **General**
4444

45-
Silero VAD was trained on huge corpora that include over **100** languages and it performs well on audios from different domains with various background noise and quality levels.
45+
Silero VAD was trained on huge corpora that include over **6000** languages and it performs well on audios from different domains with various background noise and quality levels.
4646

4747
- **Flexible sampling rate**
4848

4949
Silero VAD [supports](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#sample-rate-comparison) **8000 Hz** and **16000 Hz** [sampling rates](https://en.wikipedia.org/wiki/Sampling_(signal_processing)#Sampling_rate).
5050

51-
- **Flexible chunk size**
52-
53-
Model was trained on **30 ms**. Longer chunks are supported directly, others may work as well.
54-
5551
- **Highly Portable**
5652

5753
Silero VAD reaps benefits from the rich ecosystems built around **PyTorch** and **ONNX** running everywhere where these runtimes are available.
@@ -60,6 +56,21 @@ https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-
6056

6157
Published under permissive license (MIT) Silero VAD has zero strings attached - no telemetry, no keys, no registration, no built-in expiration, no keys or vendor lock.
6258

59+
<br/>
60+
<h2 align="center">Fast start</h2>
61+
<br/>
62+
63+
```python3
64+
import torch
65+
torch.set_num_threads(1)
66+
67+
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
68+
(get_speech_timestamps, _, read_audio, _, _) = utils
69+
70+
wav = read_audio('path_to_audio_file')
71+
speech_timestamps = get_speech_timestamps(wav, model)
72+
```
73+
6374
<br/>
6475
<h2 align="center">Typical Use Cases</h2>
6576
<br/>

files/silero_vad.jit

820 KB
Binary file not shown.

files/silero_vad.onnx

494 KB
Binary file not shown.

silero-vad.ipynb

+29-15
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
"USE_ONNX = False # change this to True if you want to test onnx model\n",
4747
"if USE_ONNX:\n",
4848
" !pip install -q onnxruntime\n",
49-
" \n",
49+
"\n",
5050
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
5151
" model='silero_vad',\n",
5252
" force_reload=True,\n",
@@ -65,16 +65,7 @@
6565
"id": "fXbbaUO3jsrw"
6666
},
6767
"source": [
68-
"## Full Audio"
69-
]
70-
},
71-
{
72-
"cell_type": "markdown",
73-
"metadata": {
74-
"id": "RAfJPb_a-Auj"
75-
},
76-
"source": [
77-
"**Speech timestapms from full audio**"
68+
"## Speech timestapms from full audio"
7869
]
7970
},
8071
{
@@ -101,10 +92,33 @@
10192
"source": [
10293
"# merge all speech chunks to one audio\n",
10394
"save_audio('only_speech.wav',\n",
104-
" collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE) \n",
95+
" collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE)\n",
10596
"Audio('only_speech.wav')"
10697
]
10798
},
99+
{
100+
"cell_type": "markdown",
101+
"metadata": {
102+
"id": "zeO1xCqxUC6w"
103+
},
104+
"source": [
105+
"## Entire audio inference"
106+
]
107+
},
108+
{
109+
"cell_type": "code",
110+
"execution_count": null,
111+
"metadata": {
112+
"id": "LjZBcsaTT7Mk"
113+
},
114+
"outputs": [],
115+
"source": [
116+
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
117+
"# audio is being splitted into 31.25 ms long pieces\n",
118+
"# so output length equals ceil(input_length * 31.25 / SAMPLING_RATE)\n",
119+
"predicts = model.audio_forward(wav, sr=SAMPLING_RATE)"
120+
]
121+
},
108122
{
109123
"cell_type": "markdown",
110124
"metadata": {
@@ -124,10 +138,10 @@
124138
"source": [
125139
"## using VADIterator class\n",
126140
"\n",
127-
"vad_iterator = VADIterator(model)\n",
141+
"vad_iterator = VADIterator(model, sampling_rate=SAMPLING_RATE)\n",
128142
"wav = read_audio(f'en_example.wav', sampling_rate=SAMPLING_RATE)\n",
129143
"\n",
130-
"window_size_samples = 1536 # number of samples in a single audio chunk\n",
144+
"window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n",
131145
"for i in range(0, len(wav), window_size_samples):\n",
132146
" chunk = wav[i: i+ window_size_samples]\n",
133147
" if len(chunk) < window_size_samples:\n",
@@ -150,7 +164,7 @@
150164
"\n",
151165
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
152166
"speech_probs = []\n",
153-
"window_size_samples = 1536\n",
167+
"window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n",
154168
"for i in range(0, len(wav), window_size_samples):\n",
155169
" chunk = wav[i: i+ window_size_samples]\n",
156170
" if len(chunk) < window_size_samples:\n",

utils_vad.py

+32-23
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import torchaudio
33
from typing import Callable, List
4-
import torch.nn.functional as F
54
import warnings
65

76
languages = ['ru', 'en', 'de', 'es']
@@ -39,22 +38,27 @@ def _validate_input(self, x, sr: int):
3938

4039
if sr not in self.sample_rates:
4140
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
42-
4341
if sr / x.shape[1] > 31.25:
4442
raise ValueError("Input audio chunk is too short")
4543

4644
return x, sr
4745

4846
def reset_states(self, batch_size=1):
49-
self._h = np.zeros((2, batch_size, 64)).astype('float32')
50-
self._c = np.zeros((2, batch_size, 64)).astype('float32')
47+
self._state = torch.zeros((2, batch_size, 128)).float()
48+
self._context = torch.zeros(0)
5149
self._last_sr = 0
5250
self._last_batch_size = 0
5351

5452
def __call__(self, x, sr: int):
5553

5654
x, sr = self._validate_input(x, sr)
55+
num_samples = 512 if sr == 16000 else 256
56+
57+
if x.shape[-1] != num_samples:
58+
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
59+
5760
batch_size = x.shape[0]
61+
context_size = 64 if sr == 16000 else 32
5862

5963
if not self._last_batch_size:
6064
self.reset_states(batch_size)
@@ -63,28 +67,35 @@ def __call__(self, x, sr: int):
6367
if (self._last_batch_size) and (self._last_batch_size != batch_size):
6468
self.reset_states(batch_size)
6569

70+
if not len(self._context):
71+
self._context = torch.zeros(batch_size, context_size)
72+
73+
x = torch.cat([self._context, x], dim=1)
6674
if sr in [8000, 16000]:
67-
ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr, dtype='int64')}
75+
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr)}
6876
ort_outs = self.session.run(None, ort_inputs)
69-
out, self._h, self._c = ort_outs
77+
out, state = ort_outs
78+
self._state = torch.from_numpy(state)
7079
else:
7180
raise ValueError()
7281

82+
self._context = x[..., -context_size:]
7383
self._last_sr = sr
7484
self._last_batch_size = batch_size
7585

76-
out = torch.tensor(out)
86+
out = torch.from_numpy(out)
7787
return out
7888

79-
def audio_forward(self, x, sr: int, num_samples: int = 512):
89+
def audio_forward(self, x, sr: int):
8090
outs = []
8191
x, sr = self._validate_input(x, sr)
92+
self.reset_states()
93+
num_samples = 512 if sr == 16000 else 256
8294

8395
if x.shape[1] % num_samples:
8496
pad_num = num_samples - (x.shape[1] % num_samples)
8597
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
8698

87-
self.reset_states(x.shape[0])
8899
for i in range(0, x.shape[1], num_samples):
89100
wavs_batch = x[:, i:i+num_samples]
90101
out_chunk = self.__call__(wavs_batch, sr)
@@ -179,11 +190,11 @@ def get_speech_timestamps(audio: torch.Tensor,
179190
min_speech_duration_ms: int = 250,
180191
max_speech_duration_s: float = float('inf'),
181192
min_silence_duration_ms: int = 100,
182-
window_size_samples: int = 512,
183193
speech_pad_ms: int = 30,
184194
return_seconds: bool = False,
185195
visualize_probs: bool = False,
186-
progress_tracking_callback: Callable[[float], None] = None):
196+
progress_tracking_callback: Callable[[float], None] = None,
197+
window_size_samples: int = 512,):
187198

188199
"""
189200
This method is used for splitting long audios into speech chunks using silero VAD
@@ -193,14 +204,14 @@ def get_speech_timestamps(audio: torch.Tensor,
193204
audio: torch.Tensor, one dimensional
194205
One dimensional float torch.Tensor, other types are casted to torch if possible
195206
196-
model: preloaded .jit silero VAD model
207+
model: preloaded .jit/.onnx silero VAD model
197208
198209
threshold: float (default - 0.5)
199210
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
200211
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
201212
202213
sampling_rate: int (default - 16000)
203-
Currently silero VAD models support 8000 and 16000 sample rates
214+
Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates
204215
205216
min_speech_duration_ms: int (default - 250 milliseconds)
206217
Final speech chunks shorter min_speech_duration_ms are thrown out
@@ -213,11 +224,6 @@ def get_speech_timestamps(audio: torch.Tensor,
213224
min_silence_duration_ms: int (default - 100 milliseconds)
214225
In the end of each speech chunk wait for min_silence_duration_ms before separating it
215226
216-
window_size_samples: int (default - 1536 samples)
217-
Audio chunks of window_size_samples size are fed to the silero VAD model.
218-
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples for 8000 sample rate.
219-
Values other than these may affect model perfomance!!
220-
221227
speech_pad_ms: int (default - 30 milliseconds)
222228
Final speech chunks are padded by speech_pad_ms each side
223229
@@ -230,6 +236,9 @@ def get_speech_timestamps(audio: torch.Tensor,
230236
progress_tracking_callback: Callable[[float], None] (default - None)
231237
callback function taking progress in percents as an argument
232238
239+
window_size_samples: int (default - 512 samples)
240+
!!! DEPRECATED, DOES NOTHING !!!
241+
233242
Returns
234243
----------
235244
speeches: list of dicts
@@ -256,10 +265,10 @@ def get_speech_timestamps(audio: torch.Tensor,
256265
else:
257266
step = 1
258267

259-
if sampling_rate == 8000 and window_size_samples > 768:
260-
warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!')
261-
if window_size_samples not in [256, 512, 768, 1024, 1536]:
262-
warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate')
268+
if sampling_rate not in [8000, 16000]:
269+
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
270+
271+
window_size_samples = 512 if sampling_rate == 16000 else 256
263272

264273
model.reset_states()
265274
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
@@ -450,7 +459,7 @@ def __init__(self,
450459
451460
Parameters
452461
----------
453-
model: preloaded .jit silero VAD model
462+
model: preloaded .jit/.onnx silero VAD model
454463
455464
threshold: float (default - 0.5)
456465
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.

0 commit comments

Comments
 (0)