Skip to content

Commit 3422b93

Browse files
authored
Add Kotlin API for Matcha-TTS models. (#1668)
1 parent 0a43e9c commit 3422b93

File tree

9 files changed

+117
-9
lines changed

9 files changed

+117
-9
lines changed

.github/workflows/jni.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,8 @@ jobs:
7575
7676
cd ./kotlin-api-examples
7777
./run.sh
78+
79+
- uses: actions/upload-artifact@v4
80+
with:
81+
name: tts-files-${{ matrix.os }}
82+
path: kotlin-api-examples/test-*.wav

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,4 @@ sherpa-onnx-moonshine-tiny-en-int8
125125
sherpa-onnx-moonshine-base-en-int8
126126
harmony-os/SherpaOnnxHar/sherpa_onnx/LICENSE
127127
harmony-os/SherpaOnnxHar/sherpa_onnx/CHANGELOG.md
128+
matcha-icefall-zh-baker

kotlin-api-examples/run.sh

+10
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,16 @@ function testTts() {
105105
rm vits-piper-en_US-amy-low.tar.bz2
106106
fi
107107

108+
if [ ! -f ./matcha-icefall-zh-baker/model-steps-3.onnx ]; then
109+
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2
110+
tar xvf matcha-icefall-zh-baker.tar.bz2
111+
rm matcha-icefall-zh-baker.tar.bz2
112+
fi
113+
114+
if [ ! -f ./hifigan_v2.onnx ]; then
115+
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx
116+
fi
117+
108118
out_filename=test_tts.jar
109119
kotlinc-jvm -include-runtime -d $out_filename \
110120
test_tts.kt \

kotlin-api-examples/test_tts.kt

+27-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,35 @@
11
package com.k2fsa.sherpa.onnx
22

33
fun main() {
4-
testTts()
4+
testVits()
5+
testMatcha()
56
}
67

7-
fun testTts() {
8+
fun testMatcha() {
9+
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
10+
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2
11+
var config = OfflineTtsConfig(
12+
model=OfflineTtsModelConfig(
13+
matcha=OfflineTtsMatchaModelConfig(
14+
acousticModel="./matcha-icefall-zh-baker/model-steps-3.onnx",
15+
vocoder="./hifigan_v2.onnx",
16+
tokens="./matcha-icefall-zh-baker/tokens.txt",
17+
lexicon="./matcha-icefall-zh-baker/lexicon.txt",
18+
dictDir="./matcha-icefall-zh-baker/dict",
19+
),
20+
numThreads=1,
21+
debug=true,
22+
),
23+
ruleFsts="./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst",
24+
)
25+
val tts = OfflineTts(config=config)
26+
val audio = tts.generateWithCallback(text="某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。", callback=::callback)
27+
audio.save(filename="test-zh.wav")
28+
tts.release()
29+
println("Saved to test-zh.wav")
30+
}
31+
32+
fun testVits() {
833
// see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
934
// https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
1035
var config = OfflineTtsConfig(

sherpa-onnx/c-api/c-api.cc

+8-4
Original file line numberDiff line numberDiff line change
@@ -1727,11 +1727,15 @@ const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
17271727
auto p = new SherpaOnnxOnlinePunctuation;
17281728
try {
17291729
sherpa_onnx::OnlinePunctuationConfig punctuation_config;
1730-
punctuation_config.model.cnn_bilstm = SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
1731-
punctuation_config.model.bpe_vocab = SHERPA_ONNX_OR(config->model.bpe_vocab, "");
1732-
punctuation_config.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
1730+
punctuation_config.model.cnn_bilstm =
1731+
SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
1732+
punctuation_config.model.bpe_vocab =
1733+
SHERPA_ONNX_OR(config->model.bpe_vocab, "");
1734+
punctuation_config.model.num_threads =
1735+
SHERPA_ONNX_OR(config->model.num_threads, 1);
17331736
punctuation_config.model.debug = config->model.debug;
1734-
punctuation_config.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
1737+
punctuation_config.model.provider =
1738+
SHERPA_ONNX_OR(config->model.provider, "cpu");
17351739

17361740
p->impl =
17371741
std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config);

sherpa-onnx/c-api/c-api.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -1381,12 +1381,14 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig {
13811381
SherpaOnnxOnlinePunctuationModelConfig model;
13821382
} SherpaOnnxOnlinePunctuationConfig;
13831383

1384-
SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation SherpaOnnxOnlinePunctuation;
1384+
SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation
1385+
SherpaOnnxOnlinePunctuation;
13851386

13861387
// Create an online punctuation processor. The user has to invoke
13871388
// SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer
13881389
// to avoid memory leak
1389-
SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
1390+
SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *
1391+
SherpaOnnxCreateOnlinePunctuation(
13901392
const SherpaOnnxOnlinePunctuationConfig *config);
13911393

13921394
// Free a pointer returned by SherpaOnnxCreateOnlinePunctuation()

sherpa-onnx/csrc/jieba-lexicon.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class JiebaLexicon::Impl {
155155

156156
this_sentence.insert(this_sentence.end(), ids.begin(), ids.end());
157157

158-
if (w == "" || w == "" || w == "" || w == "") {
158+
if (IsPunct(w)) {
159159
ans.emplace_back(std::move(this_sentence));
160160
this_sentence = {};
161161
}

sherpa-onnx/jni/offline-tts.cc

+49
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
2020
jobject model = env->GetObjectField(config, fid);
2121
jclass model_config_cls = env->GetObjectClass(model);
2222

23+
// vits
2324
fid = env->GetFieldID(model_config_cls, "vits",
2425
"Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;");
2526
jobject vits = env->GetObjectField(model, fid);
@@ -64,6 +65,54 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
6465
fid = env->GetFieldID(vits_cls, "lengthScale", "F");
6566
ans.model.vits.length_scale = env->GetFloatField(vits, fid);
6667

68+
// matcha
69+
fid = env->GetFieldID(model_config_cls, "matcha",
70+
"Lcom/k2fsa/sherpa/onnx/OfflineTtsMatchaModelConfig;");
71+
jobject matcha = env->GetObjectField(model, fid);
72+
jclass matcha_cls = env->GetObjectClass(matcha);
73+
74+
fid = env->GetFieldID(matcha_cls, "acousticModel", "Ljava/lang/String;");
75+
s = (jstring)env->GetObjectField(matcha, fid);
76+
p = env->GetStringUTFChars(s, nullptr);
77+
ans.model.matcha.acoustic_model = p;
78+
env->ReleaseStringUTFChars(s, p);
79+
80+
fid = env->GetFieldID(matcha_cls, "vocoder", "Ljava/lang/String;");
81+
s = (jstring)env->GetObjectField(matcha, fid);
82+
p = env->GetStringUTFChars(s, nullptr);
83+
ans.model.matcha.vocoder = p;
84+
env->ReleaseStringUTFChars(s, p);
85+
86+
fid = env->GetFieldID(matcha_cls, "lexicon", "Ljava/lang/String;");
87+
s = (jstring)env->GetObjectField(matcha, fid);
88+
p = env->GetStringUTFChars(s, nullptr);
89+
ans.model.matcha.lexicon = p;
90+
env->ReleaseStringUTFChars(s, p);
91+
92+
fid = env->GetFieldID(matcha_cls, "tokens", "Ljava/lang/String;");
93+
s = (jstring)env->GetObjectField(matcha, fid);
94+
p = env->GetStringUTFChars(s, nullptr);
95+
ans.model.matcha.tokens = p;
96+
env->ReleaseStringUTFChars(s, p);
97+
98+
fid = env->GetFieldID(matcha_cls, "dataDir", "Ljava/lang/String;");
99+
s = (jstring)env->GetObjectField(matcha, fid);
100+
p = env->GetStringUTFChars(s, nullptr);
101+
ans.model.matcha.data_dir = p;
102+
env->ReleaseStringUTFChars(s, p);
103+
104+
fid = env->GetFieldID(matcha_cls, "dictDir", "Ljava/lang/String;");
105+
s = (jstring)env->GetObjectField(matcha, fid);
106+
p = env->GetStringUTFChars(s, nullptr);
107+
ans.model.matcha.dict_dir = p;
108+
env->ReleaseStringUTFChars(s, p);
109+
110+
fid = env->GetFieldID(matcha_cls, "noiseScale", "F");
111+
ans.model.matcha.noise_scale = env->GetFloatField(matcha, fid);
112+
113+
fid = env->GetFieldID(matcha_cls, "lengthScale", "F");
114+
ans.model.matcha.length_scale = env->GetFloatField(matcha, fid);
115+
67116
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
68117
ans.model.num_threads = env->GetIntField(model, fid);
69118

sherpa-onnx/kotlin-api/Tts.kt

+12
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,20 @@ data class OfflineTtsVitsModelConfig(
1414
var lengthScale: Float = 1.0f,
1515
)
1616

17+
data class OfflineTtsMatchaModelConfig(
18+
var acousticModel: String = "",
19+
var vocoder: String = "",
20+
var lexicon: String = "",
21+
var tokens: String = "",
22+
var dataDir: String = "",
23+
var dictDir: String = "",
24+
var noiseScale: Float = 1.0f,
25+
var lengthScale: Float = 1.0f,
26+
)
27+
1728
data class OfflineTtsModelConfig(
1829
var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(),
30+
var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(),
1931
var numThreads: Int = 1,
2032
var debug: Boolean = false,
2133
var provider: String = "cpu",

0 commit comments

Comments
 (0)