forked from k2-fsa/sherpa-onnx
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathgenerate_bbpe_table.py
executable file
·67 lines (57 loc) · 1.81 KB
/
generate_bbpe_table.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_bpe.py#L28
# and
# https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py
#
# Caution: The PRINTABLE_LATIN from fairseq is different from PRINTABLE_BASE_CHARS from icefall
import re
BPE_UNK = chr(8263)
PRINTABLE_BASE_CHARS = (
list(range(256, 287 + 1))
+ list(range(32, 126 + 1))
+ list(range(288, 305 + 1))
+ list(range(308, 318 + 1))
+ list(range(321, 328 + 1))
+ list(range(330, 382 + 1))
+ list(range(384, 422 + 1))
)
BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)}
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
BCHAR_TO_BYTE[BPE_UNK] = 32 # map unk to space
def main():
s = ""
s += "// sherpa-onnx/csrc/bbpe.cc\n"
s += "//\n"
s += "// Copyright (c) 2024 Xiaomi Corporation\n"
s += "\n"
s += "// Auto-generated! DO NOT EDIT\n"
s += "\n"
s += '#include "sherpa-onnx/csrc/bbpe.h"\n'
s += "\n"
s += "#include <cstdint>\n"
s += "#include <string>\n"
s += "#include <unordered_map>\n"
s += "\n"
s += "const std::unordered_map<std::string, uint8_t> &GetByteBpeTable() {\n"
s += " static const std::unordered_map<std::string, uint8_t> table = {\n"
s += " "
for i, (k, v) in enumerate(BCHAR_TO_BYTE.items()):
s += "{"
if k in ["\\", '"']:
s += f'"\{k}", {v}'
else:
s += f'"{k}", {v}'
s += "}, "
if i > 0 and i % 7 == 0:
s += "\n"
s += " "
s += "};\n"
s += "\n"
s += " return table\n;"
s += "}\n"
with open("bbpe.cc", "w", encoding="utf-8") as f:
f.write(s)
if __name__ == "__main__":
main()