Skip to content

Commit b76cd90

Browse files
authored
Support decoding with byte-level BPE (bbpe) models. (#1633)
1 parent 7192e57 commit b76cd90

11 files changed

+270
-10
lines changed

scripts/bbpe/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
bbpe.cc

scripts/bbpe/generate_bbpe_table.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
3+
#
4+
# See https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_bpe.py#L28
5+
# and
6+
# https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py
7+
#
8+
# Caution: The PRINTABLE_LATIN from fairseq is different from PRINTABLE_BASE_CHARS from icefall
9+
10+
import re
11+
12+
BPE_UNK = chr(8263)
13+
PRINTABLE_BASE_CHARS = (
14+
list(range(256, 287 + 1))
15+
+ list(range(32, 126 + 1))
16+
+ list(range(288, 305 + 1))
17+
+ list(range(308, 318 + 1))
18+
+ list(range(321, 328 + 1))
19+
+ list(range(330, 382 + 1))
20+
+ list(range(384, 422 + 1))
21+
)
22+
23+
24+
BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)}
25+
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
26+
BCHAR_TO_BYTE[BPE_UNK] = 32 # map unk to space
27+
28+
29+
def main():
30+
s = ""
31+
s += "// sherpa-onnx/csrc/bbpe.cc\n"
32+
s += "//\n"
33+
s += "// Copyright (c) 2024 Xiaomi Corporation\n"
34+
s += "\n"
35+
s += "// Auto-generated! DO NOT EDIT\n"
36+
s += "\n"
37+
s += '#include "sherpa-onnx/csrc/bbpe.h"\n'
38+
s += "\n"
39+
s += "#include <cstdint>\n"
40+
s += "#include <string>\n"
41+
s += "#include <unordered_map>\n"
42+
s += "\n"
43+
s += "const std::unordered_map<std::string, uint8_t> &GetByteBpeTable() {\n"
44+
s += " static const std::unordered_map<std::string, uint8_t> table = {\n"
45+
46+
s += " "
47+
for i, (k, v) in enumerate(BCHAR_TO_BYTE.items()):
48+
s += "{"
49+
if k in ["\\", '"']:
50+
s += f'"\{k}", {v}'
51+
else:
52+
s += f'"{k}", {v}'
53+
s += "}, "
54+
if i > 0 and i % 7 == 0:
55+
s += "\n"
56+
s += " "
57+
s += "};\n"
58+
s += "\n"
59+
s += " return table\n;"
60+
s += "}\n"
61+
62+
with open("bbpe.cc", "w", encoding="utf-8") as f:
63+
f.write(s)
64+
65+
66+
if __name__ == "__main__":
67+
main()

sherpa-onnx/csrc/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ endif()
1212

