Skip to content

Commit 0337b93

Browse files
authored
Fix nemo streaming transducer greedy search (k2-fsa#944)
1 parent c2eb339 commit 0337b93

18 files changed

+320
-290
lines changed

.github/scripts/test-online-transducer.sh

+39
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,45 @@ echo "PATH: $PATH"
1515

1616
which $EXE
1717

18+
log "------------------------------------------------------------"
19+
log "Run NeMo transducer (English)"
20+
log "------------------------------------------------------------"
21+
repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
22+
curl -SL -O $repo_url
23+
tar xvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
24+
rm sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
25+
repo=sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms
26+
27+
log "Start testing ${repo_url}"
28+
29+
waves=(
30+
$repo/test_wavs/0.wav
31+
$repo/test_wavs/1.wav
32+
$repo/test_wavs/8k.wav
33+
)
34+
35+
for wave in ${waves[@]}; do
36+
time $EXE \
37+
--tokens=$repo/tokens.txt \
38+
--encoder=$repo/encoder.onnx \
39+
--decoder=$repo/decoder.onnx \
40+
--joiner=$repo/joiner.onnx \
41+
--num-threads=2 \
42+
$wave
43+
done
44+
45+
time $EXE \
46+
--tokens=$repo/tokens.txt \
47+
--encoder=$repo/encoder.onnx \
48+
--decoder=$repo/decoder.onnx \
49+
--joiner=$repo/joiner.onnx \
50+
--num-threads=2 \
51+
$repo/test_wavs/0.wav \
52+
$repo/test_wavs/1.wav \
53+
$repo/test_wavs/8k.wav
54+
55+
rm -rf $repo
56+
1857
log "------------------------------------------------------------"
1958
log "Run LSTM transducer (English)"
2059
log "------------------------------------------------------------"

.github/workflows/aarch64-linux-gnu-shared.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ jobs:
196196
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
197197
198198
cd huggingface
199-
git lfs pull
200199
mkdir -p aarch64
201200
202201
cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64

.github/workflows/aarch64-linux-gnu-static.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ jobs:
187187
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
188188
189189
cd huggingface
190-
git lfs pull
191190
mkdir -p aarch64
192191
193192
cp -v ../sherpa-onnx-*-static.tar.bz2 ./aarch64

.github/workflows/android.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ jobs:
124124
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
125125
126126
cd huggingface
127-
git lfs pull
128127
129128
cp -v ../sherpa-onnx-*-android.tar.bz2 ./
130129

.github/workflows/arm-linux-gnueabihf.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ jobs:
209209
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
210210
211211
cd huggingface
212-
git lfs pull
213212
mkdir -p arm32
214213
215214
cp -v ../sherpa-onnx-*.tar.bz2 ./arm32

.github/workflows/build-xcframework.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ jobs:
138138
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
139139
140140
cd huggingface
141-
git lfs pull
142141
143142
cp -v ../sherpa-onnx-*.tar.bz2 ./
144143

.github/workflows/riscv64-linux.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ jobs:
242242
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
243243
244244
cd huggingface
245-
git lfs pull
246245
mkdir -p riscv64
247246
248247
cp -v ../sherpa-onnx-*-shared.tar.bz2 ./riscv64

.github/workflows/windows-x64.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ jobs:
219219
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
220220
221221
cd huggingface
222-
git lfs pull
223222
mkdir -p win64
224223
225224
cp -v ../sherpa-onnx-*.tar.bz2 ./win64

.github/workflows/windows-x86.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ jobs:
221221
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
222222
223223
cd huggingface
224-
git lfs pull
225224
mkdir -p win32
226225
227226
cp -v ../sherpa-onnx-*.tar.bz2 ./win32

sherpa-onnx/csrc/online-recognizer-impl.cc

+10-10
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,18 @@ namespace sherpa_onnx {
1414

1515
std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
1616
const OnlineRecognizerConfig &config) {
17-
1817
if (!config.model_config.transducer.encoder.empty()) {
1918
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
20-
19+
2120
auto decoder_model = ReadFile(config.model_config.transducer.decoder);
22-
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
23-
21+
auto sess = std::make_unique<Ort::Session>(
22+
env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
23+
2424
size_t node_count = sess->GetOutputCount();
25-
25+
2626
if (node_count == 1) {
2727
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
2828
} else {
29-
SHERPA_ONNX_LOGE("Running streaming Nemo transducer model");
3029
return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config);
3130
}
3231
}
@@ -50,12 +49,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
5049
AAssetManager *mgr, const OnlineRecognizerConfig &config) {
5150
if (!config.model_config.transducer.encoder.empty()) {
5251
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
53-
52+
5453
auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder);
55-
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
56-
54+
auto sess = std::make_unique<Ort::Session>(
55+
env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});
56+
5757
size_t node_count = sess->GetOutputCount();
58-
58+
5959
if (node_count == 1) {
6060
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
6161
} else {

sherpa-onnx/csrc/online-recognizer-transducer-impl.h

+4-7
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,15 @@
3535

3636
namespace sherpa_onnx {
3737

38-
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
39-
const SymbolTable &sym_table,
40-
float frame_shift_ms,
41-
int32_t subsampling_factor,
42-
int32_t segment,
43-
int32_t frames_since_start) {
38+
OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
39+
const SymbolTable &sym_table,
40+
float frame_shift_ms, int32_t subsampling_factor,
41+
int32_t segment, int32_t frames_since_start) {
4442
OnlineRecognizerResult r;
4543
r.tokens.reserve(src.tokens.size());
4644
r.timestamps.reserve(src.tokens.size());
4745

4846
for (auto i : src.tokens) {
49-
if (i == -1) continue;
5047
auto sym = sym_table[i];
5148

5249
r.text.append(sym);

sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h

+37-57
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
77
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
88

9+
#include <algorithm>
910
#include <fstream>
1011
#include <ios>
1112
#include <memory>
@@ -32,23 +33,20 @@
3233
namespace sherpa_onnx {
3334

3435
// defined in ./online-recognizer-transducer-impl.h
35-
// static may or may not be here? TODDOs
36-
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
37-
const SymbolTable &sym_table,
38-
float frame_shift_ms,
39-
int32_t subsampling_factor,
40-
int32_t segment,
41-
int32_t frames_since_start);
36+
OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
37+
const SymbolTable &sym_table,
38+
float frame_shift_ms, int32_t subsampling_factor,
39+
int32_t segment, int32_t frames_since_start);
4240

4341
class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
44-
public:
42+
public:
4543
explicit OnlineRecognizerTransducerNeMoImpl(
4644
const OnlineRecognizerConfig &config)
4745
: config_(config),
4846
symbol_table_(config.model_config.tokens),
4947
endpoint_(config_.endpoint_config),
50-
model_(std::make_unique<OnlineTransducerNeMoModel>(
51-
config.model_config)) {
48+
model_(
49+
std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) {
5250
if (config.decoding_method == "greedy_search") {
5351
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
5452
model_.get(), config_.blank_penalty);
@@ -73,7 +71,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
7371
model_.get(), config_.blank_penalty);
7472
} else {
7573
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
76-
config.decoding_method.c_str());
74+
config.decoding_method.c_str());
7775
exit(-1);
7876
}
7977

@@ -83,7 +81,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
8381

8482
std::unique_ptr<OnlineStream> CreateStream() const override {
8583
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
86-
stream->SetStates(model_->GetInitStates());
8784
InitOnlineStream(stream.get());
8885
return stream;
8986
}
@@ -94,14 +91,12 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
9491
}
9592

