Skip to content

Commit 9548541

Browse files
authored
Support English for MeloTTS models. (#1134)
1 parent fa07bbc commit 9548541

File tree

5 files changed

+99
-39
lines changed

5 files changed

+99
-39
lines changed

.github/workflows/windows-x64-jni.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
strategy:
2121
fail-fast: false
2222
matrix:
23-
os: [windows-latest]
23+
os: [windows-2019]
2424

2525
steps:
2626
- uses: actions/checkout@v4

scripts/melo-tts/export-onnx.py

+22-21
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66
from melo.api import TTS
77
from melo.text import language_id_map, language_tone_start_map
88
from melo.text.chinese import pinyin_to_symbol_map
9+
from melo.text.english import eng_dict, refine_syllables
910
from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict
11+
from melo.text.symbols import language_tone_start_map
1012

1113
for k, v in pinyin_to_symbol_map.items():
14+
if isinstance(v, list):
15+
break
1216
pinyin_to_symbol_map[k] = v.split()
1317

1418

@@ -79,6 +83,16 @@ def generate_lexicon():
7983
word_dict = pinyin_dict.pinyin_dict
8084
phrases = phrases_dict.phrases_dict
8185
with open("lexicon.txt", "w", encoding="utf-8") as f:
86+
for word in eng_dict:
87+
phones, tones = refine_syllables(eng_dict[word])
88+
tones = [t + language_tone_start_map["EN"] for t in tones]
89+
tones = [str(t) for t in tones]
90+
91+
phones = " ".join(phones)
92+
tones = " ".join(tones)
93+
94+
f.write(f"{word.lower()} {phones} {tones}\n")
95+
8296
for key in word_dict:
8397
if not (0x4E00 <= key <= 0x9FA5):
8498
continue
@@ -125,15 +139,13 @@ class ModelWrapper(torch.nn.Module):
125139
def __init__(self, model: "SynthesizerTrn"):
126140
super().__init__()
127141
self.model = model
142+
self.lang_id = language_id_map[model.language]
128143

129144
def forward(
130145
self,
131146
x,
132147
x_lengths,
133148
tones,
134-
lang_id,
135-
bert,
136-
ja_bert,
137149
sid,
138150
noise_scale,
139151
length_scale,
@@ -147,7 +159,11 @@ def forward(
147159
lang_id: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
148160
sid: an integer
149161
"""
150-
return self.model.infer(
162+
bert = torch.zeros(x.shape[0], 1024, x.shape[1], dtype=torch.float32)
163+
ja_bert = torch.zeros(x.shape[0], 768, x.shape[1], dtype=torch.float32)
164+
lang_id = torch.zeros_like(x)
165+
lang_id[:, 1::2] = self.lang_id
166+
return self.model.model.infer(
151167
x=x,
152168
x_lengths=x_lengths,
153169
sid=sid,
@@ -169,27 +185,21 @@ def main():
169185

170186
generate_tokens(model.hps["symbols"])
171187

172-
torch_model = ModelWrapper(model.model)
188+
torch_model = ModelWrapper(model)
173189

174190
opset_version = 13
175191
x = torch.randint(low=0, high=10, size=(60,), dtype=torch.int64)
176192
print(x.shape)
177193
x_lengths = torch.tensor([x.size(0)], dtype=torch.int64)
178194
sid = torch.tensor([1], dtype=torch.int64)
179195
tones = torch.zeros_like(x)
180-
lang_id = torch.ones_like(x)
196+
181197
noise_scale = torch.tensor([1.0], dtype=torch.float32)
182198
length_scale = torch.tensor([1.0], dtype=torch.float32)
183199
noise_scale_w = torch.tensor([1.0], dtype=torch.float32)
184200

185-
bert = torch.zeros(1024, x.shape[0], dtype=torch.float32)
186-
ja_bert = torch.zeros(768, x.shape[0], dtype=torch.float32)
187-
188201
x = x.unsqueeze(0)
189202
tones = tones.unsqueeze(0)
190-
lang_id = lang_id.unsqueeze(0)
191-
bert = bert.unsqueeze(0)
192-
ja_bert = ja_bert.unsqueeze(0)
193203

194204
filename = "model.onnx"
195205

@@ -199,9 +209,6 @@ def main():
199209
x,
200210
x_lengths,
201211
tones,
202-
lang_id,
203-
bert,
204-
ja_bert,
205212
sid,
206213
noise_scale,
207214
length_scale,
@@ -213,9 +220,6 @@ def main():
213220
"x",
214221
"x_lengths",
215222
"tones",
216-
"lang_id",
217-
"bert",
218-
"ja_bert",
219223
"sid",
220224
"noise_scale",
221225
"length_scale",
@@ -226,9 +230,6 @@ def main():
226230
"x": {0: "N", 1: "L"},
227231
"x_lengths": {0: "N"},
228232
"tones": {0: "N", 1: "L"},
229-
"lang_id": {0: "N", 1: "L"},
230-
"bert": {0: "N", 2: "L"},
231-
"ja_bert": {0: "N", 2: "L"},
232233
"y": {0: "N", 1: "S", 2: "T"},
233234
},
234235
)

scripts/melo-tts/run.sh

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ echo "pwd: $PWD"
2828

2929
ls -lh
3030

31+
./show-info.py
32+
3133
head lexicon.txt
3234
echo "---"
3335
tail lexicon.txt

scripts/melo-tts/show-info.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
3+
4+
import onnxruntime
5+
6+
7+
def show(filename):
8+
session_opts = onnxruntime.SessionOptions()
9+
session_opts.log_severity_level = 3
10+
sess = onnxruntime.InferenceSession(filename, session_opts)
11+
for i in sess.get_inputs():
12+
print(i)
13+
14+
print("-----")
15+
16+
for i in sess.get_outputs():
17+
print(i)
18+
19+
meta = sess.get_modelmeta().custom_metadata_map
20+
print("*****************************************")
21+
print("meta\n", meta)
22+
23+
24+
def main():
25+
print("=========model==========")
26+
show("./model.onnx")
27+
28+
29+
if __name__ == "__main__":
30+
main()
31+
32+
"""
33+
=========model==========
34+
NodeArg(name='x', type='tensor(int64)', shape=['N', 'L'])
35+
NodeArg(name='x_lengths', type='tensor(int64)', shape=['N'])
36+
NodeArg(name='tones', type='tensor(int64)', shape=['N', 'L'])
37+
NodeArg(name='sid', type='tensor(int64)', shape=[1])
38+
NodeArg(name='noise_scale', type='tensor(float)', shape=[1])
39+
NodeArg(name='length_scale', type='tensor(float)', shape=[1])
40+
NodeArg(name='noise_scale_w', type='tensor(float)', shape=[1])
41+
-----
42+
NodeArg(name='y', type='tensor(float)', shape=['N', 'S', 'T'])
43+
*****************************************
44+
meta
45+
{'description': 'MeloTTS is a high-quality multi-lingual text-to-speech library by MyShell.ai',
46+
'model_type': 'melo-vits', 'license': 'MIT license', 'sample_rate': '44100', 'add_blank': '1',
47+
'n_speakers': '1', 'bert_dim': '1024', 'language': 'Chinese + English',
48+
'ja_bert_dim': '768', 'speaker_id': '1', 'comment': 'melo', 'lang_id': '3',
49+
'tone_start': '0', 'url': 'https://github.com/myshell-ai/MeloTTS'}
50+
"""

scripts/melo-tts/test.py

+24-17
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def __init__(self, lexion_filename: str, tokens_filename: str):
3030
tones = [int(t) for t in tones]
3131

3232
lexicon[word_or_phrase] = (phones, tones)
33+
lexicon["呣"] = lexicon["母"]
34+
lexicon["嗯"] = lexicon["恩"]
3335
self.lexicon = lexicon
3436

3537
punctuation = ["!", "?", "…", ",", ".", "'", "-"]
@@ -98,20 +100,16 @@ def __init__(self, filename):
98100
self.lang_id = int(meta["lang_id"])
99101
self.sample_rate = int(meta["sample_rate"])
100102

101-
def __call__(self, x, tones, lang):
103+
def __call__(self, x, tones):
102104
"""
103105
Args:
104106
x: 1-D int64 torch tensor
105107
tones: 1-D int64 torch tensor
106-
lang: 1-D int64 torch tensor
107108
"""
108109
x = x.unsqueeze(0)
109110
tones = tones.unsqueeze(0)
110-
lang = lang.unsqueeze(0)
111111

112-
print(x.shape, tones.shape, lang.shape)
113-
bert = torch.zeros(1, self.bert_dim, x.shape[-1])
114-
ja_bert = torch.zeros(1, self.ja_bert_dim, x.shape[-1])
112+
print(x.shape, tones.shape)
115113
sid = torch.tensor([self.speaker_id], dtype=torch.int64)
116114
noise_scale = torch.tensor([0.6], dtype=torch.float32)
117115
length_scale = torch.tensor([1.0], dtype=torch.float32)
@@ -125,9 +123,6 @@ def __call__(self, x, tones, lang):
125123
"x": x.numpy(),
126124
"x_lengths": x_lengths.numpy(),
127125
"tones": tones.numpy(),
128-
"lang_id": lang.numpy(),
129-
"bert": bert.numpy(),
130-
"ja_bert": ja_bert.numpy(),
131126
"sid": sid.numpy(),
132127
"noise_scale": noise_scale.numpy(),
133128
"noise_scale_w": noise_scale_w.numpy(),
@@ -140,34 +135,46 @@ def __call__(self, x, tones, lang):
140135
def main():
141136
lexicon = Lexicon(lexion_filename="./lexicon.txt", tokens_filename="./tokens.txt")
142137

143-
text = "永远相信,美好的事情即将发生。多音字测试, 银行,行不行?长沙长大"
138+
text = "永远相信,美好的事情即将发生。"
144139
s = jieba.cut(text, HMM=True)
145140

146141
phones, tones = lexicon.convert(s)
147142

143+
en_text = "how are you ?".split()
144+
145+
phones_en, tones_en = lexicon.convert(en_text)
146+
phones += [0]
147+
tones += [0]
148+
149+
phones += phones_en
150+
tones += tones_en
151+
152+
text = "多音字测试, 银行,行不行?长沙长大"
153+
s = jieba.cut(text, HMM=True)
154+
155+
phones2, tones2 = lexicon.convert(s)
156+
157+
phones += phones2
158+
tones += tones2
159+
148160
model = OnnxModel("./model.onnx")
149-
langs = [model.lang_id] * len(phones)
150161

151162
if model.add_blank:
152163
new_phones = [0] * (2 * len(phones) + 1)
153164
new_tones = [0] * (2 * len(tones) + 1)
154-
new_langs = [0] * (2 * len(langs) + 1)
155165

156166
new_phones[1::2] = phones
157167
new_tones[1::2] = tones
158-
new_langs[1::2] = langs
159168

160169
phones = new_phones
161170
tones = new_tones
162-
langs = new_langs
163171

164172
phones = torch.tensor(phones, dtype=torch.int64)
165173
tones = torch.tensor(tones, dtype=torch.int64)
166-
langs = torch.tensor(langs, dtype=torch.int64)
167174

168-
print(phones.shape, tones.shape, langs.shape)
175+
print(phones.shape, tones.shape)
169176

170-
y = model(x=phones, tones=tones, lang=langs)
177+
y = model(x=phones, tones=tones)
171178
sf.write("./test.wav", y, model.sample_rate)
172179

173180

0 commit comments

Comments
 (0)