|
| 1 | +// scripts/node-addon-api/src/punctuation.cc |
| 2 | +// |
| 3 | +// Copyright (c) 2024 Xiaomi Corporation |
| 4 | +#include <sstream> |
| 5 | + |
| 6 | +#include "macros.h" // NOLINT |
| 7 | +#include "napi.h" // NOLINT |
| 8 | +#include "sherpa-onnx/c-api/c-api.h" |
| 9 | + |
| 10 | +static SherpaOnnxOfflinePunctuationModelConfig GetOfflinePunctuationModelConfig( |
| 11 | + Napi::Object obj) { |
| 12 | + SherpaOnnxOfflinePunctuationModelConfig c; |
| 13 | + memset(&c, 0, sizeof(c)); |
| 14 | + |
| 15 | + if (!obj.Has("model") || !obj.Get("model").IsObject()) { |
| 16 | + return c; |
| 17 | + } |
| 18 | + |
| 19 | + Napi::Object o = obj.Get("model").As<Napi::Object>(); |
| 20 | + |
| 21 | + SHERPA_ONNX_ASSIGN_ATTR_STR(ct_transformer, ctTransformer); |
| 22 | + |
| 23 | + SHERPA_ONNX_ASSIGN_ATTR_INT32(num_threads, numThreads); |
| 24 | + |
| 25 | + if (o.Has("debug") && |
| 26 | + (o.Get("debug").IsNumber() || o.Get("debug").IsBoolean())) { |
| 27 | + if (o.Get("debug").IsBoolean()) { |
| 28 | + c.debug = o.Get("debug").As<Napi::Boolean>().Value(); |
| 29 | + } else { |
| 30 | + c.debug = o.Get("debug").As<Napi::Number>().Int32Value(); |
| 31 | + } |
| 32 | + } |
| 33 | + SHERPA_ONNX_ASSIGN_ATTR_STR(provider, provider); |
| 34 | + |
| 35 | + return c; |
| 36 | +} |
| 37 | + |
| 38 | +static Napi::External<SherpaOnnxOfflinePunctuation> |
| 39 | +CreateOfflinePunctuationWrapper(const Napi::CallbackInfo &info) { |
| 40 | + Napi::Env env = info.Env(); |
| 41 | + if (info.Length() != 1) { |
| 42 | + std::ostringstream os; |
| 43 | + os << "Expect only 1 argument. Given: " << info.Length(); |
| 44 | + |
| 45 | + Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException(); |
| 46 | + |
| 47 | + return {}; |
| 48 | + } |
| 49 | + |
| 50 | + if (!info[0].IsObject()) { |
| 51 | + Napi::TypeError::New(env, "You should pass an object as the only argument.") |
| 52 | + .ThrowAsJavaScriptException(); |
| 53 | + |
| 54 | + return {}; |
| 55 | + } |
| 56 | + |
| 57 | + Napi::Object o = info[0].As<Napi::Object>(); |
| 58 | + |
| 59 | + SherpaOnnxOfflinePunctuationConfig c; |
| 60 | + memset(&c, 0, sizeof(c)); |
| 61 | + c.model = GetOfflinePunctuationModelConfig(o); |
| 62 | + |
| 63 | + const SherpaOnnxOfflinePunctuation *punct = |
| 64 | + SherpaOnnxCreateOfflinePunctuation(&c); |
| 65 | + |
| 66 | + if (c.model.ct_transformer) { |
| 67 | + delete[] c.model.ct_transformer; |
| 68 | + } |
| 69 | + |
| 70 | + if (c.model.provider) { |
| 71 | + delete[] c.model.provider; |
| 72 | + } |
| 73 | + |
| 74 | + if (!punct) { |
| 75 | + Napi::TypeError::New(env, "Please check your config!") |
| 76 | + .ThrowAsJavaScriptException(); |
| 77 | + |
| 78 | + return {}; |
| 79 | + } |
| 80 | + |
| 81 | + return Napi::External<SherpaOnnxOfflinePunctuation>::New( |
| 82 | + env, const_cast<SherpaOnnxOfflinePunctuation *>(punct), |
| 83 | + [](Napi::Env env, SherpaOnnxOfflinePunctuation *punct) { |
| 84 | + SherpaOnnxDestroyOfflinePunctuation(punct); |
| 85 | + }); |
| 86 | +} |
| 87 | + |
| 88 | +static Napi::String OfflinePunctuationAddPunctWraper( |
| 89 | + const Napi::CallbackInfo &info) { |
| 90 | + Napi::Env env = info.Env(); |
| 91 | + if (info.Length() != 2) { |
| 92 | + std::ostringstream os; |
| 93 | + os << "Expect only 2 arguments. Given: " << info.Length(); |
| 94 | + |
| 95 | + Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException(); |
| 96 | + |
| 97 | + return {}; |
| 98 | + } |
| 99 | + |
| 100 | + if (!info[0].IsExternal()) { |
| 101 | + Napi::TypeError::New( |
| 102 | + env, |
| 103 | + "You should pass an offline punctuation pointer as the first argument") |
| 104 | + .ThrowAsJavaScriptException(); |
| 105 | + |
| 106 | + return {}; |
| 107 | + } |
| 108 | + |
| 109 | + if (!info[1].IsString()) { |
| 110 | + Napi::TypeError::New(env, "You should pass a string as the second argument") |
| 111 | + .ThrowAsJavaScriptException(); |
| 112 | + |
| 113 | + return {}; |
| 114 | + } |
| 115 | + |
| 116 | + SherpaOnnxOfflinePunctuation *punct = |
| 117 | + info[0].As<Napi::External<SherpaOnnxOfflinePunctuation>>().Data(); |
| 118 | + Napi::String js_text = info[1].As<Napi::String>(); |
| 119 | + std::string text = js_text.Utf8Value(); |
| 120 | + |
| 121 | + const char *punct_text = |
| 122 | + SherpaOfflinePunctuationAddPunct(punct, text.c_str()); |
| 123 | + |
| 124 | + Napi::String ans = Napi::String::New(env, punct_text); |
| 125 | + SherpaOfflinePunctuationFreeText(punct_text); |
| 126 | + return ans; |
| 127 | +} |
| 128 | + |
| 129 | +void InitPunctuation(Napi::Env env, Napi::Object exports) { |
| 130 | + exports.Set(Napi::String::New(env, "createOfflinePunctuation"), |
| 131 | + Napi::Function::New(env, CreateOfflinePunctuationWrapper)); |
| 132 | + |
| 133 | + exports.Set(Napi::String::New(env, "offlinePunctuationAddPunct"), |
| 134 | + Napi::Function::New(env, OfflinePunctuationAddPunctWraper)); |
| 135 | +} |
0 commit comments