Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ios): make option for enable/disable Core ML #145

Merged
merged 4 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions cpp/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,9 @@ struct whisper_context {
whisper_state * state = nullptr;

std::string path_model; // populated by whisper_init_from_file()
#ifdef WHISPER_USE_COREML
bool load_coreml = true;
#endif
};

static void whisper_default_log(const char * text) {
Expand Down Expand Up @@ -2854,6 +2857,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
}

#ifdef WHISPER_USE_COREML
if (ctx->load_coreml) { // Not in correct layer for easy patch
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);

log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
Expand All @@ -2869,6 +2873,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
} else {
log("%s: Core ML model loaded\n", __func__);
}
}
#endif

state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
Expand Down Expand Up @@ -2989,6 +2994,23 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
return state;
}

#ifdef WHISPER_USE_COREML
struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model) {
whisper_context * ctx = whisper_init_from_file_no_state(path_model);
if (!ctx) {
return nullptr;
}
ctx->load_coreml = false;
ctx->state = whisper_init_state(ctx);
if (!ctx->state) {
whisper_free(ctx);
return nullptr;
}

return ctx;
}
#endif

int whisper_ctx_init_openvino_encoder(
struct whisper_context * ctx,
const char * model_path,
Expand Down
3 changes: 3 additions & 0 deletions cpp/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ extern "C" {
// Various functions for loading a ggml whisper model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
#ifdef WHISPER_USE_COREML
WHISPER_API struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model);
#endif
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model);
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
Expand Down
5 changes: 3 additions & 2 deletions example/ios/Podfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ ENV['RCT_NEW_ARCH_ENABLED'] = '1'

target 'RNWhisperExample' do
# Tip: You can use RNWHISPER_DISABLE_COREML = '1' to disable CoreML support.
ENV['RNWHISPER_DISABLE_COREML'] = '1' # TEMP
ENV['RNWHISPER_DISABLE_METAL'] = '0' # TEMP
ENV['RNWHISPER_DISABLE_COREML'] = '0'

ENV['RNWHISPER_ENABLE_METAL'] = '0' # TODO

config = use_native_modules!

Expand Down
6 changes: 3 additions & 3 deletions example/ios/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ PODS:
- SSZipArchive (~> 2.2)
- SocketRocket (0.6.0)
- SSZipArchive (2.4.3)
- whisper-rn (0.3.9):
- whisper-rn (0.4.0-rc.0):
- RCT-Folly
- RCTRequired
- RCTTypeSafety
Expand Down Expand Up @@ -1006,10 +1006,10 @@ SPEC CHECKSUMS:
RNZipArchive: ef9451b849c45a29509bf44e65b788829ab07801
SocketRocket: fccef3f9c5cedea1353a9ef6ada904fde10d6608
SSZipArchive: fe6a26b2a54d5a0890f2567b5cc6de5caa600aef
whisper-rn: b3c5abf27f09df7c9d5d089ad1275c2ec20a23aa
whisper-rn: a333c75700c2d031cecf12db9255459b01602d56
Yoga: f7decafdc5e8c125e6fa0da38a687e35238420fa
YogaKit: f782866e155069a2cca2517aafea43200b01fd5a

PODFILE CHECKSUM: 37f5c1045c7d04c6e5332174cca5f32f528700cf
PODFILE CHECKSUM: a78cf54fa529c6dc4b44aaf32b861fdf1245919a

COCOAPODS: 1.11.3
2 changes: 2 additions & 0 deletions ios/RNWhisper.mm
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ - (NSDictionary *)constantsToExport

NSString *modelPath = [modelOptions objectForKey:@"filePath"];
BOOL isBundleAsset = [[modelOptions objectForKey:@"isBundleAsset"] boolValue];
BOOL useCoreMLIos = [[modelOptions objectForKey:@"useCoreMLIos"] boolValue];

