Skip to content

Commit

Permalink
add sentencepiece load from bytes method (#1399)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lanking authored Dec 2, 2021
1 parent 11d9974 commit be5289d
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ void loadModel(String path) {
SentencePieceLibrary.LIB.loadModel(getHandle(), path);
}

void loadModelFromBytes(byte[] serializedProto) {
SentencePieceLibrary.LIB.loadModelFromBytes(getHandle(), serializedProto);
}

/**
* Tokenize a sentence into array of tokens.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ public SpTokenizer(Path modelPath, String prefix) throws IOException {
loadModel(modelPath, prefix);
}

/**
* Creates a SentencePiece Tokenizer from byte array.
*
* @param serializedModel the serialized model
*/
public SpTokenizer(byte[] serializedModel) {
this.processor = SpProcessor.newInstance();
processor.loadModelFromBytes(serializedModel);
}

/** {@inheritDoc} */
@Override
public List<String> tokenize(String sentence) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ private SentencePieceLibrary() {}

public native void loadModel(long handle, String filePath);

public native void loadModelFromBytes(long handle, byte[] bytes);

public native void deleteSentencePieceProcessor(long handle);

public native String[] tokenize(long handle, String text);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ JNIEXPORT void JNICALL Java_ai_djl_sentencepiece_jni_SentencePieceLibrary_loadMo
CheckStatus(env, processor_ptr->Load(path_string));
}

JNIEXPORT void JNICALL Java_ai_djl_sentencepiece_jni_SentencePieceLibrary_loadModelFromBytes(
JNIEnv* env, jobject jthis, jlong jhandle, jbyteArray jserialized) {
auto* processor_ptr = reinterpret_cast<sentencepiece::SentencePieceProcessor*>(jhandle);
int length = env->GetArrayLength(jserialized);
std::vector<char> buff(length, 0);
env->GetByteArrayRegion(jserialized, 0, length, reinterpret_cast<jbyte*>(buff.data()));
std::string serialized(buff.data(), buff.size());
CheckStatus(env, processor_ptr->LoadFromSerializedProto(serialized));
}

JNIEXPORT void JNICALL Java_ai_djl_sentencepiece_jni_SentencePieceLibrary_deleteSentencePieceProcessor(
JNIEnv* env, jobject jthis, jlong jhandle) {
auto* processor_ptr = reinterpret_cast<sentencepiece::SentencePieceProcessor*>(jhandle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ public void downloadModel() throws IOException {
}
}

@Test
public void testLoadFromBytes() throws IOException {
TestRequirements.notWindows();

Path modelPath = Paths.get("build/test/models/sententpiece_test_model.model");
byte[] bytes = Files.readAllBytes(modelPath);
try (SpTokenizer tokenizer = new SpTokenizer(bytes)) {
String original = "Hello World";
List<String> tokens = tokenizer.tokenize(original);
List<String> expected = Arrays.asList("▁He", "ll", "o", "▁", "W", "or", "l", "d");
Assert.assertEquals(tokens, expected);
String recovered = tokenizer.buildSentence(tokens);
Assert.assertEquals(original, recovered);
}
}

@Test
public void testTokenize() throws IOException {
TestRequirements.notWindows();
Expand Down

0 comments on commit be5289d

Please sign in to comment.