Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support rescoring with an n-gram LM during decoding #867

Merged
merged 8 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ Fsa Closure(Fsa &fsa, Array1<int32_t> *arc_map = nullptr);
@param [in] labels_shape This might correspond to the shape of the
`aux_labels`; it is a shape with
`labels_shape.NumAxes() == 2` and
`arcs.shape.Dim0() == fsas.NumElements()`.
`labels_shape.Dim0() == fsas.NumElements()`.
The i'th arc of the FsaVec will be expanded to a
sequence of `max(1, l)` arcs, where l is the
length of the i'th list in `labels_shape`
Expand Down Expand Up @@ -704,7 +704,7 @@ FsaOrVec ExpandArcs(FsaOrVec &fsas, RaggedShape &labels_shape,
to n (which also implies that aux_labels for
final-arc must at least contain -1).
For other arcs that are not final-arcs,
the corresponding aux_labels must contain no
the corresponding aux_labels must not contain
-1.
@param [out] dest Output Fsa or FsaVec, it's the inverted Fsa. At exit
dest.NumAxes() == src.NumAxes() and num-states of it
Expand Down
150 changes: 113 additions & 37 deletions k2/torch/bin/decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <string>
#include <vector>

#include "k2/csrc/fsa_algo.h"
#include "k2/torch/csrc/decode.h"
#include "k2/torch/csrc/dense_fsa_vec.h"
#include "k2/torch/csrc/deserialization.h"
Expand All @@ -35,16 +36,45 @@
#include "torch/script.h"
#include "torch/utils.h"

// TODO(fangjun):
// Refactor this file.
//
// Create a binary for each decoding method.
// Don't put all decoding methods in a single binary.

enum class DecodingMethod {
kInvalid,
kCtcDecoding,
kHLG,
kNgramRescroing,
kAttentionRescoring,
};

C10_DEFINE_bool(use_gpu, false, "True to use GPU. False to use CPU");
C10_DEFINE_string(jit_pt, "", "Path to exported jit file.");
C10_DEFINE_string(
bpe_model, "",
"Path to a pretrained BPE model. Needed if --use_ctc_decoding is true");
C10_DEFINE_bool(use_ctc_decoding, true, "True to use CTC decoding");
"Path to a pretrained BPE model. Needed if --method is 'ctc-decoding'");
C10_DEFINE_string(method, "", R"(Decoding method.
Supported values are:
- ctc-decoding. Use CTC topology for decoding. You have to
provide --bpe_model.
- hlg. Use HLG graph for decoding.
- ngram-rescoring. Use HLG for decoding and an n-gram LM for rescoring.
You have to provide --G.
- attention-rescoring. Use HLG for decoding, an n-gram LM and a
attention decoder for rescoring.
)");
C10_DEFINE_string(hlg, "",
"Path to HLG.pt. Needed if --use_ctc_decoding is false");
C10_DEFINE_string(word_table, "",
"Path to words.txt. Needed if --use_ctc_decoding is false");
"Path to HLG.pt. Needed if --method is not 'ctc-decoding'");
C10_DEFINE_string(g, "",
"Path to an ngram LM, e.g, G_4gram.pt. Needed "
"if --method is 'ngram-rescoring' or 'attention-rescoring'");
C10_DEFINE_double(ngram_lm_scale, 1.0,
"Used only when method is ngram-rescoring");
C10_DEFINE_string(
word_table, "",
"Path to words.txt. Needed if --method is not 'ctc-decoding'");
// Fsa decoding related
C10_DEFINE_double(search_beam, 20, "search_beam in IntersectDensePruned");
C10_DEFINE_double(output_beam, 8, "output_beam in IntersectDensePruned");
Expand All @@ -60,42 +90,41 @@ C10_DEFINE_double(frame_length_ms, 25.0,
"Frame length in ms for computing Fbank");
C10_DEFINE_int(num_bins, 80, "Number of triangular bins for computing Fbank");

