Skip to content

Commit 5c2cc48

Browse files
authored
Add swift online punctuation (#1661)
1 parent 49154c9 commit 5c2cc48

File tree

5 files changed

+198
-0
lines changed

5 files changed

+198
-0
lines changed

sherpa-onnx/c-api/c-api.cc

+48
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "sherpa-onnx/csrc/macros.h"
2525
#include "sherpa-onnx/csrc/offline-punctuation.h"
2626
#include "sherpa-onnx/csrc/offline-recognizer.h"
27+
#include "sherpa-onnx/csrc/online-punctuation.h"
2728
#include "sherpa-onnx/csrc/online-recognizer.h"
2829
#include "sherpa-onnx/csrc/resample.h"
2930
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
@@ -1717,6 +1718,53 @@ const char *SherpaOfflinePunctuationAddPunct(
17171718

17181719
void SherpaOfflinePunctuationFreeText(const char *text) { delete[] text; }
17191720

1721+
struct SherpaOnnxOnlinePunctuation {
1722+
std::unique_ptr<sherpa_onnx::OnlinePunctuation> impl;
1723+
};
1724+
1725+
const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
1726+
const SherpaOnnxOnlinePunctuationConfig *config) {
1727+
auto p = new SherpaOnnxOnlinePunctuation;
1728+
try {
1729+
sherpa_onnx::OnlinePunctuationConfig punctuation_config;
1730+
punctuation_config.model.cnn_bilstm = SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
1731+
punctuation_config.model.bpe_vocab = SHERPA_ONNX_OR(config->model.bpe_vocab, "");
1732+
punctuation_config.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
1733+
punctuation_config.model.debug = config->model.debug;
1734+
punctuation_config.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
1735+
1736+
p->impl =
1737+
std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config);
1738+
} catch (const std::exception &e) {
1739+
SHERPA_ONNX_LOGE("Failed to create online punctuation: %s", e.what());
1740+
delete p;
1741+
return nullptr;
1742+
}
1743+
return p;
1744+
}
1745+
1746+
void SherpaOnnxDestroyOnlinePunctuation(const SherpaOnnxOnlinePunctuation *p) {
1747+
delete p;
1748+
}
1749+
1750+
const char *SherpaOnnxOnlinePunctuationAddPunct(
1751+
const SherpaOnnxOnlinePunctuation *punctuation, const char *text) {
1752+
if (!punctuation || !text) return nullptr;
1753+
1754+
try {
1755+
std::string s = punctuation->impl->AddPunctuationWithCase(text);
1756+
char *p = new char[s.size() + 1];
1757+
std::copy(s.begin(), s.end(), p);
1758+
p[s.size()] = '\0';
1759+
return p;
1760+
} catch (const std::exception &e) {
1761+
SHERPA_ONNX_LOGE("Failed to add punctuation: %s", e.what());
1762+
return nullptr;
1763+
}
1764+
}
1765+
1766+
void SherpaOnnxOnlinePunctuationFreeText(const char *text) { delete[] text; }
1767+
17201768
struct SherpaOnnxLinearResampler {
17211769
std::unique_ptr<sherpa_onnx::LinearResample> impl;
17221770
};

sherpa-onnx/c-api/c-api.h

+33
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,39 @@ SHERPA_ONNX_API const char *SherpaOfflinePunctuationAddPunct(
13691369

13701370
SHERPA_ONNX_API void SherpaOfflinePunctuationFreeText(const char *text);
13711371

1372+
SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationModelConfig {
1373+
const char *cnn_bilstm;
1374+
const char *bpe_vocab;
1375+
int32_t num_threads;
1376+
int32_t debug;
1377+
const char *provider;
1378+
} SherpaOnnxOnlinePunctuationModelConfig;
1379+
1380+
SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig {
1381+
SherpaOnnxOnlinePunctuationModelConfig model;
1382+
} SherpaOnnxOnlinePunctuationConfig;
1383+
1384+
SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation SherpaOnnxOnlinePunctuation;
1385+
1386+
// Create an online punctuation processor. The user has to invoke
1387+
// SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer
1388+
// to avoid memory leak
1389+
SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
1390+
const SherpaOnnxOnlinePunctuationConfig *config);
1391+
1392+
// Free a pointer returned by SherpaOnnxCreateOnlinePunctuation()
1393+
SHERPA_ONNX_API void SherpaOnnxDestroyOnlinePunctuation(
1394+
const SherpaOnnxOnlinePunctuation *punctuation);
1395+
1396+
// Add punctuations to the input text. The user has to invoke
1397+
// SherpaOnnxOnlinePunctuationFreeText() to free the returned pointer
1398+
// to avoid memory leak
1399+
SHERPA_ONNX_API const char *SherpaOnnxOnlinePunctuationAddPunct(
1400+
const SherpaOnnxOnlinePunctuation *punctuation, const char *text);
1401+
1402+
// Free a pointer returned by SherpaOnnxOnlinePunctuationAddPunct()
1403+
SHERPA_ONNX_API void SherpaOnnxOnlinePunctuationFreeText(const char *text);
1404+
13721405
// for resampling
13731406
SHERPA_ONNX_API typedef struct SherpaOnnxLinearResampler
13741407
SherpaOnnxLinearResampler;

