1
1
import torch
2
2
import torchaudio
3
3
from typing import Callable , List
4
- import torch .nn .functional as F
5
4
import warnings
6
5
7
6
languages = ['ru' , 'en' , 'de' , 'es' ]
@@ -39,22 +38,27 @@ def _validate_input(self, x, sr: int):
39
38
40
39
if sr not in self .sample_rates :
41
40
raise ValueError (f"Supported sampling rates: { self .sample_rates } (or multiply of 16000)" )
42
-
43
41
if sr / x .shape [1 ] > 31.25 :
44
42
raise ValueError ("Input audio chunk is too short" )
45
43
46
44
return x , sr
47
45
48
46
def reset_states (self , batch_size = 1 ):
49
- self ._h = np .zeros ((2 , batch_size , 64 )).astype ( 'float32' )
50
- self ._c = np .zeros (( 2 , batch_size , 64 )). astype ( 'float32' )
47
+ self ._state = torch .zeros ((2 , batch_size , 128 )).float ( )
48
+ self ._context = torch .zeros (0 )
51
49
self ._last_sr = 0
52
50
self ._last_batch_size = 0
53
51
54
52
def __call__ (self , x , sr : int ):
55
53
56
54
x , sr = self ._validate_input (x , sr )
55
+ num_samples = 512 if sr == 16000 else 256
56
+
57
+ if x .shape [- 1 ] != num_samples :
58
+ raise ValueError (f"Provided number of samples is { x .shape [- 1 ]} (Supported values: 256 for 8000 sample rate, 512 for 16000)" )
59
+
57
60
batch_size = x .shape [0 ]
61
+ context_size = 64 if sr == 16000 else 32
58
62
59
63
if not self ._last_batch_size :
60
64
self .reset_states (batch_size )
@@ -63,28 +67,35 @@ def __call__(self, x, sr: int):
63
67
if (self ._last_batch_size ) and (self ._last_batch_size != batch_size ):
64
68
self .reset_states (batch_size )
65
69
70
+ if not len (self ._context ):
71
+ self ._context = torch .zeros (batch_size , context_size )
72
+
73
+ x = torch .cat ([self ._context , x ], dim = 1 )
66
74
if sr in [8000 , 16000 ]:
67
- ort_inputs = {'input' : x .numpy (), 'h ' : self ._h , 'c' : self . _c , 'sr' : np .array (sr , dtype = 'int64' )}
75
+ ort_inputs = {'input' : x .numpy (), 'state ' : self ._state . numpy () , 'sr' : np .array (sr )}
68
76
ort_outs = self .session .run (None , ort_inputs )
69
- out , self ._h , self ._c = ort_outs
77
+ out , state = ort_outs
78
+ self ._state = torch .from_numpy (state )
70
79
else :
71
80
raise ValueError ()
72
81
82
+ self ._context = x [..., - context_size :]
73
83
self ._last_sr = sr
74
84
self ._last_batch_size = batch_size
75
85
76
- out = torch .tensor (out )
86
+ out = torch .from_numpy (out )
77
87
return out
78
88
79
- def audio_forward (self , x , sr : int , num_samples : int = 512 ):
89
+ def audio_forward (self , x , sr : int ):
80
90
outs = []
81
91
x , sr = self ._validate_input (x , sr )
92
+ self .reset_states ()
93
+ num_samples = 512 if sr == 16000 else 256
82
94
83
95
if x .shape [1 ] % num_samples :
84
96
pad_num = num_samples - (x .shape [1 ] % num_samples )
85
97
x = torch .nn .functional .pad (x , (0 , pad_num ), 'constant' , value = 0.0 )
86
98
87
- self .reset_states (x .shape [0 ])
88
99
for i in range (0 , x .shape [1 ], num_samples ):
89
100
wavs_batch = x [:, i :i + num_samples ]
90
101
out_chunk = self .__call__ (wavs_batch , sr )
@@ -179,11 +190,11 @@ def get_speech_timestamps(audio: torch.Tensor,
179
190
min_speech_duration_ms : int = 250 ,
180
191
max_speech_duration_s : float = float ('inf' ),
181
192
min_silence_duration_ms : int = 100 ,
182
- window_size_samples : int = 512 ,
183
193
speech_pad_ms : int = 30 ,
184
194
return_seconds : bool = False ,
185
195
visualize_probs : bool = False ,
186
- progress_tracking_callback : Callable [[float ], None ] = None ):
196
+ progress_tracking_callback : Callable [[float ], None ] = None ,
197
+ window_size_samples : int = 512 ,):
187
198
188
199
"""
189
200
This method is used for splitting long audios into speech chunks using silero VAD
@@ -193,14 +204,14 @@ def get_speech_timestamps(audio: torch.Tensor,
193
204
audio: torch.Tensor, one dimensional
194
205
One dimensional float torch.Tensor, other types are casted to torch if possible
195
206
196
- model: preloaded .jit silero VAD model
207
+ model: preloaded .jit/.onnx silero VAD model
197
208
198
209
threshold: float (default - 0.5)
199
210
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
200
211
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
201
212
202
213
sampling_rate: int (default - 16000)
203
- Currently silero VAD models support 8000 and 16000 sample rates
214
+ Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates
204
215
205
216
min_speech_duration_ms: int (default - 250 milliseconds)
206
217
Final speech chunks shorter min_speech_duration_ms are thrown out
@@ -213,11 +224,6 @@ def get_speech_timestamps(audio: torch.Tensor,
213
224
min_silence_duration_ms: int (default - 100 milliseconds)
214
225
In the end of each speech chunk wait for min_silence_duration_ms before separating it
215
226
216
- window_size_samples: int (default - 1536 samples)
217
- Audio chunks of window_size_samples size are fed to the silero VAD model.
218
- WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples for 8000 sample rate.
219
- Values other than these may affect model perfomance!!
220
-
221
227
speech_pad_ms: int (default - 30 milliseconds)
222
228
Final speech chunks are padded by speech_pad_ms each side
223
229
@@ -230,6 +236,9 @@ def get_speech_timestamps(audio: torch.Tensor,
230
236
progress_tracking_callback: Callable[[float], None] (default - None)
231
237
callback function taking progress in percents as an argument
232
238
239
+ window_size_samples: int (default - 512 samples)
240
+ !!! DEPRECATED, DOES NOTHING !!!
241
+
233
242
Returns
234
243
----------
235
244
speeches: list of dicts
@@ -256,10 +265,10 @@ def get_speech_timestamps(audio: torch.Tensor,
256
265
else :
257
266
step = 1
258
267
259
- if sampling_rate == 8000 and window_size_samples > 768 :
260
- warnings . warn ( 'window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!' )
261
- if window_size_samples not in [ 256 , 512 , 768 , 1024 , 1536 ]:
262
- warnings . warn ( 'Unusual window_size_samples! Supported window_size_samples: \n - [ 512, 1024, 1536] for 16000 sampling_rate \n - [ 256, 512, 768] for 8000 sampling_rate' )
268
+ if sampling_rate not in [ 8000 , 16000 ] :
269
+ raise ValueError ( "Currently silero VAD models support 8000 and 16000 ( or multiply of 16000) sample rates" )
270
+
271
+ window_size_samples = 512 if sampling_rate == 16000 else 256
263
272
264
273
model .reset_states ()
265
274
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
@@ -450,7 +459,7 @@ def __init__(self,
450
459
451
460
Parameters
452
461
----------
453
- model: preloaded .jit silero VAD model
462
+ model: preloaded .jit/.onnx silero VAD model
454
463
455
464
threshold: float (default - 0.5)
456
465
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
0 commit comments