// For support debug assets in development mode
BOOL downloadCoreMLAssets = [[modelOptions objectForKey:@"downloadCoreMLAssets"] boolValue];
Expand Down Expand Up @@ -75,6 +76,7 @@ - (NSDictionary *)constantsToExport
RNWhisperContext *context = [RNWhisperContext
initWithModelPath:path
contextId:contextId
noCoreML:!useCoreMLIos
];
if ([context getContext] == NULL) {
reject(@"whisper_cpp_error", @"Failed to load the model", nil);
Expand Down
2 changes: 1 addition & 1 deletion ios/RNWhisperContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ typedef struct {
RNWhisperContextRecordState recordState;
}

+ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId;
+ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId noCoreML:(BOOL)noCoreML;
- (struct whisper_context *)getContext;
- (dispatch_queue_t)getDispatchQueue;
- (OSStatus)transcribeRealtime:(int)jobId
Expand Down
10 changes: 9 additions & 1 deletion ios/RNWhisperContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@

@implementation RNWhisperContext

+ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId {
+ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId noCoreML:(BOOL)noCoreML {
RNWhisperContext *context = [[RNWhisperContext alloc] init];
context->contextId = contextId;
#ifdef WHISPER_USE_COREML
if (noCoreML) {
context->ctx = whisper_init_from_file_no_coreml([modelPath UTF8String]);
} else {
context->ctx = whisper_init_from_file([modelPath UTF8String]);
}
#else
context->ctx = whisper_init_from_file([modelPath UTF8String]);
#endif
context->dQueue = dispatch_queue_create(
[[NSString stringWithFormat:@"RNWhisperContext-%d", contextId] UTF8String],
DISPATCH_QUEUE_SERIAL
Expand Down
2 changes: 2 additions & 0 deletions scripts/bootstrap.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ yarn example

# Apply patch
patch -p0 -d ./cpp < ./scripts/ggml-metal.m.patch
patch -p0 -d ./cpp < ./scripts/whisper.h.patch
patch -p0 -d ./cpp < ./scripts/whisper.cpp.patch
patch -p0 -d ./cpp/coreml < ./scripts/whisper-encoder.mm.patch

# Download model for example
Expand Down
53 changes: 53 additions & 0 deletions scripts/whisper.cpp.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
--- whisper.cpp.orig 2023-10-12 11:44:51
+++ whisper.cpp 2023-10-12 11:43:31
@@ -770,6 +770,9 @@
whisper_state * state = nullptr;

std::string path_model; // populated by whisper_init_from_file()
+#ifdef WHISPER_USE_COREML
+ bool load_coreml = true;
+#endif
};

static void whisper_default_log(const char * text) {
@@ -2854,6 +2857,7 @@
}

#ifdef WHISPER_USE_COREML
+if (ctx->load_coreml) { // Not in correct layer for easy patch
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);

log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
@@ -2869,6 +2873,7 @@
} else {
log("%s: Core ML model loaded\n", __func__);
}
+}
#endif

state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
@@ -2987,7 +2992,24 @@
state->rng = std::mt19937(0);

return state;
+}
+
+#ifdef WHISPER_USE_COREML
+struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model) {
+ whisper_context * ctx = whisper_init_from_file_no_state(path_model);
+ if (!ctx) {
+ return nullptr;
+ }
+ ctx->load_coreml = false;
+ ctx->state = whisper_init_state(ctx);
+ if (!ctx->state) {
+ whisper_free(ctx);
+ return nullptr;
+ }
+
+ return ctx;
}
+#endif

int whisper_ctx_init_openvino_encoder(
struct whisper_context * ctx,
12 changes: 12 additions & 0 deletions scripts/whisper.h.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
--- whisper.h.orig 2023-10-12 10:41:41
+++ whisper.h 2023-10-12 10:38:11
@@ -99,6 +99,9 @@
// Various functions for loading a ggml whisper model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
+#ifdef WHISPER_USE_COREML
+ WHISPER_API struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model);
+#endif
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model);
WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
1 change: 1 addition & 0 deletions src/NativeRNWhisper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ export type CoreMLAsset = {
type NativeContextOptions = {
filePath: string,
isBundleAsset: boolean,
useCoreMLIos?: boolean,
downloadCoreMLAssets?: boolean,
coreMLAssets?: CoreMLAsset[],
}
Expand Down
4 changes: 4 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,8 @@ export type ContextOptions = {
}
/** Is the file path a bundle asset for pure string filePath */
isBundleAsset?: boolean
/** Prefer to use Core ML model if exists. If set to false, even if the Core ML model exists, it will not be used. */
useCoreMLIos?: boolean
}

const coreMLModelAssetPaths = [
Expand All @@ -451,6 +453,7 @@ export async function initWhisper({
filePath,
coreMLModelAsset,
isBundleAsset,
useCoreMLIos = true,
}: ContextOptions): Promise<WhisperContext> {
let path = ''
let coreMLAssets: CoreMLAsset[] | undefined
Expand Down Expand Up @@ -499,6 +502,7 @@ export async function initWhisper({
const id = await RNWhisper.initContext({
filePath: path,
isBundleAsset: !!isBundleAsset,
useCoreMLIos,
// Only development mode need download Core ML model assets (from packager server)
downloadCoreMLAssets: __DEV__ && !!coreMLAssets,
coreMLAssets,
Expand Down