Skip to content

Commit acf0975

Browse files
authored
Support whisper language/task in various language bindings. (#679)
1 parent 842d04d commit acf0975

File tree

15 files changed

+117
-62
lines changed

15 files changed

+117
-62
lines changed

dotnet-examples/offline-decode-files/Program.cs

+8
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ class Options
4040
[Option("whisper-decoder", Required = false, Default = "", HelpText = "Path to whisper decoder.onnx. Used only for whisper models")]
4141
public string WhisperDecoder { get; set; }
4242

43+
[Option("whisper-language", Required = false, Default = "", HelpText = "Language of the input file. Can be empty")]
44+
public string WhisperLanguage{ get; set; }
45+
46+
[Option("whisper-task", Required = false, Default = "transcribe", HelpText = "transcribe or translate")]
47+
public string WhisperTask{ get; set; }
48+
4349
[Option("tdnn-model", Required = false, Default = "", HelpText = "Path to tdnn yesno model")]
4450
public string TdnnModel { get; set; }
4551

@@ -193,6 +199,8 @@ private static void Run(Options options)
193199
{
194200
config.ModelConfig.Whisper.Encoder = options.WhisperEncoder;
195201
config.ModelConfig.Whisper.Decoder = options.WhisperDecoder;
202+
config.ModelConfig.Whisper.Language = options.WhisperLanguage;
203+
config.ModelConfig.Whisper.Task = options.WhisperTask;
196204
}
197205
else if (!String.IsNullOrEmpty(options.TdnnModel))
198206
{

go-api-examples/non-streaming-decode-files/main.go

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ func main() {
2929

3030
flag.StringVar(&config.ModelConfig.Whisper.Encoder, "whisper-encoder", "", "Path to the whisper encoder model")
3131
flag.StringVar(&config.ModelConfig.Whisper.Decoder, "whisper-decoder", "", "Path to the whisper decoder model")
32+
flag.StringVar(&config.ModelConfig.Whisper.Language, "whisper-language", "", "Language of the input wave. You can leave it empty ")
33+
flag.StringVar(&config.ModelConfig.Whisper.Task, "whisper-task", "transcribe", "transcribe or translate")
3234

3335
flag.StringVar(&config.ModelConfig.Tdnn.Model, "tdnn-model", "", "Path to the tdnn model")
3436

nodejs-examples/test-offline-nemo-ctc.js

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ function createOfflineRecognizer() {
2727
whisper: {
2828
encoder: '',
2929
decoder: '',
30+
language: '',
31+
task: '',
3032
},
3133
tdnn: {
3234
model: '',

nodejs-examples/test-offline-paraformer.js

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ function createOfflineRecognizer() {
2727
whisper: {
2828
encoder: '',
2929
decoder: '',
30+
language: '',
31+
task: '',
3032
},
3133
tdnn: {
3234
model: '',

nodejs-examples/test-offline-transducer.js

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ function createOfflineRecognizer() {
3030
whisper: {
3131
encoder: '',
3232
decoder: '',
33+
language: '',
34+
task: '',
3335
},
3436
tdnn: {
3537
model: '',

nodejs-examples/test-offline-whisper.js

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ function createOfflineRecognizer() {
2727
whisper: {
2828
encoder: './sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx',
2929
decoder: './sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx',
30+
language: '',
31+
task: 'transcribe',
3032
},
3133
tdnn: {
3234
model: '',

scripts/dotnet/offline.cs

+8
Original file line numberDiff line numberDiff line change
@@ -279,12 +279,20 @@ public OfflineWhisperModelConfig()
279279
{
280280
Encoder = "";
281281
Decoder = "";
282+
Language = "";
283+
Task = "transcribe";
282284
}
283285
[MarshalAs(UnmanagedType.LPStr)]
284286
public string Encoder;
285287

286288
[MarshalAs(UnmanagedType.LPStr)]
287289
public string Decoder;
290+
291+
[MarshalAs(UnmanagedType.LPStr)]
292+
public string Language;
293+
294+
[MarshalAs(UnmanagedType.LPStr)]
295+
public string Task;
288296
}
289297

290298
[StructLayout(LayoutKind.Sequential)]

scripts/go/sherpa_onnx.go

+10-2
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,10 @@ type OfflineNemoEncDecCtcModelConfig struct {
326326
}
327327

328328
type OfflineWhisperModelConfig struct {
329-
Encoder string
330-
Decoder string
329+
Encoder string
330+
Decoder string
331+
Language string
332+
Task string
331333
}
332334

333335
type OfflineTdnnModelConfig struct {
@@ -423,6 +425,12 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer {
423425
c.model_config.whisper.decoder = C.CString(config.ModelConfig.Whisper.Decoder)
424426
defer C.free(unsafe.Pointer(c.model_config.whisper.decoder))
425427

428+
c.model_config.whisper.language = C.CString(config.ModelConfig.Whisper.Language)
429+
defer C.free(unsafe.Pointer(c.model_config.whisper.language))
430+
431+
c.model_config.whisper.task = C.CString(config.ModelConfig.Whisper.Task)
432+
defer C.free(unsafe.Pointer(c.model_config.whisper.task))
433+
426434
c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model)
427435
defer C.free(unsafe.Pointer(c.model_config.tdnn.model))
428436

sherpa-onnx/c-api/c-api.cc

+24-26
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111

1212
#include "sherpa-onnx/csrc/circular-buffer.h"
1313
#include "sherpa-onnx/csrc/display.h"
14+
#include "sherpa-onnx/csrc/keyword-spotter.h"
1415
#include "sherpa-onnx/csrc/macros.h"
1516
#include "sherpa-onnx/csrc/offline-recognizer.h"
1617
#include "sherpa-onnx/csrc/offline-tts.h"
1718
#include "sherpa-onnx/csrc/online-recognizer.h"
1819
#include "sherpa-onnx/csrc/voice-activity-detector.h"
1920
#include "sherpa-onnx/csrc/wave-writer.h"
20-
#include "sherpa-onnx/csrc/keyword-spotter.h"
2121

2222
struct SherpaOnnxOnlineRecognizer {
2323
std::unique_ptr<sherpa_onnx::OnlineRecognizer> impl;
@@ -301,6 +301,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
301301
recognizer_config.model_config.whisper.language =
302302
SHERPA_ONNX_OR(config->model_config.whisper.language, "");
303303

304+
recognizer_config.model_config.whisper.task =
305+
SHERPA_ONNX_OR(config->model_config.whisper.task, "transcribe");
306+
304307
recognizer_config.model_config.tdnn.model =
305308
SHERPA_ONNX_OR(config->model_config.tdnn.model, "");
306309

@@ -422,8 +425,8 @@ struct SherpaOnnxKeywordSpotter {
422425
std::unique_ptr<sherpa_onnx::KeywordSpotter> impl;
423426
};
424427

425-
SherpaOnnxKeywordSpotter* CreateKeywordSpotter(
426-
const SherpaOnnxKeywordSpotterConfig* config) {
428+
SherpaOnnxKeywordSpotter *CreateKeywordSpotter(
429+
const SherpaOnnxKeywordSpotterConfig *config) {
427430
sherpa_onnx::KeywordSpotterConfig spotter_config;
428431

429432
spotter_config.feat_config.sampling_rate =
@@ -457,20 +460,17 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter(
457460
spotter_config.model_config.debug =
458461
SHERPA_ONNX_OR(config->model_config.debug, 0);
459462

460-
spotter_config.max_active_paths =
461-
SHERPA_ONNX_OR(config->max_active_paths, 4);
463+
spotter_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4);
462464

463465
spotter_config.num_trailing_blanks =
464-
SHERPA_ONNX_OR(config->num_trailing_blanks , 1);
466+
SHERPA_ONNX_OR(config->num_trailing_blanks, 1);
465467

466-
spotter_config.keywords_score =
467-
SHERPA_ONNX_OR(config->keywords_score, 1.0);
468+
spotter_config.keywords_score = SHERPA_ONNX_OR(config->keywords_score, 1.0);
468469

469470
spotter_config.keywords_threshold =
470471
SHERPA_ONNX_OR(config->keywords_threshold, 0.25);
471472

472-
spotter_config.keywords_file =
473-
SHERPA_ONNX_OR(config->keywords_file, "");
473+
spotter_config.keywords_file = SHERPA_ONNX_OR(config->keywords_file, "");
474474

475475
if (config->model_config.debug) {
476476
SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str());
@@ -481,39 +481,37 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter(
481481
return nullptr;
482482
}
483483

484-
SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter;
484+
SherpaOnnxKeywordSpotter *spotter = new SherpaOnnxKeywordSpotter;
485485

486-
spotter->impl =
487-
std::make_unique<sherpa_onnx::KeywordSpotter>(spotter_config);
486+
spotter->impl = std::make_unique<sherpa_onnx::KeywordSpotter>(spotter_config);
488487

489488
return spotter;
490489
}
491490

492-
void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter* spotter) {
491+
void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter) {
493492
delete spotter;
494493
}
495494

496-
SherpaOnnxOnlineStream* CreateKeywordStream(
497-
const SherpaOnnxKeywordSpotter* spotter) {
498-
SherpaOnnxOnlineStream* stream =
495+
SherpaOnnxOnlineStream *CreateKeywordStream(
496+
const SherpaOnnxKeywordSpotter *spotter) {
497+
SherpaOnnxOnlineStream *stream =
499498
new SherpaOnnxOnlineStream(spotter->impl->CreateStream());
500499
return stream;
501500
}
502501

503-
int32_t IsKeywordStreamReady(
504-
SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream) {
502+
int32_t IsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter,
503+
SherpaOnnxOnlineStream *stream) {
505504
return spotter->impl->IsReady(stream->impl.get());
506505
}
507506

508-
void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter,
509-
SherpaOnnxOnlineStream* stream) {
507+
void DecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter,
508+
SherpaOnnxOnlineStream *stream) {
510509
return spotter->impl->DecodeStream(stream->impl.get());
511510
}
512511

513-
void DecodeMultipleKeywordStreams(
514-
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams,
515-
int32_t n) {
516-
std::vector<sherpa_onnx::OnlineStream*> ss(n);
512+
void DecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter,
513+
SherpaOnnxOnlineStream **streams, int32_t n) {
514+
std::vector<sherpa_onnx::OnlineStream *> ss(n);
517515
for (int32_t i = 0; i != n; ++i) {
518516
ss[i] = streams[i]->impl.get();
519517
}
@@ -522,7 +520,7 @@ void DecodeMultipleKeywordStreams(
522520

523521
const SherpaOnnxKeywordResult *GetKeywordResult(
524522
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) {
525-
const sherpa_onnx::KeywordResult& result =
523+
const sherpa_onnx::KeywordResult &result =
526524
spotter->impl->GetResult(stream->impl.get());
527525
const auto &keyword = result.keyword;
528526

sherpa-onnx/c-api/c-api.h

+17-18
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig {
333333
const char *encoder;
334334
const char *decoder;
335335
const char *language;
336+
const char *task;
336337
} SherpaOnnxOfflineWhisperModelConfig;
337338

338339
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig {
@@ -483,19 +484,19 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult {
483484
/// For Chinese, it consists of Chinese words without spaces.
484485
/// Example 1: "hello world"
485486
/// Example 2: "你好世界"
486-
const char* keyword;
487+
const char *keyword;
487488

488489
/// Decoded results at the token level.
489490
/// For instance, for BPE-based models it consists of a list of BPE tokens.
490-
const char* tokens;
491+
const char *tokens;
491492

492-
const char* const* tokens_arr;
493+
const char *const *tokens_arr;
493494

494495
int32_t count;
495496

496497
/// timestamps.size() == tokens.size()
497498
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
498-
float* timestamps;
499+
float *timestamps;
499500

500501
/// Starting time of this segment.
501502
/// When an endpoint is detected, it will change
@@ -511,7 +512,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult {
511512
* "start_time": x,
512513
* }
513514
*/
514-
const char* json;
515+
const char *json;
515516
} SherpaOnnxKeywordResult;
516517

517518
SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig {
@@ -521,7 +522,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig {
521522
int32_t num_trailing_blanks;
522523
float keywords_score;
523524
float keywords_threshold;
524-
const char* keywords_file;
525+
const char *keywords_file;
525526
} SherpaOnnxKeywordSpotterConfig;
526527

527528
SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter
@@ -530,36 +531,35 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter
530531
/// @param config Config for the keyword spotter.
531532
/// @return Return a pointer to the spotter. The user has to invoke
532533
/// DestroyKeywordSpotter() to free it to avoid memory leak.
533-
SHERPA_ONNX_API SherpaOnnxKeywordSpotter* CreateKeywordSpotter(
534-
const SherpaOnnxKeywordSpotterConfig* config);
534+
SHERPA_ONNX_API SherpaOnnxKeywordSpotter *CreateKeywordSpotter(
535+
const SherpaOnnxKeywordSpotterConfig *config);
535536

536537
/// Free a pointer returned by CreateKeywordSpotter()
537538
///
538539
/// @param p A pointer returned by CreateKeywordSpotter()
539-
SHERPA_ONNX_API void DestroyKeywordSpotter(
540-
SherpaOnnxKeywordSpotter* spotter);
540+
SHERPA_ONNX_API void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter);
541541

542542
/// Create an online stream for accepting wave samples.
543543
///
544544
/// @param spotter A pointer returned by CreateKeywordSpotter()
545545
/// @return Return a pointer to an OnlineStream. The user has to invoke
546546
/// DestroyOnlineStream() to free it to avoid memory leak.
547-
SHERPA_ONNX_API SherpaOnnxOnlineStream* CreateKeywordStream(
548-
const SherpaOnnxKeywordSpotter* spotter);
547+
SHERPA_ONNX_API SherpaOnnxOnlineStream *CreateKeywordStream(
548+
const SherpaOnnxKeywordSpotter *spotter);
549549

550550
/// Return 1 if there are enough number of feature frames for decoding.
551551
/// Return 0 otherwise.
552552
///
553553
/// @param spotter A pointer returned by CreateKeywordSpotter
554554
/// @param stream A pointer returned by CreateKeywordStream
555-
SHERPA_ONNX_API int32_t IsKeywordStreamReady(
556-
SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream);
555+
SHERPA_ONNX_API int32_t IsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter,
556+
SherpaOnnxOnlineStream *stream);
557557

558558
/// Call this function to run the neural network model and decoding.
559559
//
560560
/// Precondition for this function: IsKeywordStreamReady() MUST return 1.
561-
SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter,
562-
SherpaOnnxOnlineStream* stream);
561+
SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter,
562+
SherpaOnnxOnlineStream *stream);
563563

564564
/// This function is similar to DecodeKeywordStream(). It decodes multiple
565565
/// OnlineStream in parallel.
@@ -588,8 +588,7 @@ SHERPA_ONNX_API const SherpaOnnxKeywordResult *GetKeywordResult(
588588
/// Destroy the pointer returned by GetKeywordResult().
589589
///
590590
/// @param r A pointer returned by GetKeywordResult()
591-
SHERPA_ONNX_API void DestroyKeywordResult(
592-
const SherpaOnnxKeywordResult *r);
591+
SHERPA_ONNX_API void DestroyKeywordResult(const SherpaOnnxKeywordResult *r);
593592

594593
// ============================================================
595594
// For VAD

sherpa-onnx/csrc/offline-tts-vits-model.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,8 @@ class OfflineTtsVitsModel::Impl {
223223
inputs.push_back(std::move(length_scale_tensor));
224224
inputs.push_back(std::move(noise_scale_w_tensor));
225225

226-
if (input_names_.size() == 6 && input_names_.back() == "sid") {
226+
if (input_names_.size() == 6 &&
227+
(input_names_.back() == "sid" || input_names_.back() == "speaker")) {
227228
inputs.push_back(std::move(sid_tensor));
228229
}
229230

sherpa-onnx/csrc/transducer-keyword-decoder.cc

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
//
33
// Copyright (c) 2023-2024 Xiaomi Corporation
44

5+
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
6+
57
#include <algorithm>
68
#include <cmath>
9+
#include <cstring>
710
#include <utility>
811
#include <vector>
912

1013
#include "sherpa-onnx/csrc/log.h"
1114
#include "sherpa-onnx/csrc/onnx-utils.h"
12-
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
1315

1416
namespace sherpa_onnx {
1517

0 commit comments

Comments
 (0)