9693
OnlineRecognizerResult GetResult(OnlineStream *s) const override {
97-
OnlineTransducerDecoderResult decoder_result = s->GetResult();
98-
decoder_->StripLeadingBlanks(&decoder_result);
99-
10094
// TODO(fangjun): Remember to change these constants if needed
10195
int32_t frame_shift_ms = 10;
102-
int32_t subsampling_factor = 8;
103-
return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor,
104-
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
96+
int32_t subsampling_factor = model_->SubsamplingFactor();
97+
return Convert(s->GetResult(), symbol_table_, frame_shift_ms,
98+
subsampling_factor, s->GetCurrentSegment(),
99+
s->GetNumFramesSinceStart());
105100
}
106101

107102
bool IsEndpoint(OnlineStream *s) const override {
@@ -114,8 +109,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
114109
// frame shift is 10 milliseconds
115110
float frame_shift_in_seconds = 0.01;
116111

117-
// subsampling factor is 8
118-
int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 8;
112+
int32_t trailing_silence_frames =
113+
s->GetResult().num_trailing_blanks * model_->SubsamplingFactor();
119114

120115
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
121116
frame_shift_in_seconds);
@@ -126,19 +121,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
126121
// segment is incremented only when the last
127122
// result is not empty
128123
const auto &r = s->GetResult();
129-
if (!r.tokens.empty() && r.tokens.back() != 0) {
124+
if (!r.tokens.empty()) {
130125
s->GetCurrentSegment() += 1;
131126
}
132127
}
133128

