Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 206d3f7

Browse files
committedJan 6, 2025·
Fix keyword spotting.
See also #1417
1 parent 930986b commit 206d3f7

7 files changed

+40
-12
lines changed
 

‎sherpa-onnx/csrc/keyword-spotter-impl.h

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class KeywordSpotterImpl {
3838

3939
virtual bool IsReady(OnlineStream *s) const = 0;
4040

41+
virtual void Reset(OnlineStream *s) const = 0;
42+
4143
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
4244

4345
virtual KeywordResult GetResult(OnlineStream *s) const = 0;

‎sherpa-onnx/csrc/keyword-spotter-transducer-impl.h

+16
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,24 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
195195
return s->GetNumProcessedFrames() + model_->ChunkSize() <
196196
s->NumFramesReady();
197197
}
198+
void Reset(OnlineStream *s) const override { InitOnlineStream(s); }
198199

199200
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
201+
for (int32_t i = 0; i < n; ++i) {
202+
auto s = ss[i];
203+
auto r = s->GetKeywordResult(true);
204+
int32_t num_trailing_blanks = r.num_trailing_blanks;
205+
// assume subsampling_factor is 4
206+
// assume frameshift is 0.01 second
207+
float trailing_slience = num_trailing_blanks * 4 * 0.01;
208+
209+
// it resets automatically after detecting 1.5 seconds of silence
210+
float threshold = 1.5;
211+
if (trailing_slience > threshold) {
212+
Reset(s);
213+
}
214+
}
215+
200216
int32_t chunk_size = model_->ChunkSize();
201217
int32_t chunk_shift = model_->ChunkShift();
202218

‎sherpa-onnx/csrc/keyword-spotter.cc

+2
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ bool KeywordSpotter::IsReady(OnlineStream *s) const {
157157
return impl_->IsReady(s);
158158
}
159159

160+
void KeywordSpotter::Reset(OnlineStream *s) const { impl_->Reset(s); }
161+
160162
void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
161163
impl_->DecodeStreams(ss, n);
162164
}

‎sherpa-onnx/csrc/keyword-spotter.h

+3
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ class KeywordSpotter {
129129
*/
130130
bool IsReady(OnlineStream *s) const;
131131

132+
// Remember to call it after detecting a keyword
133+
void Reset(OnlineStream *s) const;
134+
132135
/** Decode a single stream. */
133136
void DecodeStream(OnlineStream *s) const {
134137
OnlineStream *ss[1] = {s};

‎sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc

+8-6
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,15 @@ as the device_name.
106106

107107
while (spotter.IsReady(stream.get())) {
108108
spotter.DecodeStream(stream.get());
109-
}
110109

111-
const auto r = spotter.GetResult(stream.get());
112-
if (!r.keyword.empty()) {
113-
display.Print(keyword_index, r.AsJsonString());
114-
fflush(stderr);
115-
keyword_index++;
110+
const auto r = spotter.GetResult(stream.get());
111+
if (!r.keyword.empty()) {
112+
display.Print(keyword_index, r.AsJsonString());
113+
fflush(stderr);
114+
keyword_index++;
115+
116+
spotter.Reset(stream.get());
117+
}
116118
}
117119
}
118120

‎sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc

+8-6
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,15 @@ for a list of pre-trained models to download.
150150
while (!stop) {
151151
while (spotter.IsReady(s.get())) {
152152
spotter.DecodeStream(s.get());
153-
}
154153

155-
const auto r = spotter.GetResult(s.get());
156-
if (!r.keyword.empty()) {
157-
display.Print(keyword_index, r.AsJsonString());
158-
fflush(stderr);
159-
keyword_index++;
154+
const auto r = spotter.GetResult(s.get());
155+
if (!r.keyword.empty()) {
156+
display.Print(keyword_index, r.AsJsonString());
157+
fflush(stderr);
158+
keyword_index++;
159+
160+
spotter.Reset(s.get());
161+
}
160162
}
161163

162164
Pa_Sleep(20); // sleep for 20ms

‎sherpa-onnx/python/csrc/keyword-spotter.cc

+1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ void PybindKeywordSpotter(py::module *m) {
6767
py::arg("keywords"), py::call_guard<py::gil_scoped_release>())
6868
.def("is_ready", &PyClass::IsReady,
6969
py::call_guard<py::gil_scoped_release>())
70+
.def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
7071
.def("decode_stream", &PyClass::DecodeStream,
7172
py::call_guard<py::gil_scoped_release>())
7273
.def(

0 commit comments

Comments
 (0)