|
| 1 | +/** |
| 2 | + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) |
| 3 | + * |
| 4 | + * See LICENSE for clarification regarding multiple authors |
| 5 | + * |
| 6 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 7 | + * you may not use this file except in compliance with the License. |
| 8 | + * You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, software |
| 13 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 14 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | + * See the License for the specific language governing permissions and |
| 16 | + * limitations under the License. |
| 17 | + */ |
| 18 | + |
| 19 | +#include "k2/csrc/log.h" |
| 20 | +#include "k2/torch/csrc/beam_search.h" |
| 21 | +#include "k2/torch/csrc/features.h" |
| 22 | +#include "k2/torch/csrc/parse_options.h" |
| 23 | +#include "k2/torch/csrc/wave_reader.h" |
| 24 | +#include "sentencepiece_processor.h" // NOLINT |
| 25 | +#include "torch/all.h" |
| 26 | + |
| 27 | +static constexpr const char *kUsageMessage = R"( |
| 28 | +This file implements RNN-T decoding for pruned stateless transducer models |
| 29 | +that are trained using pruned_transducer_statelessX (X>=2) from icefall. |
| 30 | +
|
| 31 | +Usage: |
| 32 | + ./bin/pruned_stateless_transducer --help |
| 33 | +
|
| 34 | + ./bin/pruned_stateless_transducer \ |
| 35 | + --nn-model=/path/to/cpu_jit.pt \ |
| 36 | + --bpe-model=/path/to/bpe.model \ |
| 37 | + --use-gpu=true \ |
| 38 | + --decoding-method=modified_beam_search \ |
| 39 | + /path/to/foo.wav \ |
| 40 | + /path/to/bar.wav |
| 41 | +)"; |
| 42 | + |
| 43 | +static void RegisterFrameExtractionOptions( |
| 44 | + k2::ParseOptions *po, kaldifeat::FrameExtractionOptions *opts) { |
| 45 | + po->Register("sample-frequency", &opts->samp_freq, |
| 46 | + "Waveform data sample frequency (must match the waveform file, " |
| 47 | + "if specified there)"); |
| 48 | + |
| 49 | + po->Register("frame-length", &opts->frame_length_ms, |
| 50 | + "Frame length in milliseconds"); |
| 51 | + |
| 52 | + po->Register("frame-shift", &opts->frame_shift_ms, |
| 53 | + "Frame shift in milliseconds"); |
| 54 | + |
| 55 | + po->Register("dither", &opts->dither, |
| 56 | + "Dithering constant (0.0 means no dither)."); |
| 57 | +} |
| 58 | + |
| 59 | +static void RegisterMelBanksOptions(k2::ParseOptions *po, |
| 60 | + kaldifeat::MelBanksOptions *opts) { |
| 61 | + po->Register("num-mel-bins", &opts->num_bins, |
| 62 | + "Number of triangular mel-frequency bins"); |
| 63 | +} |
| 64 | + |
| 65 | +int main(int argc, char *argv[]) { |
| 66 | + // see |
| 67 | + // https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html |
| 68 | + torch::set_num_threads(1); |
| 69 | + torch::set_num_interop_threads(1); |
| 70 | + torch::NoGradGuard no_grad; |
| 71 | + |
| 72 | + k2::ParseOptions po(kUsageMessage); |
| 73 | + |
| 74 | + std::string nn_model; // path to the torch jit model file |
| 75 | + std::string bpe_model; // path to the BPE model file |
| 76 | + bool use_gpu = false; // true to use GPU for decoding; false to use CPU. |
| 77 | + std::string decoding_method = "greedy_search"; // Supported methods are: |
| 78 | + // greedy_search, |
| 79 | + // modified_beam_search |
| 80 | + |
| 81 | + kaldifeat::FbankOptions fbank_opts; |
| 82 | + fbank_opts.frame_opts.dither = 0; |
| 83 | + RegisterFrameExtractionOptions(&po, &fbank_opts.frame_opts); |
| 84 | + fbank_opts.mel_opts.num_bins = 80; |
| 85 | + RegisterMelBanksOptions(&po, &fbank_opts.mel_opts); |
| 86 | + |
| 87 | + po.Register("nn-model", &nn_model, "Path to the torch jit model file"); |
| 88 | + |
| 89 | + po.Register("bpe-model", &bpe_model, "Path to the BPE model file"); |
| 90 | + |
| 91 | + po.Register("use-gpu", &use_gpu, |
| 92 | + "true to use GPU for decoding; false to use CPU. " |
| 93 | + "If GPU is enabled, it always uses GPU 0. You can use " |
| 94 | + "the environment variable CUDA_VISIBLE_DEVICES to control " |
| 95 | + "which GPU device to use."); |
| 96 | + |
| 97 | + po.Register( |
| 98 | + "decoding-method", &decoding_method, |
| 99 | + "Decoding method to use." |
| 100 | + "Currently implemented methods are: greedy_search, modified_beam_search"); |
| 101 | + |
| 102 | + po.Read(argc, argv); |
| 103 | + |
| 104 | + K2_CHECK(decoding_method == "greedy_search" || |
| 105 | + decoding_method == "modified_beam_search") |
| 106 | + << "Currently supported decoding methods are: " |
| 107 | + "greedy_search, modified_beam_search. " |
| 108 | + << "Given: " << decoding_method; |
| 109 | + |
| 110 | + torch::Device device(torch::kCPU); |
| 111 | + if (use_gpu) { |
| 112 | + K2_LOG(INFO) << "Use GPU"; |
| 113 | + device = torch::Device(torch::kCUDA, 0); |
| 114 | + } |
| 115 | + |
| 116 | + K2_LOG(INFO) << "Device: " << device; |
| 117 | + |
| 118 | + int32_t num_waves = po.NumArgs(); |
| 119 | + K2_CHECK_GT(num_waves, 0) << "Please provide at least one wave file"; |
| 120 | + |
| 121 | + std::vector<std::string> wave_filenames(num_waves); |
| 122 | + for (int32_t i = 0; i < num_waves; ++i) { |
| 123 | + wave_filenames[i] = po.GetArg(i + 1); |
| 124 | + } |
| 125 | + |
| 126 | + K2_LOG(INFO) << "Loading wave files"; |
| 127 | + std::vector<torch::Tensor> wave_data = |
| 128 | + k2::ReadWave(wave_filenames, fbank_opts.frame_opts.samp_freq); |
| 129 | + for (auto &w : wave_data) { |
| 130 | + w = w.to(device); |
| 131 | + } |
| 132 | + |
| 133 | + fbank_opts.device = device; |
| 134 | + |
| 135 | + kaldifeat::Fbank fbank(fbank_opts); |
| 136 | + |
| 137 | + K2_LOG(INFO) << "Computing features"; |
| 138 | + std::vector<int64_t> num_frames; |
| 139 | + std::vector<torch::Tensor> features_vec = |
| 140 | + k2::ComputeFeatures(fbank, wave_data, &num_frames); |
| 141 | + |
| 142 | + // Note: math.log(1e-10) is -23.025850929940457 |
| 143 | + torch::Tensor features = torch::nn::utils::rnn::pad_sequence( |
| 144 | + features_vec, /*batch_first*/ true, |
| 145 | + /*padding_value*/ -23.025850929940457f); |
| 146 | + torch::Tensor feature_lens = torch::tensor(num_frames, device); |
| 147 | + |
| 148 | + K2_LOG(INFO) << "Loading neural network model from " << nn_model; |
| 149 | + torch::jit::Module module = torch::jit::load(nn_model); |
| 150 | + module.eval(); |
| 151 | + module.to(device); |
| 152 | + |
| 153 | + K2_LOG(INFO) << "Computing output of the encoder network"; |
| 154 | + |
| 155 | + auto outputs = module.attr("encoder") |
| 156 | + .toModule() |
| 157 | + .run_method("forward", features, feature_lens) |
| 158 | + .toTuple(); |
| 159 | + assert(outputs->elements().size() == 2u); |
| 160 | + |
| 161 | + auto encoder_out = outputs->elements()[0].toTensor(); |
| 162 | + auto encoder_out_lens = outputs->elements()[1].toTensor(); |
| 163 | + |
| 164 | + K2_LOG(INFO) << "Using " << decoding_method; |
| 165 | + |
| 166 | + std::vector<std::vector<int32_t>> hyp_tokens; |
| 167 | + if (decoding_method == "greedy_search") { |
| 168 | + hyp_tokens = k2::GreedySearch(module, encoder_out, encoder_out_lens.cpu()); |
| 169 | + } else { |
| 170 | + hyp_tokens = |
| 171 | + k2::ModifiedBeamSearch(module, encoder_out, encoder_out_lens.cpu()); |
| 172 | + } |
| 173 | + |
| 174 | + sentencepiece::SentencePieceProcessor processor; |
| 175 | + auto status = processor.Load(bpe_model); |
| 176 | + K2_CHECK(status.ok()) << status.ToString(); |
| 177 | + |
| 178 | + std::vector<std::string> texts; |
| 179 | + for (const auto &ids : hyp_tokens) { |
| 180 | + std::string text; |
| 181 | + status = processor.Decode(ids, &text); |
| 182 | + K2_CHECK(status.ok()) << status.ToString(); |
| 183 | + texts.emplace_back(std::move(text)); |
| 184 | + } |
| 185 | + |
| 186 | + std::ostringstream os; |
| 187 | + os << "\nDecoding result:\n\n"; |
| 188 | + for (int32_t i = 0; i != num_waves; ++i) { |
| 189 | + os << wave_filenames[i] << "\n"; |
| 190 | + os << texts[i]; |
| 191 | + os << "\n\n"; |
| 192 | + } |
| 193 | + K2_LOG(INFO) << os.str(); |
| 194 | +}; |
0 commit comments