static void CheckArgs() {
static void CheckArgs(DecodingMethod method) {
#if !defined(K2_WITH_CUDA)
if (FLAGS_use_gpu) {
std::cerr << "k2 was not compiled with CUDA"
<< "\n";
std::cerr << "Please use --use_gpu 0"
<< "\n";
std::cerr << "k2 was not compiled with CUDA. "
"Please use --use_gpu false";
exit(EXIT_FAILURE);
}
#endif

if (FLAGS_jit_pt.empty()) {
std::cerr << "Please provide --jit_pt"
<< "\n";
std::cerr << torch::UsageMessage() << "\n";
std::cerr << "Please provide --jit_pt\n" << torch::UsageMessage();
exit(EXIT_FAILURE);
}

if (method == DecodingMethod::kCtcDecoding && FLAGS_bpe_model.empty()) {
std::cerr << "Please provide --bpe_model\n"
<< torch::UsageMessage() << "\n";
exit(EXIT_FAILURE);
}

if (FLAGS_use_ctc_decoding && FLAGS_bpe_model.empty()) {
std::cout << "Please provide --bpe_model"
<< "\n";
std::cout << torch::UsageMessage() << "\n";
if (method != DecodingMethod::kCtcDecoding && FLAGS_hlg.empty()) {
std::cerr << "Please provide --hlg\n" << torch::UsageMessage() << "\n";
exit(EXIT_FAILURE);
}

if (FLAGS_use_ctc_decoding == false && FLAGS_hlg.empty()) {
std::cerr << "Please provide --hlg"
<< "\n";
std::cerr << torch::UsageMessage() << "\n";
if (method != DecodingMethod::kCtcDecoding && FLAGS_word_table.empty()) {
std::cerr << "Please provide --word_table\n"
<< torch::UsageMessage() << "\n";
exit(EXIT_FAILURE);
}

if (FLAGS_use_ctc_decoding == false && FLAGS_word_table.empty()) {
std::cerr << "Please provide --word_table"
<< "\n";
std::cerr << torch::UsageMessage() << "\n";
if ((method == DecodingMethod::kNgramRescroing ||
method == DecodingMethod::kAttentionRescoring) &&
FLAGS_g.empty()) {
std::cerr << "Please provide --g\n" << torch::UsageMessage() << "\n";
exit(EXIT_FAILURE);
}
}
Expand All @@ -109,15 +138,25 @@ int main(int argc, char *argv[]) {
std::string usage = R"(
(1) CTC decoding
./bin/decode \
--use_ctc_decoding true \
--method ctc-decoding \
--jit_pt <path to exported torch script pt file> \
--bpe_model <path to pretrained BPE model> \
/path/to/foo.wav \
/path/to/bar.wav \
<more wave files if any>
(2) HLG decoding
./bin/decode \
--use_ctc_decoding false \
--method hlg \
--jit_pt <path to exported torch script pt file> \
--hlg <path to HLG.pt> \
--word_table <path to words.txt> \
/path/to/foo.wav \
/path/to/bar.wav \
<more wave files if any>
(3) HLG decoding + ngram LM rescoring
./bin/decode \
--method ngram-rescoring \
--g <path to G.pt> \
--jit_pt <path to exported torch script pt file> \
--hlg <path to HLG.pt> \
--word_table <path to words.txt> \
Expand All @@ -127,11 +166,29 @@ int main(int argc, char *argv[]) {

--use_gpu false to use CPU
--use_gpu true to use GPU

./bin/decode --help
to view all possible options.
)";
torch::SetUsageMessage(usage);

DecodingMethod method = DecodingMethod::kInvalid;
torch::ParseCommandLineFlags(&argc, &argv);
CheckArgs();
if (FLAGS_method == "ctc-decoding") {
method = DecodingMethod::kCtcDecoding;
} else if (FLAGS_method == "hlg") {
method = DecodingMethod::kHLG;
} else if (FLAGS_method == "ngram-rescoring") {
method = DecodingMethod::kNgramRescroing;
} else if (FLAGS_method == "attention-rescoring") {
// method = DecodingMethod::kAttentionRescoring;
K2_LOG(FATAL) << "Not implemented yet for: " << FLAGS_method;
} else {
K2_LOG(FATAL) << "Unsupported method: " << FLAGS_method << "\n"
<< torch::UsageMessage();
}

CheckArgs(method);

torch::Device device(torch::kCPU);
if (FLAGS_use_gpu) {
Expand Down Expand Up @@ -187,13 +244,11 @@ int main(int argc, char *argv[]) {

torch::IValue supervisions(sup);

std::vector<torch::IValue> inputs;
inputs.emplace_back(std::move(features));
inputs.emplace_back(supervisions);

K2_LOG(INFO) << "Compute nnet_output";
// the output for module.forward() is a tuple of 3 tensors
auto outputs = module.forward(inputs).toTuple();
// See the definition of the model in conformer_ctc/transformer.py
// from icefall
auto outputs = module.run_method("forward", features, supervisions).toTuple();
assert(outputs->elements().size() == 3u);

auto nnet_output = outputs->elements()[0].toTensor();
Expand All @@ -207,13 +262,20 @@ int main(int argc, char *argv[]) {

k2::FsaClass decoding_graph;

if (FLAGS_use_ctc_decoding) {
if (method == DecodingMethod::kCtcDecoding) {
K2_LOG(INFO) << "Build CTC topo";
decoding_graph =
k2::CtcTopo(nnet_output.size(2) - 1, /*modified*/ false, device);
decoding_graph = k2::CtcTopo(nnet_output.size(2) - 1, false, device);
} else {
K2_LOG(INFO) << "Load " << FLAGS_hlg;
decoding_graph = k2::LoadFsa(FLAGS_hlg, device);
K2_CHECK(decoding_graph.HasAttr("aux_labels"));
}

if (method == DecodingMethod::kNgramRescroing ||
method == DecodingMethod::kAttentionRescoring) {
// Add `lm_scores` so that we can separate acoustic scores and lm scores
// later in the rescoring stage.
decoding_graph.SetTensorAttr("lm_scores", decoding_graph.Scores().clone());
}

K2_LOG(INFO) << "Decoding";
Expand All @@ -222,14 +284,28 @@ int main(int argc, char *argv[]) {
FLAGS_output_beam, FLAGS_min_activate_states, FLAGS_max_activate_states,
subsampling_factor);

if (method == DecodingMethod::kNgramRescroing) {
// rescore with an n-gram LM
K2_LOG(INFO) << "Load n-gram LM: " << FLAGS_g;
k2::FsaClass G = k2::LoadFsa(FLAGS_g, device);
G.fsa = k2::FsaToFsaVec(G.fsa);

K2_CHECK_EQ(G.NumAttrs(), 0) << "G is expected to be an acceptor.";
k2::AddEpsilonSelfLoops(G.fsa, &G.fsa);
k2::ArcSort(&G.fsa);
G.SetTensorAttr("lm_scores", G.Scores().clone());

WholeLatticeRescoring(G, FLAGS_ngram_lm_scale, &lattice);
}

lattice = k2::ShortestPath(lattice);

auto ragged_aux_labels = k2::GetTexts(lattice);

auto aux_labels_vec = ragged_aux_labels.ToVecVec();

std::vector<std::string> texts;
if (FLAGS_use_ctc_decoding) {
if (method == DecodingMethod::kCtcDecoding) {
sentencepiece::SentencePieceProcessor processor;
auto status = processor.Load(FLAGS_bpe_model);
if (!status.ok()) {
Expand Down
1 change: 1 addition & 0 deletions k2/torch/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ target_link_libraries(k2_torch PUBLIC ${TORCH_LIBRARIES} context kaldifeat_core)
# Please sort files alphabetically
set(k2_torch_test_srcs
dense_fsa_vec_test.cu
deserialization_test.cu
fsa_class_test.cu
wave_reader_test.cu
)
Expand Down
44 changes: 44 additions & 0 deletions k2/torch/csrc/decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,48 @@ Ragged<int32_t> GetTexts(FsaClass &lattice) {
return ragged_aux_labels;
}

void WholeLatticeRescoring(FsaClass &G, float ngram_lm_scale,
FsaClass *lattice) {
K2_CHECK(lattice->HasAttr("lm_scores"));

torch::Tensor am_scores =
lattice->Scores() - lattice->GetTensorAttr("lm_scores");
lattice->SetScores(am_scores);

// Now, lattice contains only acoustic scores, we will attach LM scores
// from the given n-gram LM
lattice->DeleteAttr("lm_scores");

K2_CHECK_EQ(G.NumAttrs(), 1)
<< "G is expected to contain only 1 attribute: lm_scores.";
K2_CHECK_EQ(G.fsa.NumAxes(), 3);
K2_CHECK_EQ(G.fsa.Dim0(), 1);

k2::Invert(lattice);
// Now lattice has word IDs as labels and token IDs as aux_labels.

// TODO(fangjun): Use Intersect() when device is CPU
auto b_to_a_map =
k2::Array1<int32_t>(G.fsa.Context(), lattice->fsa.Dim0(), 0);
k2::Array1<int32_t> arc_map_a, arc_map_b;

k2::Fsa dest = k2::IntersectDevice(G.fsa, G.Properties(), lattice->fsa,
lattice->Properties(), b_to_a_map,
&arc_map_a, &arc_map_b, true);

lattice->properties = 0;
lattice->fsa = dest;
lattice->CopyAttrs(*lattice, k2::Array1ToTorch(arc_map_b));
lattice->CopyAttrs(G, k2::Array1ToTorch(arc_map_a));
k2::Connect(lattice);
k2::TopSort(lattice);
k2::Invert(lattice);

// Now lattice has token IDs as labels and word IDs as aux_labels
torch::Tensor lm_scores = lattice->GetTensorAttr("lm_scores");
am_scores = lattice->Scores() - lm_scores;
torch::Tensor scores = am_scores / ngram_lm_scale + lm_scores;
lattice->SetScores(scores);
}

} // namespace k2
13 changes: 13 additions & 0 deletions k2/torch/csrc/decode.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ FsaClass GetLattice(torch::Tensor nnet_output, FsaClass &decoding_graph,
*/
Ragged<int32_t> GetTexts(FsaClass &lattice);

/** Rescore a lattice with an n-gram LM.

@param G An acceptor. It MUST be an FsaVec containing only one
arc-sorted FSA. Also, it contains epsilon self loops
(see AddEpsilonSelfLoops()). It contains only one tensor
attribute: "lm_scores".
@param ngram_lm_scale The scale value for ngram LM scores.
@param lattice The input/output lattice. It can be the
return value of `GetLattice()`.
*/
void WholeLatticeRescoring(FsaClass &G, float ngram_lm_scale,
FsaClass *lattice);

} // namespace k2

#endif // K2_TORCH_CSRC_DECODE_H_
4 changes: 2 additions & 2 deletions k2/torch/csrc/deserialization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,8 @@ k2::FsaClass LoadFsa(
if (v.isTensor() || IsRaggedInt(v)) {
ans.SetAttr(p.key().toStringRef(), p.value());
} else {
K2_LOG(INFO) << "Ignore non tensor attribute: '" << p.key().toStringRef()
<< "' of type: " << v.tagKind();
K2_LOG(WARNING) << "Ignore non tensor attribute: '"
<< p.key().toStringRef() << "' of type: " << v.tagKind();
}
}

Expand Down
Loading