-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathasr_module.py
297 lines (256 loc) · 11.7 KB
/
asr_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
# asr_module.py
import copy
import numpy as np
import queue, threading
import pyaudio, wave, tempfile
from silero_vad import load_silero_vad, get_speech_timestamps
from PyQt5.QtCore import pyqtSignal, QObject
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
import logging
from logging_config import gcww
# 配置日志
logger = logging.getLogger("asr_module")
from vpr_module import VoicePrintRecognition
from vits_module import vitsSpeaker
class SpeechRecognition(QObject):
# 定义信号
update_text_signal = pyqtSignal(tuple) # 用于传递二元组 (音频序列, 文本)
recording_ended_signal = pyqtSignal() # 用于通知录音结束
detect_speech_signal = pyqtSignal(bool) # 用于通知检测到人声
# 定义声纹识别流程信号
open_vp_register_signal = pyqtSignal(bool) # 用于控制是否开启声纹识别注册
def __init__(self, main_settings):
"""语音识别类初始化
Args:
main_settings (dict): 配置文件读取后得到的dict
"""
super().__init__()
self.settings = main_settings
self.asr_auto_send_silence_time = gcww(
self.settings, "asr_auto_send_silence_time", 2.7, logger
)
model_dir = gcww(self.settings, "asr_model_dir", "./SenseVoiceSmall", logger)
# 配置录音参数
self.FORMAT = pyaudio.paInt16
self.CHANNELS = 1
self.RATE = 16000
self.CHUNK = 1024
# 初始化 Silero VAD 模型
self.vad_model = load_silero_vad()
# 初始化 SenseVoice 模型
self.asr_model = AutoModel(
model=model_dir,
trust_remote_code=True,
device="cuda:0", # 默认使用 GPU
)
# 初始化声纹管理器
self.vpr_manager = VoicePrintRecognition(main_settings)
# 初始化vitsSpeaker
self.vits_speaker = vitsSpeaker(main_settings)
# 连接 vitsSpeaker 的音频播放结束信号
self.vits_speaker.audio_start_play.connect(self.on_audio_start_play)
self.vits_speaker.audio_played.connect(self.on_audio_played)
# 初始化标志量
self._is_running = False # 是否正在运行
self.audio_buffer_startup = True # 是否允许开始记录audio_buffer
self.transcribe_but_not_send = False # 已经识别但是未发送
# 初始化音频队列
self.audio_queue = queue.Queue()
self.audio_lock = threading.Lock()
# 是否开启仅注册用户语音识别
self.only_asr_register_user = True
self.open_vp_register_signal.connect(self.set_only_register_user)
def set_only_register_user(self, value):
self.only_asr_register_user = value
def on_audio_start_play(self):
self.audio_buffer_startup = False
logger.debug(f"vits音频播放开始: {self.audio_buffer_startup}")
def on_audio_played(self):
self.audio_buffer_startup = True
logger.debug(f"vits音频播放结束: {self.audio_buffer_startup}")
def detect_speech(self, audio_data, sample_rate=16000):
"""使用 WebRTC VAD 检测音频数据是否包含有效语音。
Args:
audio_data (np.ndarray): 单通道16kHz音频数据(int16或float32格式)
sample_rate (int): 必须为16000
Returns:
bool: 是否检测到人声
"""
# 数据预处理
if not isinstance(audio_data, np.ndarray) or audio_data.size == 0:
return False
# 格式标准化(int16转float32)
if audio_data.dtype == np.int16:
audio_data = audio_data.astype(np.float32) / 32768.0
# 语音段检测(阈值建议0.3-0.5)
speech_segments = get_speech_timestamps(
audio_data,
self.vad_model,
sampling_rate=sample_rate,
threshold=0.3,
return_seconds=True,
)
return len(speech_segments) > 0
# 录音线程
def audio_producer(self):
"""麦克风收音函数 (音频生产者)"""
try:
audio = pyaudio.PyAudio()
stream = audio.open(
format=self.FORMAT,
channels=self.CHANNELS,
rate=self.RATE,
input=True,
frames_per_buffer=self.CHUNK,
)
logger.debug("audio_producer录音线程启动")
while self._is_running:
data = stream.read(self.CHUNK, exception_on_overflow=False)
self.audio_queue.put(data)
stream.stop_stream()
stream.close()
audio.terminate()
logger.debug("audio_producer录音线程停止")
except Exception as e:
logger.error(f"audio_producer录音线程异常: {e}")
logger.debug("audio_producer正常退出")
def audio_consumer(self):
"""
处理音频队列中的数据, 按frame_window_ms进行检测, 并根据检测结果处理静默时长. (音频消费者)
"""
silence_timer = 0 # 秒
audio_buffer = []
temp_frames = [] # 缓存两倍检测窗口数据
frames_per_window = 16 # 每个时间段的帧数
frame_window_ms = frames_per_window * (self.CHUNK / self.RATE) * 1000.0
while self._is_running:
try:
data = self.audio_queue.get(timeout=0.1)
if len(data) == 0:
logger.warning("audio_producer收到空帧,跳过处理")
continue
# 转换数据格式
frame = np.frombuffer(data, dtype=np.int16)
# 缓存音频帧到临时帧
temp_frames.append(frame)
if self.audio_buffer_startup:
# logger.debug("记录audio_buffer")
audio_buffer.append(frame)
# 保持temp_frames的大小不超过两倍frames_per_window
if len(temp_frames) > 2 * frames_per_window:
temp_frames.pop(0)
# 只有当temp_frames达到检测窗口大小才进行VAD检测
if len(temp_frames) >= frames_per_window:
is_active = self.detect_speech(
np.concatenate(temp_frames[-frames_per_window:])
)
if is_active:
user_name = self.vpr_manager.match_voiceprint(
temp_frames[-frames_per_window:]
)
if user_name != "Unknown": # 声纹已注册, 发送语音检测信号量
self.detect_speech_signal.emit(True)
if not self.audio_buffer_startup: # 若还没启动audio_buffer
self.audio_buffer_startup = True # 开始记录audio_buffer
# 留一定比例窗口大小的音频数据缓存, 避免出现头部丢失
audio_buffer = temp_frames[:frames_per_window]
silence_timer = 0
else:
# 检测到静默,累积静默时间
self.detect_speech_signal.emit(False)
silence_timer += (
frame_window_ms / frames_per_window
) / 1000.0 # 转为秒
# logger.debug(f"silence_timer: {silence_timer}")
if audio_buffer:
# 检测到人声
if self.detect_speech(np.concatenate(audio_buffer)):
should_transcribe = True
if self.only_asr_register_user:
user_name = self.vpr_manager.match_voiceprint(
audio_buffer
)
should_transcribe = user_name != "Unknown"
logger.debug(
f"人声是否属于注册用户: {should_transcribe}"
)
if should_transcribe:
self.audio_transcribe(audio_buffer)
audio_buffer.clear()
else:
audio_buffer.pop(0) # 移除最旧的帧
if (
silence_timer >= self.asr_auto_send_silence_time
and self.transcribe_but_not_send
and self.asr_auto_send_silence_time != -1
):
logger.info("静默时间超限,触发结果发送")
self.recording_ended_signal.emit()
self.transcribe_but_not_send = False # 重置未发送标志
# self._is_running = False # 停止录音
except queue.Empty:
continue
except Exception as e:
logger.error(f"audio_consumer音频处理异常: {e}")
break
logger.debug("audio_consumer正常退出")
# 转录并记录
def audio_transcribe(self, frames):
"""对包含人声的音频序列进行语音识别
Args:
frames (NDArray): 音频序列
"""
audio_data = b"".join(frames)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_wav:
with wave.open(temp_wav, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(self.RATE)
wf.writeframes(audio_data)
res = self.asr_model.generate(
input=temp_wav.name,
cache={},
language="auto", # 自动检测语言
use_itn=True,
ban_emo_unk=True, # 情感表情输出
)
if res and res[0]["text"]:
text = rich_transcription_postprocess(res[0]["text"])
logger.debug(f"实时转录结果: {text}")
self.transcribe_but_not_send = True
copied_frames = copy.deepcopy(frames) # 使用深拷贝传递数据
self.update_text_signal.emit((copied_frames, text))
# 启动流式语音识别
def start_streaming(self):
"""语音识别线程启动函数"""
if not self._is_running:
logger.info("启动流式语音识别")
self.audio_queue.queue.clear()
self._is_running = True
self.transcribe_but_not_send = False
producer_thread = threading.Thread(target=self.audio_producer)
consumer_thread = threading.Thread(target=self.audio_consumer)
producer_thread.start()
consumer_thread.start()
producer_thread.join()
consumer_thread.join()
else:
logger.info("流式语音识别已经启动")
# 停止流式语音识别
def stop_streaming(self):
"""语音识别线程终止函数"""
self._is_running = False
if self.transcribe_but_not_send: # 存在未发送内容
logger.info("手动点击按钮,触发语音转录")
self.recording_ended_signal.emit()
self.transcribe_but_not_send = False # 重置未发送标志
logger.info("停止流式语音识别")
if __name__ == "__main__":
import yaml, logging_config
# 初始化日志配置
logging_config.setup_logging()
with open("./config.yaml", "r", encoding="utf-8") as f:
settings = yaml.safe_load(f)
recognizer = SpeechRecognition(settings)
recognizer.start_streaming()