Skip to content

Commit b8a45ac

Browse files
authored
Greedy search and modified beam search for pruned stateless RNN-T. (#975)
* First version of greedy search. * WIP: Implement modified beam search and greedy search for pruned RNN-T. * Implement modified beam search.
1 parent 47587be commit b8a45ac

12 files changed

+959
-3
lines changed

k2/csrc/ragged_ops.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,7 @@ RaggedShape ComposeRaggedShapes3(const RaggedShape &a, const RaggedShape &b,
12021202
If cached_tot_sizeN is not -1, it must equal the total size on
12031203
that axis which will equal the last element of row_splitsN (if
12041204
provided) and must equal the row_idsN.Dim(), if provided. See
1205-
documentation above for RagggedShape2 for details.
1205+
documentation above for RaggedShape2 for details.
12061206
12071207
We also require that (supposing both row_splitsN and row_idsN are non-NULL):
12081208
row_splits1[row_splits1.Dim() - 1] == row_ids1.Dim()

k2/torch/bin/CMakeLists.txt

+11-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ add_executable(online_decode ${online_decode_srcs})
6464
set_property(TARGET online_decode PROPERTY CXX_STANDARD 14)
6565
target_link_libraries(online_decode ${bin_dep_libs})
6666

67-
6867
#-------------------------------------------
6968
# rnnt demo
7069
#-------------------------------------------
@@ -77,3 +76,14 @@ add_executable(rnnt_demo ${rnnt_demo_srcs})
7776
set_property(TARGET rnnt_demo PROPERTY CXX_STANDARD 14)
7877
target_link_libraries(rnnt_demo ${bin_dep_libs})
7978

79+
#-------------------------------------------
80+
# pruned stateless transducer
81+
#-------------------------------------------
82+
set(pruned_stateless_transducer_srcs pruned_stateless_transducer.cu)
83+
if(NOT K2_WITH_CUDA)
84+
transform(OUTPUT_VARIABLE pruned_stateless_transducer_srcs SRCS ${pruned_stateless_transducer_srcs})
85+
endif()
86+
87+
add_executable(pruned_stateless_transducer ${pruned_stateless_transducer_srcs})
88+
set_property(TARGET pruned_stateless_transducer PROPERTY CXX_STANDARD 14)
89+
target_link_libraries(pruned_stateless_transducer ${bin_dep_libs})
+194
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
/**
2+
* Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
3+
*
4+
* See LICENSE for clarification regarding multiple authors
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
#include "k2/csrc/log.h"
20+
#include "k2/torch/csrc/beam_search.h"
21+
#include "k2/torch/csrc/features.h"
22+
#include "k2/torch/csrc/parse_options.h"
23+
#include "k2/torch/csrc/wave_reader.h"
24+
#include "sentencepiece_processor.h" // NOLINT
25+
#include "torch/all.h"
26+
27+
static constexpr const char *kUsageMessage = R"(
28+
This file implements RNN-T decoding for pruned stateless transducer models
29+
that are trained using pruned_transducer_statelessX (X>=2) from icefall.
30+
31+
Usage:
32+
./bin/pruned_stateless_transducer --help
33+
34+
./bin/pruned_stateless_transducer \
35+
--nn-model=/path/to/cpu_jit.pt \
36+
--bpe-model=/path/to/bpe.model \
37+
--use-gpu=true \
38+
--decoding-method=modified_beam_search \
39+
/path/to/foo.wav \
40+
/path/to/bar.wav
41+
)";
42+
43+
static void RegisterFrameExtractionOptions(
44+
k2::ParseOptions *po, kaldifeat::FrameExtractionOptions *opts) {
45+
po->Register("sample-frequency", &opts->samp_freq,
46+
"Waveform data sample frequency (must match the waveform file, "
47+
"if specified there)");
48+
49+
po->Register("frame-length", &opts->frame_length_ms,
50+
"Frame length in milliseconds");
51+
52+
po->Register("frame-shift", &opts->frame_shift_ms,
53+
"Frame shift in milliseconds");
54+
55+
po->Register("dither", &opts->dither,
56+
"Dithering constant (0.0 means no dither).");
57+
}
58+
59+
static void RegisterMelBanksOptions(k2::ParseOptions *po,
60+
kaldifeat::MelBanksOptions *opts) {
61+
po->Register("num-mel-bins", &opts->num_bins,
62+
"Number of triangular mel-frequency bins");
63+
}
64+
65+
int main(int argc, char *argv[]) {
66+
// see
67+
// https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html
68+
torch::set_num_threads(1);
69+
torch::set_num_interop_threads(1);
70+
torch::NoGradGuard no_grad;
71+
72+
k2::ParseOptions po(kUsageMessage);
73+
74+
std::string nn_model; // path to the torch jit model file
75+
std::string bpe_model; // path to the BPE model file
76+
bool use_gpu = false; // true to use GPU for decoding; false to use CPU.
77+
std::string decoding_method = "greedy_search"; // Supported methods are:
78+
// greedy_search,
79+
// modified_beam_search
80+
81+
kaldifeat::FbankOptions fbank_opts;
82+
fbank_opts.frame_opts.dither = 0;
83+
RegisterFrameExtractionOptions(&po, &fbank_opts.frame_opts);
84+
fbank_opts.mel_opts.num_bins = 80;
85+
RegisterMelBanksOptions(&po, &fbank_opts.mel_opts);
86+
87+
po.Register("nn-model", &nn_model, "Path to the torch jit model file");
88+
89+
po.Register("bpe-model", &bpe_model, "Path to the BPE model file");
90+
91+
po.Register("use-gpu", &use_gpu,
92+
"true to use GPU for decoding; false to use CPU. "
93+
"If GPU is enabled, it always uses GPU 0. You can use "
94+
"the environment variable CUDA_VISIBLE_DEVICES to control "
95+
"which GPU device to use.");
96+
97+
po.Register(
98+
"decoding-method", &decoding_method,
99+
"Decoding method to use."
100+
"Currently implemented methods are: greedy_search, modified_beam_search");
101+
102+
po.Read(argc, argv);
103+
104+
K2_CHECK(decoding_method == "greedy_search" ||
105+
decoding_method == "modified_beam_search")
106+
<< "Currently supported decoding methods are: "
107+
"greedy_search, modified_beam_search. "
108+
<< "Given: " << decoding_method;
109+
110+
torch::Device device(torch::kCPU);
111+
if (use_gpu) {
112+
K2_LOG(INFO) << "Use GPU";
113+
device = torch::Device(torch::kCUDA, 0);
114+
}
115+
116+
K2_LOG(INFO) << "Device: " << device;
117+
118+
int32_t num_waves = po.NumArgs();
119+
K2_CHECK_GT(num_waves, 0) << "Please provide at least one wave file";
120+
121+
std::vector<std::string> wave_filenames(num_waves);
122+
for (int32_t i = 0; i < num_waves; ++i) {
123+
wave_filenames[i] = po.GetArg(i + 1);
124+
}
125+
126+
K2_LOG(INFO) << "Loading wave files";
127+
std::vector<torch::Tensor> wave_data =
128+
k2::ReadWave(wave_filenames, fbank_opts.frame_opts.samp_freq);
129+
for (auto &w : wave_data) {
130+
w = w.to(device);
131+
}
132+
133+
fbank_opts.device = device;
134+
135+
kaldifeat::Fbank fbank(fbank_opts);
136+
137+
K2_LOG(INFO) << "Computing features";
138+
std::vector<int64_t> num_frames;
139+
std::vector<torch::Tensor> features_vec =
140+
k2::ComputeFeatures(fbank, wave_data, &num_frames);
141+
142+
// Note: math.log(1e-10) is -23.025850929940457
143+
torch::Tensor features = torch::nn::utils::rnn::pad_sequence(
144+
features_vec, /*batch_first*/ true,
145+
/*padding_value*/ -23.025850929940457f);
146+
torch::Tensor feature_lens = torch::tensor(num_frames, device);
147+
148+
K2_LOG(INFO) << "Loading neural network model from " << nn_model;
149+
torch::jit::Module module = torch::jit::load(nn_model);
150+
module.eval();
151+
module.to(device);
152+
153+
K2_LOG(INFO) << "Computing output of the encoder network";
154+
155+
auto outputs = module.attr("encoder")
156+
.toModule()
157+
.run_method("forward", features, feature_lens)
158+
.toTuple();
159+
assert(outputs->elements().size() == 2u);
160+
161+
auto encoder_out = outputs->elements()[0].toTensor();
162+
auto encoder_out_lens = outputs->elements()[1].toTensor();
163+
164+
K2_LOG(INFO) << "Using " << decoding_method;
165+
166+
std::vector<std::vector<int32_t>> hyp_tokens;
167+
if (decoding_method == "greedy_search") {
168+
hyp_tokens = k2::GreedySearch(module, encoder_out, encoder_out_lens.cpu());
169+
} else {
170+
hyp_tokens =
171+
k2::ModifiedBeamSearch(module, encoder_out, encoder_out_lens.cpu());
172+
}
173+
174+
sentencepiece::SentencePieceProcessor processor;
175+
auto status = processor.Load(bpe_model);
176+
K2_CHECK(status.ok()) << status.ToString();
177+
178+
std::vector<std::string> texts;
179+
for (const auto &ids : hyp_tokens) {
180+
std::string text;
181+
status = processor.Decode(ids, &text);
182+
K2_CHECK(status.ok()) << status.ToString();
183+
texts.emplace_back(std::move(text));
184+
}
185+
186+
std::ostringstream os;
187+
os << "\nDecoding result:\n\n";
188+
for (int32_t i = 0; i != num_waves; ++i) {
189+
os << wave_filenames[i] << "\n";
190+
os << texts[i];
191+
os << "\n\n";
192+
}
193+
K2_LOG(INFO) << os.str();
194+
};

k2/torch/csrc/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ include_directories(${CMAKE_SOURCE_DIR})
33
# it is located in k2/csrc/cmake/transform.cmake
44
include(transform)
55
set(k2_torch_srcs
6+
beam_search.cu
67
decode.cu
78
dense_fsa_vec.cu
89
deserialization.cu
910
features.cu
1011
fsa_algo.cu
1112
fsa_class.cu
13+
hypothesis.cu
1214
nbest.cu
1315
parse_options.cu
1416
symbol_table.cu
@@ -28,6 +30,7 @@ set(k2_torch_test_srcs
2830
dense_fsa_vec_test.cu
2931
deserialization_test.cu
3032
fsa_class_test.cu
33+
hypothesis_test.cu
3134
parse_options_test.cu
3235
wave_reader_test.cu
3336
)

0 commit comments

Comments
 (0)