Skip to content

Commit 9dd0e03

Browse files
authored
Enable to stop TTS generation (#1041)
1 parent 96ab843 commit 9dd0e03

File tree

32 files changed

+248
-69
lines changed

32 files changed

+248
-69
lines changed

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ project(sherpa-onnx)
88
# ./nodejs-addon-examples
99
# ./dart-api-examples/
1010
# ./sherpa-onnx/flutter/CHANGELOG.md
11-
set(SHERPA_ONNX_VERSION "1.10.0")
11+
set(SHERPA_ONNX_VERSION "1.10.1")
1212

1313
# Disable warning about
1414
#

android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt

+30-5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ class MainActivity : AppCompatActivity() {
2626
private lateinit var speed: EditText
2727
private lateinit var generate: Button
2828
private lateinit var play: Button
29+
private lateinit var stop: Button
30+
private var stopped: Boolean = false
31+
private var mediaPlayer: MediaPlayer? = null
2932

3033
// see
3134
// https://developer.android.com/reference/kotlin/android/media/AudioTrack
@@ -49,9 +52,11 @@ class MainActivity : AppCompatActivity() {
4952

5053
generate = findViewById(R.id.generate)
5154
play = findViewById(R.id.play)
55+
stop = findViewById(R.id.stop)
5256

5357
generate.setOnClickListener { onClickGenerate() }
5458
play.setOnClickListener { onClickPlay() }
59+
stop.setOnClickListener { onClickStop() }
5560

5661
sid.setText("0")
5762
speed.setText("1.0")
@@ -70,7 +75,7 @@ class MainActivity : AppCompatActivity() {
7075
AudioFormat.CHANNEL_OUT_MONO,
7176
AudioFormat.ENCODING_PCM_FLOAT
7277
)
73-
Log.i(TAG, "sampleRate: ${sampleRate}, buffLength: ${bufLength}")
78+
Log.i(TAG, "sampleRate: $sampleRate, buffLength: $bufLength")
7479

7580
val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
7681
.setUsage(AudioAttributes.USAGE_MEDIA)
@@ -90,8 +95,14 @@ class MainActivity : AppCompatActivity() {
9095
}
9196

9297
// this function is called from C++
93-
private fun callback(samples: FloatArray) {
94-
track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
98+
private fun callback(samples: FloatArray): Int {
99+
if (!stopped) {
100+
track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
101+
return 1
102+
} else {
103+
track.stop()
104+
return 0
105+
}
95106
}
96107

97108
private fun onClickGenerate() {
@@ -127,6 +138,8 @@ class MainActivity : AppCompatActivity() {
127138
track.play()
128139

129140
play.isEnabled = false
141+
generate.isEnabled = false
142+
stopped = false
130143
Thread {
131144
val audio = tts.generateWithCallback(
132145
text = textStr,
@@ -140,6 +153,7 @@ class MainActivity : AppCompatActivity() {
140153
if (ok) {
141154
runOnUiThread {
142155
play.isEnabled = true
156+
generate.isEnabled = true
143157
track.stop()
144158
}
145159
}
@@ -148,11 +162,22 @@ class MainActivity : AppCompatActivity() {
148162

149163
private fun onClickPlay() {
150164
val filename = application.filesDir.absolutePath + "/generated.wav"
151-
val mediaPlayer = MediaPlayer.create(
165+
mediaPlayer?.stop()
166+
mediaPlayer = MediaPlayer.create(
152167
applicationContext,
153168
Uri.fromFile(File(filename))
154169
)
155-
mediaPlayer.start()
170+
mediaPlayer?.start()
171+
}
172+
173+
private fun onClickStop() {
174+
stopped = true
175+
play.isEnabled = true
176+
generate.isEnabled = true
177+
track.pause()
178+
track.flush()
179+
mediaPlayer?.stop()
180+
mediaPlayer = null
156181
}
157182

158183
private fun initTts() {

android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/Tts.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class OfflineTts(
7676
text: String,
7777
sid: Int = 0,
7878
speed: Float = 1.0f,
79-
callback: (samples: FloatArray) -> Unit
79+
callback: (samples: FloatArray) -> Int
8080
): GeneratedAudio {
8181
val objArray = generateWithCallbackImpl(
8282
ptr,
@@ -146,7 +146,7 @@ class OfflineTts(
146146
text: String,
147147
sid: Int = 0,
148148
speed: Float = 1.0f,
149-
callback: (samples: FloatArray) -> Unit
149+
callback: (samples: FloatArray) -> Int
150150
): Array<Any>
151151

152152
companion object {

android/SherpaOnnxTts/app/src/main/res/layout/activity_main.xml

+12
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,16 @@
8484
app:layout_constraintLeft_toLeftOf="parent"
8585
app:layout_constraintRight_toRightOf="parent"
8686
app:layout_constraintTop_toBottomOf="@id/generate" />
87+
88+
<Button
89+
android:id="@+id/stop"
90+
android:textAllCaps="false"
91+
android:layout_width="match_parent"
92+
android:layout_height="50dp"
93+
android:layout_marginTop="4dp"
94+
android:text="@string/stop"
95+
app:layout_constraintLeft_toLeftOf="parent"
96+
app:layout_constraintRight_toRightOf="parent"
97+
app:layout_constraintTop_toBottomOf="@id/play" />
98+
8799
</androidx.constraintlayout.widget.ConstraintLayout>

android/SherpaOnnxTts/app/src/main/res/values/strings.xml

+1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77
<string name="text_hint">Please input your text here</string>
88
<string name="generate">Generate</string>
99
<string name="play">Play</string>
10+
<string name="stop">Stop</string>
1011
</resources>

android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsService.kt

+5-2
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class TtsService : TextToSpeechService() {
126126
return
127127
}
128128

129-
val ttsCallback = { floatSamples: FloatArray ->
129+
val ttsCallback: (FloatArray) -> Int = fun(floatSamples): Int {
130130
// convert FloatArray to ByteArray
131131
val samples = floatArrayToByteArray(floatSamples)
132132
val maxBufferSize: Int = callback.maxBufferSize
@@ -137,6 +137,9 @@ class TtsService : TextToSpeechService() {
137137
offset += bytesToWrite
138138
}
139139

140+
// 1 means to continue
141+
// 0 means to stop
142+
return 1
140143
}
141144

142145
Log.i(TAG, "text: $text")
@@ -160,4 +163,4 @@ class TtsService : TextToSpeechService() {
160163
}
161164
return byteArray
162165
}
163-
}
166+
}

dart-api-examples/non-streaming-asr/pubspec.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ environment:
1010

1111
# Add regular dependencies here.
1212
dependencies:
13-
sherpa_onnx: ^1.10.0
13+
sherpa_onnx: ^1.10.1
1414
path: ^1.9.0
1515
args: ^2.5.0
1616

dart-api-examples/streaming-asr/pubspec.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ environment:
1111

1212
# Add regular dependencies here.
1313
dependencies:
14-
sherpa_onnx: ^1.10.0
14+
sherpa_onnx: ^1.10.1
1515
path: ^1.9.0
1616
args: ^2.5.0
1717

dart-api-examples/tts/bin/piper.dart

+4
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ void main(List<String> arguments) async {
6868
callback: (Float32List samples) {
6969
print('${samples.length} samples received');
7070
// You can play samples in a separate thread/isolate
71+
72+
// 1 means to continue
73+
// 0 means to stop
74+
return 1;
7175
});
7276
tts.free();
7377

dart-api-examples/tts/pubspec.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ environment:
88

99
# Add regular dependencies here.
1010
dependencies:
11-
sherpa_onnx: ^1.10.0
11+
sherpa_onnx: ^1.10.1
1212
path: ^1.9.0
1313
args: ^2.5.0
1414

dart-api-examples/vad/pubspec.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ environment:
99
sdk: ^3.4.0
1010

1111
dependencies:
12-
sherpa_onnx: ^1.10.0
12+
sherpa_onnx: ^1.10.1
1313
path: ^1.9.0
1414
args: ^2.5.0
1515

dotnet-examples/offline-tts-play/Program.cs

+4
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,10 @@ private static void Run(Options options)
187187
Marshal.Copy(samples, data, 0, n);
188188

189189
dataItems.Add(data);
190+
191+
// 1 means to keep generating
192+
// 0 means to stop generating
193+
return 1;
190194
};
191195

192196
bool playFinished = false;

kotlin-api-examples/test_tts.kt

+41-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,46 @@ fun testTts() {
2525
println("Saved to test-en.wav")
2626
}
2727

28-
fun callback(samples: FloatArray): Unit {
28+
/*
29+
1. Unzip test_tts.jar
30+
2.
31+
javap ./com/k2fsa/sherpa/onnx/Test_ttsKt\$testTts\$audio\$1.class
32+
33+
3. It prints:
34+
Compiled from "test_tts.kt"
35+
final class com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 extends kotlin.jvm.internal.FunctionReferenceImpl implements kotlin.jvm.functions.Function1<float[], java.lang.Integer> {
36+
public static final com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 INSTANCE;
37+
com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1();
38+
public final java.lang.Integer invoke(float[]);
39+
public java.lang.Object invoke(java.lang.Object);
40+
static {};
41+
}
42+
43+
4.
44+
javap -s ./com/k2fsa/sherpa/onnx/Test_ttsKt\$testTts\$audio\$1.class
45+
46+
5. It prints
47+
Compiled from "test_tts.kt"
48+
final class com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 extends kotlin.jvm.internal.FunctionReferenceImpl implements kotlin.jvm.functions.Function1<float[], java.lang.Integer> {
49+
public static final com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1 INSTANCE;
50+
descriptor: Lcom/k2fsa/sherpa/onnx/Test_ttsKt$testTts$audio$1;
51+
com.k2fsa.sherpa.onnx.Test_ttsKt$testTts$audio$1();
52+
descriptor: ()V
53+
54+
public final java.lang.Integer invoke(float[]);
55+
descriptor: ([F)Ljava/lang/Integer;
56+
57+
public java.lang.Object invoke(java.lang.Object);
58+
descriptor: (Ljava/lang/Object;)Ljava/lang/Object;
59+
60+
static {};
61+
descriptor: ()V
62+
}
63+
*/
64+
fun callback(samples: FloatArray): Int {
2965
println("callback got called with ${samples.size} samples");
66+
67+
// 1 means to continue
68+
// 0 means to stop
69+
return 1
3070
}
Binary file not shown.

mfc-examples/NonStreamingTextToSpeech/NonStreamingTextToSpeechDlg.cpp

+43-19
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ static bool g_started = false;
5757
static bool g_stopped = false;
5858
static bool g_killed = false;
5959

60-
static void AudioGeneratedCallback(const float *s, int32_t n) {
60+
static int32_t AudioGeneratedCallback(const float *s, int32_t n) {
6161
if (n > 0) {
6262
Samples samples;
6363
samples.data = std::vector<float>{s, s + n};
@@ -66,6 +66,10 @@ static void AudioGeneratedCallback(const float *s, int32_t n) {
6666
g_buffer.samples.push(std::move(samples));
6767
g_started = true;
6868
}
69+
if (g_killed) {
70+
return 0;
71+
}
72+
return 1;
6973
}
7074

7175
static int PlayCallback(const void * /*in*/, void *out,
@@ -324,6 +328,7 @@ BEGIN_MESSAGE_MAP(CNonStreamingTextToSpeechDlg, CDialogEx)
324328
ON_WM_PAINT()
325329
ON_WM_QUERYDRAGICON()
326330
ON_BN_CLICKED(IDOK, &CNonStreamingTextToSpeechDlg::OnBnClickedOk)
331+
ON_BN_CLICKED(IDC_STOP, &CNonStreamingTextToSpeechDlg::OnBnClickedStop)
327332
END_MESSAGE_MAP()
328333

329334

@@ -492,11 +497,18 @@ void CNonStreamingTextToSpeechDlg::Init() {
492497
if (tts_) {
493498
SherpaOnnxDestroyOfflineTts(tts_);
494499
}
500+
if (generate_thread_ && generate_thread_->joinable()) {
501+
generate_thread_->join();
502+
}
503+
504+
if (play_thread_ && play_thread_->joinable()) {
505+
play_thread_->join();
506+
}
495507
}
496508

497509

498510
static std::string ToString(const CString &s) {
499-
CT2CA pszConvertedAnsiString( s);
511+
CT2CA pszConvertedAnsiString(s);
500512
return std::string(pszConvertedAnsiString);
501513
}
502514

@@ -510,7 +522,7 @@ void CNonStreamingTextToSpeechDlg::OnBnClickedOk() {
510522
}
511523

512524
speed_.GetWindowText(s);
513-
float speed = static_cast<float>(_ttof(s));
525+
float speed = static_cast<float>(_ttof(s));
514526
if (speed < 0) {
515527
AfxMessageBox(Utf8ToUtf16("Please input a valid speed").c_str(), MB_OK);
516528
return;
@@ -541,28 +553,40 @@ void CNonStreamingTextToSpeechDlg::OnBnClickedOk() {
541553
// for simplicity
542554
play_thread_ = std::make_unique<std::thread>(StartPlayback, SherpaOnnxOfflineTtsSampleRate(tts_));
543555

544-
generate_btn_.EnableWindow(FALSE);
545-
546-
const SherpaOnnxGeneratedAudio *audio =
547-
SherpaOnnxOfflineTtsGenerateWithCallback(tts_, ss.c_str(), speaker_id, speed, &AudioGeneratedCallback);
548-
549-
generate_btn_.EnableWindow(TRUE);
556+
if (generate_thread_ && generate_thread_->joinable()) {
557+
generate_thread_->join();
558+
}
550559

551560
output_filename_.GetWindowText(s);
552561
std::string filename = ToString(s);
553562

554-
int ok = SherpaOnnxWriteWave(audio->samples, audio->n, audio->sample_rate,
555-
filename.c_str());
563+
generate_thread_ = std::make_unique<std::thread>([ss, this,filename, speaker_id, speed]() {
564+
std::string text = ss;
556565

557-
SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio);
566+
// generate_btn_.EnableWindow(FALSE);
558567

559-
if (ok) {
560-
// AfxMessageBox(Utf8ToUtf16(std::string("Saved to ") + filename + " successfully").c_str(), MB_OK);
561-
AppendLineToMultilineEditCtrl(my_hint_, std::string("Saved to ") + filename + " successfully");
562-
} else {
563-
// AfxMessageBox(Utf8ToUtf16(std::string("Failed to save to ") + filename).c_str(), MB_OK);
564-
AppendLineToMultilineEditCtrl(my_hint_, std::string("Failed to saved to ") + filename);
565-
}
568+
const SherpaOnnxGeneratedAudio *audio =
569+
SherpaOnnxOfflineTtsGenerateWithCallback(tts_, text.c_str(), speaker_id, speed, &AudioGeneratedCallback);
570+
// generate_btn_.EnableWindow(TRUE);
571+
g_stopped = true;
572+
573+
int ok = SherpaOnnxWriteWave(audio->samples, audio->n, audio->sample_rate,
574+
filename.c_str());
575+
576+
SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio);
577+
578+
if (ok) {
579+
// AfxMessageBox(Utf8ToUtf16(std::string("Saved to ") + filename + " successfully").c_str(), MB_OK);
580+
581+
// AppendLineToMultilineEditCtrl(my_hint_, std::string("Saved to ") + filename + " successfully");
582+
} else {
583+
// AfxMessageBox(Utf8ToUtf16(std::string("Failed to save to ") + filename).c_str(), MB_OK);
584+
585+
// AppendLineToMultilineEditCtrl(my_hint_, std::string("Failed to saved to ") + filename);
586+
}
587+
});
566588

567589
//CDialogEx::OnOK();
568590
}
591+
592+
void CNonStreamingTextToSpeechDlg::OnBnClickedStop() { g_killed = true; }

mfc-examples/NonStreamingTextToSpeech/NonStreamingTextToSpeechDlg.h

+3
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,8 @@ class CNonStreamingTextToSpeechDlg : public CDialogEx
6060
private:
6161
Microphone mic_;
6262
std::unique_ptr<std::thread> play_thread_;
63+
std::unique_ptr<std::thread> generate_thread_;
6364

65+
public:
66+
afx_msg void OnBnClickedStop();
6467
};

0 commit comments

Comments
 (0)