diff --git a/extensions/tokenizers/rust/src/lib.rs b/extensions/tokenizers/rust/src/lib.rs index 34e532de897..2a826c3f56a 100644 --- a/extensions/tokenizers/rust/src/lib.rs +++ b/extensions/tokenizers/rust/src/lib.rs @@ -16,11 +16,13 @@ extern crate tokenizers as tk; use std::str::FromStr; use tk::tokenizer::{EncodeInput, Encoding}; -use tk::FromPretrainedParameters; use tk::Tokenizer; +use tk::{FromPretrainedParameters, Offsets}; -use jni::objects::{JObject, JString}; -use jni::sys::{jboolean, jlong, jlongArray, jobjectArray, jsize, JNI_TRUE}; +use jni::objects::{JClass, JMethodID, JObject, JString, JValue}; +use jni::sys::{ + jboolean, jint, jlong, jlongArray, jobjectArray, jsize, JNI_TRUE, +}; use jni::JNIEnv; #[no_mangle] @@ -256,7 +258,8 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ .unwrap(); for (i, token) in tokens.iter().enumerate() { let item: JString = env.new_string(&token).unwrap(); - env.set_object_array_element(array, i as jsize, item).unwrap(); + env.set_object_array_element(array, i as jsize, item) + .unwrap(); } array } @@ -299,6 +302,48 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ array } +#[no_mangle] +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTokenCharSpans( + env: JNIEnv, + _: JObject, + handle: jlong, +) -> jobjectArray { + let encoding = cast_handle::(handle); + let tokens = encoding.get_tokens(); + let len = tokens.len() as jsize; + + let array: jobjectArray = env + .new_object_array( + len, + "ai/djl/huggingface/tokenizers/jni/CharSpan", + JObject::null(), + ) + .unwrap(); + for (i, _) in tokens.iter().enumerate() { + let opt_offsets: Option<(usize, Offsets)> = encoding.token_to_chars(i); + match &opt_offsets { + Some((_, offsets)) => { + let class_id = "ai/djl/huggingface/tokenizers/jni/CharSpan"; + let method_id = ""; + let params = "(II)V"; + let cls: JClass = env.find_class(class_id).unwrap(); + let constructor: JMethodID = env.get_method_id(cls, method_id, params).unwrap(); + let offsets_vec: Vec = vec![ + JValue::Int((*offsets).0 as jint), + JValue::Int((*offsets).1 as jint), + ]; + let obj = env + .new_object_unchecked(cls, constructor, &offsets_vec[..]) + .unwrap(); + env.set_object_array_element(array, i as jsize, obj) + .unwrap(); + } + None => {} + } + } + array +} + fn to_handle(val: T) -> jlong { let handle = Box::into_raw(Box::new(val)) as jlong; handle diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java index 9d894add6fe..dfbee95de0c 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java @@ -12,6 +12,8 @@ */ package ai.djl.huggingface.tokenizers; +import ai.djl.huggingface.tokenizers.jni.CharSpan; + /** A class holds token encoding information. */ public class Encoding { @@ -21,6 +23,7 @@ public class Encoding { private long[] wordIds; private long[] attentionMask; private long[] specialTokenMask; + private CharSpan[] charTokenSpans; Encoding( long[] ids, @@ -28,13 +31,15 @@ public class Encoding { String[] tokens, long[] wordIds, long[] attentionMask, - long[] specialTokenMask) { + long[] specialTokenMask, + CharSpan[] charTokenSpans) { this.ids = ids; this.typeIds = typeIds; this.tokens = tokens; this.wordIds = wordIds; this.attentionMask = attentionMask; this.specialTokenMask = specialTokenMask; + this.charTokenSpans = charTokenSpans; } /** @@ -90,4 +95,13 @@ public long[] getAttentionMask() { public long[] getSpecialTokenMask() { return specialTokenMask; } + + /** + * Returns char token spans. + * + * @return char token spans + */ + public CharSpan[] getCharTokenSpans() { + return charTokenSpans; + } } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index 3bf130d50cb..1d540faaf70 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -12,6 +12,7 @@ */ package ai.djl.huggingface.tokenizers; +import ai.djl.huggingface.tokenizers.jni.CharSpan; import ai.djl.huggingface.tokenizers.jni.LibUtils; import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary; import ai.djl.modality.nlp.preprocess.Tokenizer; @@ -200,8 +201,10 @@ private Encoding toEncoding(long encoding) { long[] wordIds = TokenizersLibrary.LIB.getWordIds(encoding); long[] attentionMask = TokenizersLibrary.LIB.getAttentionMask(encoding); long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding); + CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding); TokenizersLibrary.LIB.deleteEncoding(encoding); - return new Encoding(ids, typeIds, tokens, wordIds, attentionMask, specialTokenMask); + return new Encoding( + ids, typeIds, tokens, wordIds, attentionMask, specialTokenMask, charSpans); } } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/CharSpan.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/CharSpan.java new file mode 100644 index 00000000000..7478d0e8b75 --- /dev/null +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/CharSpan.java @@ -0,0 +1,51 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** A class holds character span information. */ +package ai.djl.huggingface.tokenizers.jni; + +/** A class holds character span information. */ +public class CharSpan { + + private final int start; + private final int end; + + /** + * Constructs a new {@code CharSpan} instance. + * + * @param start the start position + * @param end the end position + */ + public CharSpan(int start, int end) { + this.start = start; + this.end = end; + } + + /** + * Returns the start position. + * + * @return the start position + */ + public int getStart() { + return start; + } + + /** + * Returns the end position. + * + * @return the end position + */ + public int getEnd() { + return end; + } +} diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java index cb8b6830853..7ca5bd284b8 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java @@ -45,4 +45,6 @@ private TokenizersLibrary() {} public native long[] getAttentionMask(long encoding); public native long[] getSpecialTokenMask(long encoding); + + public native CharSpan[] getTokenCharSpans(long encoding); } diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java index 4f559e430ee..b4cd211bb19 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java @@ -13,6 +13,7 @@ package ai.djl.huggingface.tokenizers; +import ai.djl.huggingface.tokenizers.jni.CharSpan; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -49,6 +50,34 @@ public void testTokenizer() { Assert.assertEquals(attentionMask, encoding.getAttentionMask()); Assert.assertEquals(specialTokenMask, encoding.getSpecialTokenMask()); + CharSpan[] charSpansExpected = { + null, + new CharSpan(0, 5), + new CharSpan(5, 6), + new CharSpan(7, 8), + new CharSpan(8, 9), + new CharSpan(9, 12), + new CharSpan(12, 13), + new CharSpan(14, 17), + new CharSpan(18, 21), + new CharSpan(22, 25), + new CharSpan(26, 30), + new CharSpan(31, 32), + null + }; + int expectedLength = charSpansExpected.length; + CharSpan[] charSpansResult = encoding.getCharTokenSpans(); + + Assert.assertEquals(expectedLength, charSpansResult.length); + Assert.assertEquals(charSpansExpected[0], charSpansResult[0]); + Assert.assertEquals( + charSpansExpected[expectedLength - 1], charSpansResult[expectedLength - 1]); + + for (int i = 1; i < expectedLength - 1; i++) { + Assert.assertEquals(charSpansExpected[i].getStart(), charSpansResult[i].getStart()); + Assert.assertEquals(charSpansExpected[i].getEnd(), charSpansResult[i].getEnd()); + } + Encoding[] encodings = tokenizer.batchEncode(Arrays.asList(inputs)); Assert.assertEquals(encodings.length, 2); }