Skip to content

Commit a7d8e01

Browse files
authored
Wrap offline ASR APIs to dart (k2-fsa#961)
1 parent 6c08789 commit a7d8e01

8 files changed

+550
-12
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) 2024 Xiaomi Corporation
2+
import 'package:path/path.dart';
3+
import 'package:path_provider/path_provider.dart';
4+
import 'package:flutter/services.dart' show rootBundle;
5+
import 'dart:typed_data';
6+
import "dart:io";
7+
8+
import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
9+
import './utils.dart';
10+
11+
Future<void> testNonStreamingParaformerAsr() async {
12+
var model = 'assets/sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx';
13+
var tokens = 'assets/sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt';
14+
var testWave = 'assets/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav';
15+
16+
model = await copyAssetFile(src: model, dst: 'model.int8.onnx');
17+
tokens = await copyAssetFile(src: tokens, dst: 'tokens.txt');
18+
testWave = await copyAssetFile(src: testWave, dst: '0.wav');
19+
20+
final paraformer = sherpa_onnx.OfflineParaformerModelConfig(
21+
model: model,
22+
);
23+
24+
final modelConfig = sherpa_onnx.OfflineModelConfig(
25+
paraformer: paraformer,
26+
tokens: tokens,
27+
modelType: 'paraformer',
28+
);
29+
30+
final config = sherpa_onnx.OfflineRecognizerConfig(model: modelConfig);
31+
print(config);
32+
final recognizer = sherpa_onnx.OfflineRecognizer(config);
33+
34+
final waveData = sherpa_onnx.readWave(testWave);
35+
final stream = recognizer.createStream();
36+
37+
stream.acceptWaveform(
38+
samples: waveData.samples, sampleRate: waveData.sampleRate);
39+
recognizer.decode(stream);
40+
41+
final result = recognizer.getResult(stream);
42+
print('result is: ${result}');
43+
44+
print('recognizer: ${recognizer.ptr}');
45+
stream.free();
46+
recognizer.free();
47+
}

sherpa-onnx/flutter/example/pubspec.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ flutter:
7474
assets:
7575
- assets/
7676
- assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/
77+
# - assets/sherpa-onnx-paraformer-zh-2023-03-28/
78+
# - assets/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/
7779
# - assets/sr-data/enroll/
7880
# - assets/sr-data/test/
7981
# - images/a_dot_ham.jpeg

sherpa-onnx/flutter/lib/sherpa_onnx.dart

+4
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22
import 'dart:io';
33
import 'dart:ffi';
44

5+
export 'src/feature_config.dart';
6+
export 'src/offline_recognizer.dart';
7+
export 'src/offline_stream.dart';
58
export 'src/online_recognizer.dart';
69
export 'src/online_stream.dart';
710
export 'src/speaker_identification.dart';
811
export 'src/vad.dart';
912
export 'src/wave_reader.dart';
1013
export 'src/wave_writer.dart';
14+
1115
import 'src/sherpa_onnx_bindings.dart';
1216

