Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add truncation and padding setter APIs to tokenizers #1870

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions extensions/tokenizers/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use std::str::FromStr;
use tk::tokenizer::{EncodeInput, Encoding};
use tk::Tokenizer;
use tk::{FromPretrainedParameters, Offsets};
use tk::utils::truncation::{TruncationParams, TruncationStrategy};
use tk::utils::padding::{PaddingParams, PaddingStrategy};

use jni::objects::{JClass, JMethodID, JObject, JString, JValue, ReleaseMode};
use jni::sys::{
Expand Down Expand Up @@ -405,6 +407,81 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
ret
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_setPadding(
env: JNIEnv,
_: JObject,
handle: jlong,
max_length: jlong,
padding_strategy: JString,
pad_to_multiple_of: jlong,
) {
let strategy: String = env
.get_string(padding_strategy)
.expect("Couldn't get java string!")
.into();
let len = max_length as usize;
let res_strategy = match strategy.as_ref() {
"batch_longest" => Ok(PaddingStrategy::BatchLongest),
"fixed" => Ok(PaddingStrategy::Fixed(len)),
_ => Err("strategy must be one of [batch_longest, fixed]"),
};

let mut params = PaddingParams::default();
params.strategy = res_strategy.unwrap();
params.pad_to_multiple_of = Some(pad_to_multiple_of as usize);
let tokenizer = cast_handle::<Tokenizer>(handle);
tokenizer.with_padding(Some(params));
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_disablePadding(
_env: JNIEnv,
_: JObject,
handle: jlong,
) {
let tokenizer = cast_handle::<Tokenizer>(handle);
tokenizer.with_padding(None);
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_setTruncation(
env: JNIEnv,
_: JObject,
handle: jlong,
truncation_max_length: jlong,
truncation_strategy: JString,
truncation_stride: jlong,
) {
let strategy: String = env
.get_string(truncation_strategy)
.expect("Couldn't get java string!")
.into();
let res_strategy = match strategy.as_ref() {
"longest_first" => Ok(TruncationStrategy::LongestFirst),
"only_first" => Ok(TruncationStrategy::OnlyFirst),
"only_second" => Ok(TruncationStrategy::OnlySecond),
_ => Err("strategy must be one of [longest_first, only_first, only_second]"),
};
let mut params = TruncationParams::default();
params.max_length = truncation_max_length as usize;
params.strategy = res_strategy.unwrap();
params.stride = truncation_stride as usize;

let tokenizer = cast_handle::<Tokenizer>(handle);
tokenizer.with_truncation(Some(params));
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_disableTruncation(
_env: JNIEnv,
_: JObject,
handle: jlong,
) {
let tokenizer = cast_handle::<Tokenizer>(handle);
tokenizer.with_truncation(None);
}

fn to_handle<T: 'static>(val: T) -> jlong {
let handle = Box::into_raw(Box::new(val)) as jlong;
handle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,90 @@ public String decode(long[] ids) {
return decode(ids, !addSpecialTokens);
}

/**
* Sets the truncation and padding behavior for the tokenizer.
*
* @param truncationStrategy the {@code TruncationStrategy} to use
* @param paddingStrategy the {@code PaddingStrategy} to use
* @param maxLength the maximum length to pad/truncate sequences to
*/
public void setTruncationAndPadding(
TruncationStrategy truncationStrategy,
PaddingStrategy paddingStrategy,
long maxLength) {
setTruncationAndPadding(truncationStrategy, paddingStrategy, maxLength, 0, 0);
}

/**
* Sets the truncation and padding behavior for the tokenizer.
*
* @param truncationStrategy the {@code TruncationStrategy} to use
* @param paddingStrategy the {@code PaddingStrategy} to use
* @param maxLength the maximum length to pad/truncate sequences to
* @param stride value to use when handling overflow
*/
public void setTruncationAndPadding(
TruncationStrategy truncationStrategy,
PaddingStrategy paddingStrategy,
long maxLength,
long stride) {
setTruncationAndPadding(truncationStrategy, paddingStrategy, maxLength, stride, 0);
}

/**
* Sets the truncation and padding behavior for the tokenizer.
*
* @param truncationStrategy the {@code TruncationStrategy} to use
* @param paddingStrategy the {@code PaddingStrategy} to use
* @param maxLength the maximum length to pad/truncate sequences to
* @param stride value to use when handling overflow
* @param padToMultipleOf pad sequence length to multiple of value
*/
public void setTruncationAndPadding(
TruncationStrategy truncationStrategy,
PaddingStrategy paddingStrategy,
long maxLength,
long stride,
long padToMultipleOf) {
setTruncation(truncationStrategy, maxLength, stride);
setPadding(paddingStrategy, maxLength, padToMultipleOf);
}

/**
* Sets the truncation behavior for the tokenizer.
*
* @param truncationStrategy the {@code TruncationStrategy} to use
* @param maxLength the maximum length to truncate sequences to
* @param stride value to use when handling overflow
*/
public void setTruncation(TruncationStrategy truncationStrategy, long maxLength, long stride) {
if (truncationStrategy == TruncationStrategy.DO_NOT_TRUNCATE) {
TokenizersLibrary.LIB.disableTruncation(getHandle());
} else {
TokenizersLibrary.LIB.setTruncation(
getHandle(), maxLength, truncationStrategy.toString().toLowerCase(), stride);
}
}

/**
* Sets the padding behavior for the tokenizer.
*
* @param paddingStrategy the {@code PaddingStrategy} to use
* @param maxLength the maximum length to pad sequences to
* @param padToMultipleOf pad sequence length to multiple of value
*/
public void setPadding(PaddingStrategy paddingStrategy, long maxLength, long padToMultipleOf) {
if (paddingStrategy == PaddingStrategy.DO_NOT_PAD) {
TokenizersLibrary.LIB.disablePadding(getHandle());
} else {
TokenizersLibrary.LIB.setPadding(
getHandle(),
maxLength,
paddingStrategy.toString().toLowerCase(),
padToMultipleOf);
}
}

private Encoding toEncoding(long encoding) {
long[] ids = TokenizersLibrary.LIB.getTokenIds(encoding);
long[] typeIds = TokenizersLibrary.LIB.getTypeIds(encoding);
Expand All @@ -315,4 +399,17 @@ private Encoding toEncoding(long encoding) {
return new Encoding(
ids, typeIds, tokens, wordIds, attentionMask, specialTokenMask, charSpans);
}

enum TruncationStrategy {
DO_NOT_TRUNCATE,
LONGEST_FIRST,
ONLY_FIRST,
ONLY_SECOND
}

enum PaddingStrategy {
DO_NOT_PAD,
BATCH_LONGEST,
FIXED
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,14 @@ public native long encodeDual(
public native CharSpan[] getTokenCharSpans(long encoding);

public native String decode(long tokenizer, long[] ids, boolean addSpecialTokens);

public native void disablePadding(long tokenizer);

public native void setPadding(
long tokenizer, long maxLength, String paddingStrategy, long padToMultipleOf);

public native void disableTruncation(long tokenizer);

public native void setTruncation(
long tokenizer, long maxLength, String truncationStrategy, long stride);
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,77 @@ public void testTokenizerDecoding() {
}
}
}

@Test
public void testTruncationAndPadding() {
TestRequirements.notArm();

try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) {
List<String> inputs =
Arrays.asList(
"Hello, y'all! How are you?",
"Today is a sunny day. Good weather I'd say",
"I am happy");
tokenizer.setTruncationAndPadding(
HuggingFaceTokenizer.TruncationStrategy.LONGEST_FIRST,
HuggingFaceTokenizer.PaddingStrategy.BATCH_LONGEST,
10);
List<long[]> expectedIds =
Arrays.asList(
new long[] {101, 8667, 117, 194, 112, 1155, 106, 1731, 1132, 102},
new long[] {101, 3570, 1110, 170, 21162, 1285, 119, 2750, 4250, 102},
new long[] {101, 146, 1821, 2816, 102, 0, 0, 0, 0, 0});
List<long[]> expectedAttentionMasks =
Arrays.asList(
new long[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
new long[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
new long[] {1, 1, 1, 1, 1, 0, 0, 0, 0, 0});
Encoding[] encodings = tokenizer.batchEncode(inputs);
for (int i = 0; i < encodings.length; ++i) {
Assert.assertEquals(encodings[i].getIds(), expectedIds.get(i));
Assert.assertEquals(encodings[i].getAttentionMask(), expectedAttentionMasks.get(i));
}

tokenizer.setTruncationAndPadding(
HuggingFaceTokenizer.TruncationStrategy.LONGEST_FIRST,
HuggingFaceTokenizer.PaddingStrategy.FIXED,
12);
expectedIds =
Arrays.asList(
new long[] {
101, 8667, 117, 194, 112, 1155, 106, 1731, 1132, 1128, 136, 102
},
new long[] {
101, 3570, 1110, 170, 21162, 1285, 119, 2750, 4250, 146, 112, 102
},
new long[] {101, 146, 1821, 2816, 102, 0, 0, 0, 0, 0, 0, 0});
expectedAttentionMasks =
Arrays.asList(
new long[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
new long[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
new long[] {1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0});
encodings = tokenizer.batchEncode(inputs);
for (int i = 0; i < encodings.length; ++i) {
Assert.assertEquals(encodings[i].getIds(), expectedIds.get(i));
Assert.assertEquals(encodings[i].getAttentionMask(), expectedAttentionMasks.get(i));
}

tokenizer.setTruncationAndPadding(
HuggingFaceTokenizer.TruncationStrategy.ONLY_FIRST,
HuggingFaceTokenizer.PaddingStrategy.BATCH_LONGEST,
8);
Encoding encoding = tokenizer.encode("Hello there my friend", "How are you");
long[] expectedId = new long[] {101, 8667, 1175, 102, 1731, 1132, 1128, 102};
Assert.assertEquals(encoding.getIds(), expectedId);

tokenizer.setTruncationAndPadding(
HuggingFaceTokenizer.TruncationStrategy.ONLY_SECOND,
HuggingFaceTokenizer.PaddingStrategy.DO_NOT_PAD,
8);

encoding = tokenizer.encode("Hello there my friend", "How are you");
expectedId = new long[] {101, 8667, 1175, 1139, 1910, 102, 1731, 102};
Assert.assertEquals(encoding.getIds(), expectedId);
}
}
}