From 734683bf98a301121814ec1f369d2d4abbe9955c Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Wed, 1 Dec 2021 22:18:05 -0800 Subject: [PATCH] add sentencepiece load from bytes method --- .../java/ai/djl/sentencepiece/SpProcessor.java | 4 ++++ .../java/ai/djl/sentencepiece/SpTokenizer.java | 10 ++++++++++ .../sentencepiece/jni/SentencePieceLibrary.java | 2 ++ ...djl_sentencepiece_jni_SentencePieceLibrary.cc | 10 ++++++++++ .../ai/djl/sentencepiece/SpTokenizerTest.java | 16 ++++++++++++++++ 5 files changed, 42 insertions(+) diff --git a/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpProcessor.java b/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpProcessor.java index d2c6913c705..44d05936e58 100644 --- a/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpProcessor.java +++ b/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpProcessor.java @@ -44,6 +44,10 @@ void loadModel(String path) { SentencePieceLibrary.LIB.loadModel(getHandle(), path); } + void loadModelFromBytes(byte[] serializedProto) { + SentencePieceLibrary.LIB.loadModelFromBytes(getHandle(), serializedProto); + } + /** * Tokenize a sentence into array of tokens. * diff --git a/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpTokenizer.java b/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpTokenizer.java index d11c2a903ff..613fcced47f 100644 --- a/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpTokenizer.java +++ b/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpTokenizer.java @@ -50,6 +50,16 @@ public SpTokenizer(Path modelPath, String prefix) throws IOException { loadModel(modelPath, prefix); } + /** + * Creates a SentencePiece Tokenizer from byte array. + * + * @param serializedModel the serialized model + */ + public SpTokenizer(byte[] serializedModel) { + this.processor = SpProcessor.newInstance(); + processor.loadModelFromBytes(serializedModel); + } + /** {@inheritDoc} */ @Override public List tokenize(String sentence) { diff --git a/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/jni/SentencePieceLibrary.java b/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/jni/SentencePieceLibrary.java index 2bf7b09bcc7..35cbd4d3ff1 100644 --- a/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/jni/SentencePieceLibrary.java +++ b/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/jni/SentencePieceLibrary.java @@ -24,6 +24,8 @@ private SentencePieceLibrary() {} public native void loadModel(long handle, String filePath); + public native void loadModelFromBytes(long handle, byte[] bytes); + public native void deleteSentencePieceProcessor(long handle); public native String[] tokenize(long handle, String text); diff --git a/extensions/sentencepiece/src/main/native/ai_djl_sentencepiece_jni_SentencePieceLibrary.cc b/extensions/sentencepiece/src/main/native/ai_djl_sentencepiece_jni_SentencePieceLibrary.cc index 0a335297811..cf6d1c73ce9 100644 --- a/extensions/sentencepiece/src/main/native/ai_djl_sentencepiece_jni_SentencePieceLibrary.cc +++ b/extensions/sentencepiece/src/main/native/ai_djl_sentencepiece_jni_SentencePieceLibrary.cc @@ -37,6 +37,16 @@ JNIEXPORT void JNICALL Java_ai_djl_sentencepiece_jni_SentencePieceLibrary_loadMo CheckStatus(env, processor_ptr->Load(path_string)); } +JNIEXPORT void JNICALL Java_ai_djl_sentencepiece_jni_SentencePieceLibrary_loadModelFromBytes( + JNIEnv* env, jobject jthis, jlong jhandle, jbyteArray jserialized) { + auto* processor_ptr = reinterpret_cast(jhandle); + int length = env->GetArrayLength(jserialized); + std::vector buff(length, 0); + env->GetByteArrayRegion(jserialized, 0, length, reinterpret_cast(buff.data())); + std::string serialized(buff.data(), buff.size()); + CheckStatus(env, processor_ptr->LoadFromSerializedProto(serialized)); +} + JNIEXPORT void JNICALL Java_ai_djl_sentencepiece_jni_SentencePieceLibrary_deleteSentencePieceProcessor( JNIEnv* env, jobject jthis, jlong jhandle) { auto* processor_ptr = reinterpret_cast(jhandle); diff --git a/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java b/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java index 1b2f97f2255..e143603906e 100644 --- a/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java +++ b/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java @@ -37,6 +37,22 @@ public void downloadModel() throws IOException { } } + @Test + public void testLoadFromBytes() throws IOException { + TestRequirements.notWindows(); + + Path modelPath = Paths.get("build/test/models/sententpiece_test_model.model"); + byte[] bytes = Files.readAllBytes(modelPath); + try (SpTokenizer tokenizer = new SpTokenizer(bytes)) { + String original = "Hello World"; + List tokens = tokenizer.tokenize(original); + List expected = Arrays.asList("▁He", "ll", "o", "▁", "W", "or", "l", "d"); + Assert.assertEquals(tokens, expected); + String recovered = tokenizer.buildSentence(tokens); + Assert.assertEquals(original, recovered); + } + } + @Test public void testTokenize() throws IOException { TestRequirements.notWindows();