1317
final DynamicLibrary _dylib = () {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// Copyright (c) 2024 Xiaomi Corporation
2+
3+
class FeatureConfig {
4+
const FeatureConfig({this.sampleRate = 16000, this.featureDim = 80});
5+
6+
@override
7+
String toString() {
8+
return 'FeatureConfig(sampleRate: $sampleRate, featureDim: $featureDim)';
9+
}
10+
11+
final int sampleRate;
12+
final int featureDim;
13+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
// Copyright (c) 2024 Xiaomi Corporation
2+
import 'dart:convert';
3+
import 'dart:ffi';
4+
import 'dart:typed_data';
5+
6+
import 'package:ffi/ffi.dart';
7+
8+
import './feature_config.dart';
9+
import './offline_stream.dart';
10+
import './sherpa_onnx_bindings.dart';
11+
12+
class OfflineTransducerModelConfig {
13+
const OfflineTransducerModelConfig({
14+
this.encoder = '',
15+
this.decoder = '',
16+
this.joiner = '',
17+
});
18+
19+
@override
20+
String toString() {
21+
return 'OfflineTransducerModelConfig(encoder: $encoder, decoder: $decoder, joiner: $joiner)';
22+
}
23+
24+
final String encoder;
25+
final String decoder;
26+
final String joiner;
27+
}
28+
29+
class OfflineParaformerModelConfig {
30+
const OfflineParaformerModelConfig({this.model = ''});
31+
32+
@override
33+
String toString() {
34+
return 'OfflineParaformerModelConfig(model: $model)';
35+
}
36+
37+
final String model;
38+
}
39+
40+
class OfflineNemoEncDecCtcModelConfig {
41+
const OfflineNemoEncDecCtcModelConfig({this.model = ''});
42+
43+
@override
44+
String toString() {
45+
return 'OfflineNemoEncDecCtcModelConfig(model: $model)';
46+
}
47+
48+
final String model;
49+
}
50+
51+
class OfflineWhisperModelConfig {
52+
const OfflineWhisperModelConfig(
53+
{this.encoder = '',
54+
this.decoder = '',
55+
this.language = '',
56+
this.task = '',
57+
this.tailPaddings = -1});
58+
59+
@override
60+
String toString() {
61+
return 'OfflineWhisperModelConfig(encoder: $encoder, decoder: $decoder, language: $language, task: $task, tailPaddings: $tailPaddings)';
62+
}
63+
64+
final String encoder;
65+
final String decoder;
66+
final String language;
67+
final String task;
68+
final int tailPaddings;
69+
}
70+
71+
class OfflineTdnnModelConfig {
72+
const OfflineTdnnModelConfig({this.model = ''});
73+
74+
@override
75+
String toString() {
76+
return 'OfflineTdnnModelConfig(model: $model)';
77+
}
78+
79+
final String model;
80+
}
81+
82+
class OfflineLMConfig {
83+
const OfflineLMConfig({this.model = '', this.scale = 1.0});
84+
85+
@override
86+
String toString() {
87+
return 'OfflineLMConfig(model: $model, scale: $scale)';
88+
}
89+
90+
final String model;
91+
final double scale;
92+
}
93+
94+
class OfflineModelConfig {
95+
const OfflineModelConfig({
96+
this.transducer = const OfflineTransducerModelConfig(),
97+
this.paraformer = const OfflineParaformerModelConfig(),
98+
this.nemoCtc = const OfflineNemoEncDecCtcModelConfig(),
99+
this.whisper = const OfflineWhisperModelConfig(),
100+
this.tdnn = const OfflineTdnnModelConfig(),
101+
required this.tokens,
102+
this.numThreads = 1,
103+
this.debug = true,
104+
this.provider = 'cpu',
105+
this.modelType = '',
106+
});
107+
108+
@override
109+
String toString() {
110+
return 'OfflineModelConfig(transducer: $transducer, paraformer: $paraformer, nemoCtc: $nemoCtc, whisper: $whisper, tdnn: $tdnn, tokens: $tokens, numThreads: $numThreads, debug: $debug, provider: $provider, modelType: $modelType)';
111+
}
112+
113+
final OfflineTransducerModelConfig transducer;
114+
final OfflineParaformerModelConfig paraformer;
115+
final OfflineNemoEncDecCtcModelConfig nemoCtc;
116+
final OfflineWhisperModelConfig whisper;
117+
final OfflineTdnnModelConfig tdnn;
118+
119+
final String tokens;
120+
final int numThreads;
121+
final bool debug;
122+
final String provider;
123+
final String modelType;
124+
}
125+
126+
class OfflineRecognizerConfig {
127+
const OfflineRecognizerConfig({
128+
this.feat = const FeatureConfig(),
129+
required this.model,
130+
this.lm = const OfflineLMConfig(),
131+
this.decodingMethod = 'greedy_search',
132+
this.maxActivePaths = 4,
133+
this.hotwordsFile = '',
134+
this.hotwordsScore = 1.5,
135+
});
136+
137+
@override
138+
String toString() {
139+
return 'OfflineRecognizerConfig(feat: $feat, model: $model, lm: $lm, decodingMethod: $decodingMethod, maxActivePaths: $maxActivePaths, hotwordsFile: $hotwordsFile, hotwordsScore: $hotwordsScore)';
140+
}
141+
142+
final FeatureConfig feat;
143+
final OfflineModelConfig model;
144+
final OfflineLMConfig lm;
145+
final String decodingMethod;
146+
147+
final int maxActivePaths;
148+
149+
final String hotwordsFile;
150+
151+
final double hotwordsScore;
152+
}
153+
154+
class OfflineRecognizerResult {
155+
OfflineRecognizerResult(
156+
{required this.text, required this.tokens, required this.timestamps});
157+
158+
@override
159+
String toString() {
160+
return 'OfflineRecognizerResult(text: $text, tokens: $tokens, timestamps: $timestamps)';
161+
}
162+
163+
final String text;
164+
final List<String> tokens;
165+
final List<double> timestamps;
166+
}
167+
168+
class OfflineRecognizer {
169+
OfflineRecognizer._({required this.ptr, required this.config});
170+
171+
void free() {
172+
SherpaOnnxBindings.destroyOfflineRecognizer?.call(ptr);
173+
ptr = nullptr;
174+
}
175+
176+
/// The user is responsible to call the OfflineRecognizer.free()
177+
/// method of the returned instance to avoid memory leak.
178+
factory OfflineRecognizer(OfflineRecognizerConfig config) {
179+
final c = calloc<SherpaOnnxOfflineRecognizerConfig>();
180+
181+
c.ref.feat.sampleRate = config.feat.sampleRate;
182+
c.ref.feat.featureDim = config.feat.featureDim;
183+
184+
// transducer
185+
c.ref.model.transducer.encoder =
186+
config.model.transducer.encoder.toNativeUtf8();
187+
c.ref.model.transducer.decoder =
188+
config.model.transducer.decoder.toNativeUtf8();
189+
c.ref.model.transducer.joiner =
190+
config.model.transducer.joiner.toNativeUtf8();
191+
192+
// paraformer
193+
c.ref.model.paraformer.model = config.model.paraformer.model.toNativeUtf8();
194+
195+
// nemoCtc
196+
c.ref.model.nemoCtc.model = config.model.nemoCtc.model.toNativeUtf8();
197+
198+
// whisper
199+
c.ref.model.whisper.encoder = config.model.whisper.encoder.toNativeUtf8();
200+
201+
c.ref.model.whisper.decoder = config.model.whisper.decoder.toNativeUtf8();
202+
203+
c.ref.model.whisper.language = config.model.whisper.language.toNativeUtf8();
204+
205+
c.ref.model.whisper.task = config.model.whisper.task.toNativeUtf8();
206+
207+
c.ref.model.whisper.tailPaddings = config.model.whisper.tailPaddings;
208+
209+
c.ref.model.tdnn.model = config.model.tdnn.model.toNativeUtf8();
210+
211+
c.ref.model.tokens = config.model.tokens.toNativeUtf8();
212+
213+
c.ref.model.numThreads = config.model.numThreads;
214+
c.ref.model.debug = config.model.debug ? 1 : 0;
215+
c.ref.model.provider = config.model.provider.toNativeUtf8();
216+
c.ref.model.modelType = config.model.modelType.toNativeUtf8();
217+
218+
c.ref.lm.model = config.lm.model.toNativeUtf8();
219+
c.ref.lm.scale = config.lm.scale;
220+
221+
c.ref.decodingMethod = config.decodingMethod.toNativeUtf8();
222+
c.ref.maxActivePaths = config.maxActivePaths;
223+
224+
c.ref.hotwordsFile = config.hotwordsFile.toNativeUtf8();
225+
c.ref.hotwordsScore = config.hotwordsScore;
226+
227+
final ptr = SherpaOnnxBindings.createOfflineRecognizer?.call(c) ?? nullptr;
228+
229+
calloc.free(c.ref.hotwordsFile);
230+
calloc.free(c.ref.decodingMethod);
231+
calloc.free(c.ref.lm.model);
232+
calloc.free(c.ref.model.modelType);
233+
calloc.free(c.ref.model.provider);
234+
calloc.free(c.ref.model.tokens);
235+
calloc.free(c.ref.model.tdnn.model);
236+
calloc.free(c.ref.model.whisper.task);
237+
calloc.free(c.ref.model.whisper.language);
238+
calloc.free(c.ref.model.whisper.decoder);
239+
calloc.free(c.ref.model.whisper.encoder);
240+
calloc.free(c.ref.model.nemoCtc.model);
241+
calloc.free(c.ref.model.paraformer.model);
242+
calloc.free(c.ref.model.transducer.encoder);
243+
calloc.free(c.ref.model.transducer.decoder);
244+
calloc.free(c.ref.model.transducer.joiner);
245+
calloc.free(c);
246+
247+
return OfflineRecognizer._(ptr: ptr, config: config);
248+
}
249+
250+
/// The user has to invoke stream.free() on the returned instance
251+
/// to avoid memory leak
252+
OfflineStream createStream() {
253+
final p = SherpaOnnxBindings.createOfflineStream?.call(ptr) ?? nullptr;
254+
return OfflineStream(ptr: p);
255+
}
256+
257+
void decode(OfflineStream stream) {
258+
SherpaOnnxBindings.decodeOfflineStream?.call(ptr, stream.ptr);
259+
}
260+
261+
OfflineRecognizerResult getResult(OfflineStream stream) {
262+
final json =
263+
SherpaOnnxBindings.getOfflineStreamResultAsJson?.call(stream.ptr) ??
264+
nullptr;
265+
if (json == null) {
266+
return OfflineRecognizerResult(text: '', tokens: [], timestamps: []);
267+
}
268+
269+
final parsedJson = jsonDecode(json.toDartString());
270+
271+
SherpaOnnxBindings.destroyOfflineStreamResultJson?.call(json);
272+
273+
return OfflineRecognizerResult(
274+
text: parsedJson['text'],
275+
tokens: List<String>.from(parsedJson['tokens']),
276+
timestamps: List<double>.from(parsedJson['timestamps']));
277+
}
278+
279+
Pointer<SherpaOnnxOfflineRecognizer> ptr;
280+
OfflineRecognizerConfig config;
281+
}

0 commit comments

Comments
 (0)