diff --git a/extensions/tokenizers/rust/src/lib.rs b/extensions/tokenizers/rust/src/lib.rs index 3b0d52a7e0f..17bfa79b816 100644 --- a/extensions/tokenizers/rust/src/lib.rs +++ b/extensions/tokenizers/rust/src/lib.rs @@ -18,6 +18,8 @@ use std::str::FromStr; use tk::tokenizer::{EncodeInput, Encoding}; use tk::Tokenizer; use tk::{FromPretrainedParameters, Offsets}; +use tk::utils::truncation::{TruncationParams, TruncationStrategy}; +use tk::utils::padding::{PaddingParams, PaddingStrategy}; use jni::objects::{JClass, JMethodID, JObject, JString, JValue, ReleaseMode}; use jni::sys::{ @@ -405,6 +407,81 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ ret } +#[no_mangle] +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_setPadding( + env: JNIEnv, + _: JObject, + handle: jlong, + max_length: jlong, + padding_strategy: JString, + pad_to_multiple_of: jlong, +) { + let strategy: String = env + .get_string(padding_strategy) + .expect("Couldn't get java string!") + .into(); + let len = max_length as usize; + let res_strategy = match strategy.as_ref() { + "batch_longest" => Ok(PaddingStrategy::BatchLongest), + "fixed" => Ok(PaddingStrategy::Fixed(len)), + _ => Err("strategy must be one of [batch_longest, fixed]"), + }; + + let mut params = PaddingParams::default(); + params.strategy = res_strategy.unwrap(); + params.pad_to_multiple_of = Some(pad_to_multiple_of as usize); + let tokenizer = cast_handle::(handle); + tokenizer.with_padding(Some(params)); +} + +#[no_mangle] +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_disablePadding( + _env: JNIEnv, + _: JObject, + handle: jlong, +) { + let tokenizer = cast_handle::(handle); + tokenizer.with_padding(None); +} + +#[no_mangle] +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_setTruncation( + env: JNIEnv, + _: JObject, + handle: jlong, + truncation_max_length: jlong, + truncation_strategy: JString, + truncation_stride: jlong, +) { + let strategy: String = env + .get_string(truncation_strategy) + .expect("Couldn't get java string!") + .into(); + let res_strategy = match strategy.as_ref() { + "longest_first" => Ok(TruncationStrategy::LongestFirst), + "only_first" => Ok(TruncationStrategy::OnlyFirst), + "only_second" => Ok(TruncationStrategy::OnlySecond), + _ => Err("strategy must be one of [longest_first, only_first, only_second]"), + }; + let mut params = TruncationParams::default(); + params.max_length = truncation_max_length as usize; + params.strategy = res_strategy.unwrap(); + params.stride = truncation_stride as usize; + + let tokenizer = cast_handle::(handle); + tokenizer.with_truncation(Some(params)); +} + +#[no_mangle] +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_disableTruncation( + _env: JNIEnv, + _: JObject, + handle: jlong, +) { + let tokenizer = cast_handle::(handle); + tokenizer.with_truncation(None); +} + fn to_handle(val: T) -> jlong { let handle = Box::into_raw(Box::new(val)) as jlong; handle diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index 123ba7c841f..f80779960e7 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -302,6 +302,90 @@ public String decode(long[] ids) { return decode(ids, !addSpecialTokens); } + /** + * Sets the truncation and padding behavior for the tokenizer. + * + * @param truncationStrategy the {@code TruncationStrategy} to use + * @param paddingStrategy the {@code PaddingStrategy} to use + * @param maxLength the maximum length to pad/truncate sequences to + */ + public void setTruncationAndPadding( + TruncationStrategy truncationStrategy, + PaddingStrategy paddingStrategy, + long maxLength) { + setTruncationAndPadding(truncationStrategy, paddingStrategy, maxLength, 0, 0); + } + + /** + * Sets the truncation and padding behavior for the tokenizer. + * + * @param truncationStrategy the {@code TruncationStrategy} to use + * @param paddingStrategy the {@code PaddingStrategy} to use + * @param maxLength the maximum length to pad/truncate sequences to + * @param stride value to use when handling overflow + */ + public void setTruncationAndPadding( + TruncationStrategy truncationStrategy, + PaddingStrategy paddingStrategy, + long maxLength, + long stride) { + setTruncationAndPadding(truncationStrategy, paddingStrategy, maxLength, stride, 0); + } + + /** + * Sets the truncation and padding behavior for the tokenizer. + * + * @param truncationStrategy the {@code TruncationStrategy} to use + * @param paddingStrategy the {@code PaddingStrategy} to use + * @param maxLength the maximum length to pad/truncate sequences to + * @param stride value to use when handling overflow + * @param padToMultipleOf pad sequence length to multiple of value + */ + public void setTruncationAndPadding( + TruncationStrategy truncationStrategy, + PaddingStrategy paddingStrategy, + long maxLength, + long stride, + long padToMultipleOf) { + setTruncation(truncationStrategy, maxLength, stride); + setPadding(paddingStrategy, maxLength, padToMultipleOf); + } + + /** + * Sets the truncation behavior for the tokenizer. + * + * @param truncationStrategy the {@code TruncationStrategy} to use + * @param maxLength the maximum length to truncate sequences to + * @param stride value to use when handling overflow + */ + public void setTruncation(TruncationStrategy truncationStrategy, long maxLength, long stride) { + if (truncationStrategy == TruncationStrategy.DO_NOT_TRUNCATE) { + TokenizersLibrary.LIB.disableTruncation(getHandle()); + } else { + TokenizersLibrary.LIB.setTruncation( + getHandle(), maxLength, truncationStrategy.toString().toLowerCase(), stride); + } + } + + /** + * Sets the padding behavior for the tokenizer. + * + * @param paddingStrategy the {@code PaddingStrategy} to use + * @param maxLength the maximum length to pad sequences to + * @param padToMultipleOf pad sequence length to multiple of value + */ + public void setPadding(PaddingStrategy paddingStrategy, long maxLength, long padToMultipleOf) { + if (paddingStrategy == PaddingStrategy.DO_NOT_PAD) { + TokenizersLibrary.LIB.disablePadding(getHandle()); + } else { + TokenizersLibrary.LIB.setPadding( + getHandle(), + maxLength, + paddingStrategy.toString().toLowerCase(), + padToMultipleOf); + } + } + private Encoding toEncoding(long encoding) { long[] ids = TokenizersLibrary.LIB.getTokenIds(encoding); long[] typeIds = TokenizersLibrary.LIB.getTypeIds(encoding); @@ -315,4 +399,17 @@ private Encoding toEncoding(long encoding) { return new Encoding( ids, typeIds, tokens, wordIds, attentionMask, specialTokenMask, charSpans); } + + enum TruncationStrategy { + DO_NOT_TRUNCATE, + LONGEST_FIRST, + ONLY_FIRST, + ONLY_SECOND + } + + enum PaddingStrategy { + DO_NOT_PAD, + BATCH_LONGEST, + FIXED + } } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java index b532886bf16..b0e9095acc8 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java @@ -52,4 +52,14 @@ public native long encodeDual( public native CharSpan[] getTokenCharSpans(long encoding); public native String decode(long tokenizer, long[] ids, boolean addSpecialTokens); + + public native void disablePadding(long tokenizer); + + public native void setPadding( + long tokenizer, long maxLength, String paddingStrategy, long padToMultipleOf); + + public native void disableTruncation(long tokenizer); + + public native void setTruncation( + long tokenizer, long maxLength, String truncationStrategy, long stride); } diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java index 591df6b2244..22cf0782d00 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java @@ -155,4 +155,77 @@ public void testTokenizerDecoding() { } } } + + @Test + public void testTruncationAndPadding() { + TestRequirements.notArm(); + + try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) { + List inputs = + Arrays.asList( + "Hello, y'all! How are you?", + "Today is a sunny day. Good weather I'd say", + "I am happy"); + tokenizer.setTruncationAndPadding( + HuggingFaceTokenizer.TruncationStrategy.LONGEST_FIRST, + HuggingFaceTokenizer.PaddingStrategy.BATCH_LONGEST, + 10); + List expectedIds = + Arrays.asList( + new long[] {101, 8667, 117, 194, 112, 1155, 106, 1731, 1132, 102}, + new long[] {101, 3570, 1110, 170, 21162, 1285, 119, 2750, 4250, 102}, + new long[] {101, 146, 1821, 2816, 102, 0, 0, 0, 0, 0}); + List expectedAttentionMasks = + Arrays.asList( + new long[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + new long[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + new long[] {1, 1, 1, 1, 1, 0, 0, 0, 0, 0}); + Encoding[] encodings = tokenizer.batchEncode(inputs); + for (int i = 0; i < encodings.length; ++i) { + Assert.assertEquals(encodings[i].getIds(), expectedIds.get(i)); + Assert.assertEquals(encodings[i].getAttentionMask(), expectedAttentionMasks.get(i)); + } + + tokenizer.setTruncationAndPadding( + HuggingFaceTokenizer.TruncationStrategy.LONGEST_FIRST, + HuggingFaceTokenizer.PaddingStrategy.FIXED, + 12); + expectedIds = + Arrays.asList( + new long[] { + 101, 8667, 117, 194, 112, 1155, 106, 1731, 1132, 1128, 136, 102 + }, + new long[] { + 101, 3570, 1110, 170, 21162, 1285, 119, 2750, 4250, 146, 112, 102 + }, + new long[] {101, 146, 1821, 2816, 102, 0, 0, 0, 0, 0, 0, 0}); + expectedAttentionMasks = + Arrays.asList( + new long[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + new long[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + new long[] {1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0}); + encodings = tokenizer.batchEncode(inputs); + for (int i = 0; i < encodings.length; ++i) { + Assert.assertEquals(encodings[i].getIds(), expectedIds.get(i)); + Assert.assertEquals(encodings[i].getAttentionMask(), expectedAttentionMasks.get(i)); + } + + tokenizer.setTruncationAndPadding( + HuggingFaceTokenizer.TruncationStrategy.ONLY_FIRST, + HuggingFaceTokenizer.PaddingStrategy.BATCH_LONGEST, + 8); + Encoding encoding = tokenizer.encode("Hello there my friend", "How are you"); + long[] expectedId = new long[] {101, 8667, 1175, 102, 1731, 1132, 1128, 102}; + Assert.assertEquals(encoding.getIds(), expectedId); + + tokenizer.setTruncationAndPadding( + HuggingFaceTokenizer.TruncationStrategy.ONLY_SECOND, + HuggingFaceTokenizer.PaddingStrategy.DO_NOT_PAD, + 8); + + encoding = tokenizer.encode("Hello there my friend", "How are you"); + expectedId = new long[] {101, 8667, 1175, 1139, 1910, 102, 1731, 102}; + Assert.assertEquals(encoding.getIds(), expectedId); + } + } }