134-
// we keep the decoder_out
135-
decoder_->UpdateDecoderOut(&s->GetResult());
136-
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
129+
s->SetResult({});
130+
131+
s->SetStates(model_->GetEncoderInitStates());
137132

138-
auto r = decoder_->GetEmptyResult();
139-
140-
s->SetResult(r);
141-
s->GetResult().decoder_out = std::move(decoder_out);
133+
s->SetNeMoDecoderStates(model_->GetDecoderInitStates());
142134

143135
// Note: We only update counters. The underlying audio samples
144136
// are not discarded.
@@ -151,10 +143,9 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
151143

152144
int32_t feature_dim = ss[0]->FeatureDim();
153145

154-
std::vector<OnlineTransducerDecoderResult> result(n);
155146
std::vector<float> features_vec(n * chunk_size * feature_dim);
156147
std::vector<std::vector<Ort::Value>> encoder_states(n);
157-
148+
158149
for (int32_t i = 0; i != n; ++i) {
159150
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
160151
std::vector<float> features =
@@ -166,9 +157,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
166157
std::copy(features.begin(), features.end(),
167158
features_vec.data() + i * chunk_size * feature_dim);
168159

169-
result[i] = std::move(ss[i]->GetResult());
170160
encoder_states[i] = std::move(ss[i]->GetStates());
171-
172161
}
173162

174163
auto memory_info =
@@ -180,42 +169,35 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
180169
features_vec.size(), x_shape.data(),
181170
x_shape.size());
182171

183-
// Batch size is 1
184-
auto states = std::move(encoder_states[0]);
185-
int32_t num_states = states.size(); // num_states = 3
172+
auto states = model_->StackStates(std::move(encoder_states));
173+
int32_t num_states = states.size(); // num_states = 3
186174
auto t = model_->RunEncoder(std::move(x), std::move(states));
187175
// t[0] encoder_out, float tensor, (batch_size, dim, T)
188176
// t[1] next states
189-
177+
190178
std::vector<Ort::Value> out_states;
191179
out_states.reserve(num_states);
192-
180+
193181
for (int32_t k = 1; k != num_states + 1; ++k) {
194182
out_states.push_back(std::move(t[k]));
195183
}
196184

185+
auto unstacked_states = model_->UnStackStates(std::move(out_states));
186+
for (int32_t i = 0; i != n; ++i) {
187+
ss[i]->SetStates(std::move(unstacked_states[i]));
188+
}
189+
197190
Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);
198-
199-
// defined in online-transducer-greedy-search-nemo-decoder.h
200-
// get intial states of decoder.
201-
std::vector<Ort::Value> &decoder_states = ss[0]->GetNeMoDecoderStates();
202-
203-
// Subsequent decoder states (for each chunks) are updated inside the Decode method.
204-
// This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it.
205-
decoder_states = decoder_->Decode(std::move(encoder_out),
206-
std::move(decoder_states),
207-
&result, ss, n);
208-
209-
ss[0]->SetResult(result[0]);
210-
211-
ss[0]->SetStates(std::move(out_states));
191+
192+
decoder_->Decode(std::move(encoder_out), ss, n);
212193
}
213194

214195
void InitOnlineStream(OnlineStream *stream) const {
215-
auto r = decoder_->GetEmptyResult();
196+
// set encoder states
197+
stream->SetStates(model_->GetEncoderInitStates());
216198

217-
stream->SetResult(r);
218-
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1));
199+
// set decoder states
200+
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates());
219201
}
220202

221203
private:
@@ -250,7 +232,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
250232
symbol_table_.NumSymbols(), vocab_size);
251233
exit(-1);
252234
}
253-
254235
}
255236

256237
private:
@@ -259,9 +240,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
259240
std::unique_ptr<OnlineTransducerNeMoModel> model_;
260241
std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_;
261242
Endpoint endpoint_;
262-
263243
};
264244

265245
} // namespace sherpa_onnx
266246

267-
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
247+
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_

sherpa-onnx/csrc/online-stream.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
225225
return impl_->GetStates();
226226
}
227227

228-
void OnlineStream::SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) {
228+
void OnlineStream::SetNeMoDecoderStates(
229+
std::vector<Ort::Value> decoder_states) {
229230
return impl_->SetNeMoDecoderStates(std::move(decoder_states));
230231
}
231232

0 commit comments

Comments
 (0)