1313
set(sources
1414
base64-decode.cc
15+
bbpe.cc
1516
cat.cc
1617
circular-buffer.cc
1718
context-graph.cc
@@ -78,11 +79,11 @@ set(sources
7879
online-stream.cc
7980
online-transducer-decoder.cc
8081
online-transducer-greedy-search-decoder.cc
82+
online-transducer-greedy-search-nemo-decoder.cc
8183
online-transducer-model-config.cc
8284
online-transducer-model.cc
8385
online-transducer-modified-beam-search-decoder.cc
8486
online-transducer-nemo-model.cc
85-
online-transducer-greedy-search-nemo-decoder.cc
8687
online-wenet-ctc-model-config.cc
8788
online-wenet-ctc-model.cc
8889
online-zipformer-transducer-model.cc

sherpa-onnx/csrc/bbpe.cc

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// sherpa-onnx/csrc/bbpe.cc
2+
//
3+
// Copyright (c) 2024 Xiaomi Corporation
4+
5+
// Auto-generated! DO NOT EDIT
6+
7+
#include "sherpa-onnx/csrc/bbpe.h"
8+
9+
#include <cstdint>
10+
#include <string>
11+
#include <unordered_map>
12+
13+
const std::unordered_map<std::string, uint8_t> &GetByteBpeTable() {
14+
static const std::unordered_map<std::string, uint8_t> table = {
15+
{"Ā", 0}, {"ā", 1}, {"Ă", 2}, {"ă", 3}, {"Ą", 4}, {"ą", 5},
16+
{"Ć", 6}, {"ć", 7}, {"Ĉ", 8}, {"ĉ", 9}, {"Ċ", 10}, {"ċ", 11},
17+
{"Č", 12}, {"č", 13}, {"Ď", 14}, {"ď", 15}, {"Đ", 16}, {"đ", 17},
18+
{"Ē", 18}, {"ē", 19}, {"Ĕ", 20}, {"ĕ", 21}, {"Ė", 22}, {"ė", 23},
19+
{"Ę", 24}, {"ę", 25}, {"Ě", 26}, {"ě", 27}, {"Ĝ", 28}, {"ĝ", 29},
20+
{"Ğ", 30}, {"ğ", 31}, {" ", 32}, {"!", 33}, {"\"", 34}, {"#", 35},
21+
{"$", 36}, {"%", 37}, {"&", 38}, {"'", 39}, {"(", 40}, {")", 41},
22+
{"*", 42}, {"+", 43}, {",", 44}, {"-", 45}, {".", 46}, {"/", 47},
23+
{"0", 48}, {"1", 49}, {"2", 50}, {"3", 51}, {"4", 52}, {"5", 53},
24+
{"6", 54}, {"7", 55}, {"8", 56}, {"9", 57}, {":", 58}, {";", 59},
25+
{"<", 60}, {"=", 61}, {">", 62}, {"?", 63}, {"@", 64}, {"A", 65},
26+
{"B", 66}, {"C", 67}, {"D", 68}, {"E", 69}, {"F", 70}, {"G", 71},
27+
{"H", 72}, {"I", 73}, {"J", 74}, {"K", 75}, {"L", 76}, {"M", 77},
28+
{"N", 78}, {"O", 79}, {"P", 80}, {"Q", 81}, {"R", 82}, {"S", 83},
29+
{"T", 84}, {"U", 85}, {"V", 86}, {"W", 87}, {"X", 88}, {"Y", 89},
30+
{"Z", 90}, {"[", 91}, {"\\", 92}, {"]", 93}, {"^", 94}, {"_", 95},
31+
{"`", 96}, {"a", 97}, {"b", 98}, {"c", 99}, {"d", 100}, {"e", 101},
32+
{"f", 102}, {"g", 103}, {"h", 104}, {"i", 105}, {"j", 106}, {"k", 107},
33+
{"l", 108}, {"m", 109}, {"n", 110}, {"o", 111}, {"p", 112}, {"q", 113},
34+
{"r", 114}, {"s", 115}, {"t", 116}, {"u", 117}, {"v", 118}, {"w", 119},
35+
{"x", 120}, {"y", 121}, {"z", 122}, {"{", 123}, {"|", 124}, {"}", 125},
36+
{"~", 126}, {"Ġ", 127}, {"ġ", 128}, {"Ģ", 129}, {"ģ", 130}, {"Ĥ", 131},
37+
{"ĥ", 132}, {"Ħ", 133}, {"ħ", 134}, {"Ĩ", 135}, {"ĩ", 136}, {"Ī", 137},
38+
{"ī", 138}, {"Ĭ", 139}, {"ĭ", 140}, {"Į", 141}, {"į", 142}, {"İ", 143},
39+
{"ı", 144}, {"Ĵ", 145}, {"ĵ", 146}, {"Ķ", 147}, {"ķ", 148}, {"ĸ", 149},
40+
{"Ĺ", 150}, {"ĺ", 151}, {"Ļ", 152}, {"ļ", 153}, {"Ľ", 154}, {"ľ", 155},
41+
{"Ł", 156}, {"ł", 157}, {"Ń", 158}, {"ń", 159}, {"Ņ", 160}, {"ņ", 161},
42+
{"Ň", 162}, {"ň", 163}, {"Ŋ", 164}, {"ŋ", 165}, {"Ō", 166}, {"ō", 167},
43+
{"Ŏ", 168}, {"ŏ", 169}, {"Ő", 170}, {"ő", 171}, {"Œ", 172}, {"œ", 173},
44+
{"Ŕ", 174}, {"ŕ", 175}, {"Ŗ", 176}, {"ŗ", 177}, {"Ř", 178}, {"ř", 179},
45+
{"Ś", 180}, {"ś", 181}, {"Ŝ", 182}, {"ŝ", 183}, {"Ş", 184}, {"ş", 185},
46+
{"Š", 186}, {"š", 187}, {"Ţ", 188}, {"ţ", 189}, {"Ť", 190}, {"ť", 191},
47+
{"Ŧ", 192}, {"ŧ", 193}, {"Ũ", 194}, {"ũ", 195}, {"Ū", 196}, {"ū", 197},
48+
{"Ŭ", 198}, {"ŭ", 199}, {"Ů", 200}, {"ů", 201}, {"Ű", 202}, {"ű", 203},
49+
{"Ų", 204}, {"ų", 205}, {"Ŵ", 206}, {"ŵ", 207}, {"Ŷ", 208}, {"ŷ", 209},
50+
{"Ÿ", 210}, {"Ź", 211}, {"ź", 212}, {"Ż", 213}, {"ż", 214}, {"Ž", 215},
51+
{"ž", 216}, {"ƀ", 217}, {"Ɓ", 218}, {"Ƃ", 219}, {"ƃ", 220}, {"Ƅ", 221},
52+
{"ƅ", 222}, {"Ɔ", 223}, {"Ƈ", 224}, {"ƈ", 225}, {"Ɖ", 226}, {"Ɗ", 227},
53+
{"Ƌ", 228}, {"ƌ", 229}, {"ƍ", 230}, {"Ǝ", 231}, {"Ə", 232}, {"Ɛ", 233},
54+
{"Ƒ", 234}, {"ƒ", 235}, {"Ɠ", 236}, {"Ɣ", 237}, {"ƕ", 238}, {"Ɩ", 239},
55+
{"Ɨ", 240}, {"Ƙ", 241}, {"ƙ", 242}, {"ƚ", 243}, {"ƛ", 244}, {"Ɯ", 245},
56+
{"Ɲ", 246}, {"ƞ", 247}, {"Ɵ", 248}, {"Ơ", 249}, {"ơ", 250}, {"Ƣ", 251},
57+
{"ƣ", 252}, {"Ƥ", 253}, {"ƥ", 254}, {"Ʀ", 255}, {"", 32},
58+
};
59+
60+
return table;
61+
}

sherpa-onnx/csrc/bbpe.h

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// sherpa-onnx/csrc/bbpe.h
2+
//
3+
// Copyright (c) 2024 Xiaomi Corporation
4+
5+
#ifndef SHERPA_ONNX_CSRC_BBPE_H_
6+
#define SHERPA_ONNX_CSRC_BBPE_H_
7+
#include <cstdint>
8+
#include <string>
9+
#include <unordered_map>
10+
11+
// It is equivalent to the map BCHAR_TO_BYTE
12+
// from
13+
// https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py#L280
14+
const std::unordered_map<std::string, uint8_t> &GetByteBpeTable();
15+
16+
#endif // SHERPA_ONNX_CSRC_BBPE_H_

sherpa-onnx/csrc/offline-recognizer-ctc-impl.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
4141
text.append(sym);
4242

4343
if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
44-
// for byte bpe models
44+
// for bpe models with byte_fallback
4545
// (but don't rewrite printable characters 0x20..0x7e,
4646
// which collide with standard BPE units)
4747
std::ostringstream os;
@@ -52,6 +52,11 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
5252

5353
r.tokens.push_back(std::move(sym));
5454
}
55+
56+
if (sym_table.IsByteBpe()) {
57+
text = sym_table.DecodeByteBpe(text);
58+
}
59+
5560
r.text = std::move(text);
5661

5762
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;

sherpa-onnx/csrc/offline-recognizer-transducer-impl.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ static OfflineRecognitionResult Convert(
4343
text.append(sym);
4444

4545
if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
46-
// for byte bpe models,
46+
// for bpe models with byte_fallback,
4747
// (but don't rewrite printable characters 0x20..0x7e,
4848
// which collide with standard BPE units)
4949
std::ostringstream os;
@@ -54,6 +54,10 @@ static OfflineRecognitionResult Convert(
5454

5555
r.tokens.push_back(std::move(sym));
5656
}
57+
if (sym_table.IsByteBpe()) {
58+
text = sym_table.DecodeByteBpe(text);
59+
}
60+
5761
r.text = std::move(text);
5862

5963
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;

sherpa-onnx/csrc/online-recognizer-ctc-impl.h

+9-2
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,14 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
3434
r.tokens.reserve(src.tokens.size());
3535
r.timestamps.reserve(src.tokens.size());
3636

37+
std::string text;
3738
for (auto i : src.tokens) {
3839
auto sym = sym_table[i];
3940

40-
r.text.append(sym);
41+
text.append(sym);
4142

4243
if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
43-
// for byte bpe models
44+
// for bpe models with byte_fallback
4445
// (but don't rewrite printable characters 0x20..0x7e,
4546
// which collide with standard BPE units)
4647
std::ostringstream os;
@@ -52,6 +53,12 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
5253
r.tokens.push_back(std::move(sym));
5354
}
5455

56+
if (sym_table.IsByteBpe()) {
57+
text = sym_table.DecodeByteBpe(text);
58+
}
59+
60+
r.text = std::move(text);
61+
5562
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
5663
for (auto t : src.timestamps) {
5764
float time = frame_shift_s * t;

sherpa-onnx/csrc/online-recognizer-transducer-impl.h

+9-2
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,14 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
3838
r.tokens.reserve(src.tokens.size());
3939
r.timestamps.reserve(src.tokens.size());
4040

41+
std::string text;
4142
for (auto i : src.tokens) {
4243
auto sym = sym_table[i];
4344

44-
r.text.append(sym);
45+
text.append(sym);
4546

4647
if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
47-
// for byte bpe models
48+
// for bpe models with byte_fallback
4849
// (but don't rewrite printable characters 0x20..0x7e,
4950
// which collide with standard BPE units)
5051
std::ostringstream os;
@@ -56,6 +57,12 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
5657
r.tokens.push_back(std::move(sym));
5758
}
5859

60+
if (sym_table.IsByteBpe()) {
61+
text = sym_table.DecodeByteBpe(text);
62+
}
63+
64+
r.text = std::move(text);
65+
5966
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
6067
for (auto t : src.timestamps) {
6168
float time = frame_shift_s * t;

0 commit comments

Comments
 (0)