swift-api-examples/SherpaOnnx.swift

+46
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,52 @@ class SherpaOnnxOfflinePunctuationWrapper {
10951095
}
10961096
}
10971097

1098+
func sherpaOnnxOnlinePunctuationModelConfig(
1099+
cnnBiLstm: String,
1100+
bpeVocab: String,
1101+
numThreads: Int = 1,
1102+
debug: Int = 0,
1103+
provider: String = "cpu"
1104+
) -> SherpaOnnxOnlinePunctuationModelConfig {
1105+
return SherpaOnnxOnlinePunctuationModelConfig(
1106+
cnn_bilstm: toCPointer(cnnBiLstm),
1107+
bpe_vocab: toCPointer(bpeVocab),
1108+
num_threads: Int32(numThreads),
1109+
debug: Int32(debug),
1110+
provider: toCPointer(provider))
1111+
}
1112+
1113+
func sherpaOnnxOnlinePunctuationConfig(
1114+
model: SherpaOnnxOnlinePunctuationModelConfig
1115+
) -> SherpaOnnxOnlinePunctuationConfig {
1116+
return SherpaOnnxOnlinePunctuationConfig(model: model)
1117+
}
1118+
1119+
class SherpaOnnxOnlinePunctuationWrapper {
1120+
/// A pointer to the underlying counterpart in C
1121+
let ptr: OpaquePointer!
1122+
1123+
/// Constructor taking a model config
1124+
init(
1125+
config: UnsafePointer<SherpaOnnxOnlinePunctuationConfig>!
1126+
) {
1127+
ptr = SherpaOnnxCreateOnlinePunctuation(config)
1128+
}
1129+
1130+
deinit {
1131+
if let ptr {
1132+
SherpaOnnxDestroyOnlinePunctuation(ptr)
1133+
}
1134+
}
1135+
1136+
func addPunct(text: String) -> String {
1137+
let cText = SherpaOnnxOnlinePunctuationAddPunct(ptr, toCPointer(text))
1138+
let ans = String(cString: cText!)
1139+
SherpaOnnxOnlinePunctuationFreeText(cText)
1140+
return ans
1141+
}
1142+
}
1143+
10981144
func sherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig(model: String)
10991145
-> SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig
11001146
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
func run() {
2+
let model = "./sherpa-onnx-online-punct-en-2024-08-06/model.onnx"
3+
let bpe = "./sherpa-onnx-online-punct-en-2024-08-06/bpe.vocab"
4+
5+
// Create model config
6+
let modelConfig = sherpaOnnxOnlinePunctuationModelConfig(
7+
cnnBiLstm: model,
8+
bpeVocab: bpe
9+
)
10+
11+
// Create punctuation config
12+
var config = sherpaOnnxOnlinePunctuationConfig(model: modelConfig)
13+
14+
// Create punctuation instance
15+
let punct = SherpaOnnxOnlinePunctuationWrapper(config: &config)
16+
17+
// Test texts
18+
let textList = [
19+
"how are you i am fine thank you",
20+
"The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry"
21+
]
22+
23+
// Process each text
24+
for i in 0..<textList.count {
25+
let t = punct.addPunct(text: textList[i])
26+
print("\nresult is:\n\(t)")
27+
}
28+
}
29+
30+
@main
31+
struct App {
32+
static func main() {
33+
run()
34+
}
35+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/env bash
2+
3+
set -ex
4+
5+
if [ ! -d ../build-swift-macos ]; then
6+
echo "Please run ../build-swift-macos.sh first!"
7+
exit 1
8+
fi
9+
10+
# Download and extract the online punctuation model if not exists
11+
if [ ! -d ./sherpa-onnx-online-punct-en-2024-08-06 ]; then
12+
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
13+
tar xvf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
14+
rm sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
15+
fi
16+
17+
if [ ! -e ./add-punctuation-online ]; then
18+
# Note: We use -lc++ to link against libc++ instead of libstdc++
19+
swiftc \
20+
-lc++ \
21+
-I ../build-swift-macos/install/include \
22+
-import-objc-header ./SherpaOnnx-Bridging-Header.h \
23+
./add-punctuation-online.swift ./SherpaOnnx.swift \
24+
-L ../build-swift-macos/install/lib/ \
25+
-l sherpa-onnx \
26+
-l onnxruntime \
27+
-o ./add-punctuation-online
28+
29+
strip ./add-punctuation-online
30+
else
31+
echo "./add-punctuation-online exists - skip building"
32+
fi
33+
34+
# Set library path and run the executable
35+
export DYLD_LIBRARY_PATH=$PWD/../build-swift-macos/install/lib:$DYLD_LIBRARY_PATH
36+
./add-punctuation-online

0 commit comments

Comments
 (0)