Skip to content

Commit

Permalink
Add truncation and padding setter APIs to tokenizers (deepjavalibrary…
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk authored and patins1 committed Aug 26, 2022
1 parent 3e76f8d commit 8d4383c
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 0 deletions.
77 changes: 77 additions & 0 deletions extensions/tokenizers/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use std::str::FromStr;
use tk::tokenizer::{EncodeInput, Encoding};
use tk::Tokenizer;
use tk::{FromPretrainedParameters, Offsets};
use tk::utils::truncation::{TruncationParams, TruncationStrategy};
use tk::utils::padding::{PaddingParams, PaddingStrategy};

use jni::objects::{JClass, JMethodID, JObject, JString, JValue, ReleaseMode};
use jni::sys::{
Expand Down Expand Up @@ -405,6 +407,81 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
ret
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_setPadding(
env: JNIEnv,
_: JObject,
handle: jlong,
max_length: jlong,
padding_strategy: JString,
pad_to_multiple_of: jlong,
) {
let strategy: String = env
.get_string(padding_strategy)
.expect("Couldn't get java string!")
.into();
let len = max_length as usize;
let res_strategy = match strategy.as_ref() {
"batch_longest" => Ok(PaddingStrategy::BatchLongest),
"fixed" => Ok(PaddingStrategy::Fixed(len)),
_ => Err("strategy must be one of [batch_longest, fixed]"),
};

let mut params = PaddingParams::default();
params.strategy = res_strategy.unwrap();
params.pad_to_multiple_of = Some(pad_to_multiple_of as usize);
let tokenizer = cast_handle::<Tokenizer>(handle);
tokenizer.with_padding(Some(params));
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_disablePadding(
_env: JNIEnv,
_: JObject,
handle: jlong,
) {
let tokenizer = cast_handle::<Tokenizer>(handle);
tokenizer.with_padding(None);
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_setTruncation(
env: JNIEnv,
_: JObject,
handle: jlong,
truncation_max_length: jlong,
truncation_strategy: JString,
truncation_stride: jlong,
) {
let strategy: String = env
.get_string(truncation_strategy)
.expect("Couldn't get java string!")
.into();
let res_strategy = match strategy.as_ref() {
"longest_first" => Ok(TruncationStrategy::LongestFirst),
"only_first" => Ok(TruncationStrategy::OnlyFirst),
"only_second" => Ok(TruncationStrategy::OnlySecond),
_ => Err("strategy must be one of [longest_first, only_first, only_second]"),
};
let mut params = TruncationParams::default();
params.max_length = truncation_max_length as usize;
params.strategy = res_strategy.unwrap();
params.stride = truncation_stride as usize;

let tokenizer = cast_handle::<Tokenizer>(handle);
tokenizer.with_truncation(Some(params));
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_disableTruncation(
_env: JNIEnv,
_: JObject,
handle: jlong,
) {
let tokenizer = cast_handle::<Tokenizer>(handle);
tokenizer.with_truncation(None);
}

fn to_handle<T: 'static>(val: T) -> jlong {
let handle = Box::into_raw(Box::new(val)) as jlong;
handle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,90 @@ public String decode(long[] ids) {
return decode(ids, !addSpecialTokens);
}

/**
* Sets the truncation and padding behavior for the tokenizer.
*
* @param truncationStrategy the {@code TruncationStrategy} to use
* @param paddingStrategy the {@code PaddingStrategy} to use
* @param maxLength the maximum length to pad/truncate sequences to
*/
public void setTruncationAndPadding(
TruncationStrategy truncationStrategy,
PaddingStrategy paddingStrategy,
long maxLength) {
setTruncationAndPadding(truncationStrategy, paddingStrategy, maxLength, 0, 0);
}

/**
* Sets the truncation and padding behavior for the tokenizer.
*
* @param truncationStrategy the {@code TruncationStrategy} to use
* @param paddingStrategy the {@code PaddingStrategy} to use
* @param maxLength the maximum length to pad/truncate sequences to
* @param stride value to use when handling overflow
*/
public void setTruncationAndPadding(
TruncationStrategy truncationStrategy,
PaddingStrategy paddingStrategy,
long maxLength,
long stride) {
setTruncationAndPadding(truncationStrategy, paddingStrategy, maxLength, stride, 0);
}

/**
* Sets the truncation and padding behavior for the tokenizer.
*
* @param truncationStrategy the {@code TruncationStrategy} to use
* @param paddingStrategy the {@code PaddingStrategy} to use
* @param maxLength the maximum length to pad/truncate sequences to
* @param stride value to use when handling overflow
* @param padToMultipleOf pad sequence length to multiple of value
*/
public void setTruncationAndPadding(
TruncationStrategy truncationStrategy,
PaddingStrategy paddingStrategy,
long maxLength,
long stride,
long padToMultipleOf) {
setTruncation(truncationStrategy, maxLength, stride);
setPadding(paddingStrategy, maxLength, padToMultipleOf);
}

/**
* Sets the truncation behavior for the tokenizer.
*
* @param truncationStrategy the {@code TruncationStrategy} to use
* @param maxLength the maximum length to truncate sequences to
* @param stride value to use when handling overflow
*/
public void setTruncation(TruncationStrategy truncationStrategy, long maxLength, long stride) {
if (truncationStrategy == TruncationStrategy.DO_NOT_TRUNCATE) {
TokenizersLibrary.LIB.disableTruncation(getHandle());
} else {
TokenizersLibrary.LIB.setTruncation(
getHandle(), maxLength, truncationStrategy.toString().toLowerCase(), stride);
}
}

/**
* Sets the padding behavior for the tokenizer.
*
* @param paddingStrategy the {@code PaddingStrategy} to use
* @param maxLength the maximum length to pad sequences to
* @param padToMultipleOf pad sequence length to multiple of value
*/
public void setPadding(PaddingStrategy paddingStrategy, long maxLength, long padToMultipleOf) {
if (paddingStrategy == PaddingStrategy.DO_NOT_PAD) {
TokenizersLibrary.LIB.disablePadding(getHandle());
} else {
TokenizersLibrary.LIB.setPadding(
getHandle(),
maxLength,
paddingStrategy.toString().toLowerCase(),
padToMultipleOf);
}
}

private Encoding toEncoding(long encoding) {
long[] ids = TokenizersLibrary.LIB.getTokenIds(encoding);
long[] typeIds = TokenizersLibrary.LIB.getTypeIds(encoding);
Expand All @@ -315,4 +399,17 @@ private Encoding toEncoding(long encoding) {
return new Encoding(
ids, typeIds, tokens, wordIds, attentionMask, specialTokenMask, charSpans);
}

enum TruncationStrategy {
DO_NOT_TRUNCATE,
LONGEST_FIRST,
ONLY_FIRST,
ONLY_SECOND
}

enum PaddingStrategy {
DO_NOT_PAD,
BATCH_LONGEST,
FIXED
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,14 @@ public native long encodeDual(
public native CharSpan[] getTokenCharSpans(long encoding);

public native String decode(long tokenizer, long[] ids, boolean addSpecialTokens);

public native void disablePadding(long tokenizer);

public native void setPadding(
long tokenizer, long maxLength, String paddingStrategy, long padToMultipleOf);

public native void disableTruncation(long tokenizer);

public native void setTruncation(
long tokenizer, long maxLength, String truncationStrategy, long stride);
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,77 @@ public void testTokenizerDecoding() {
}
}
}

@Test
public void testTruncationAndPadding() {
TestRequirements.notArm();

try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) {
List<String> inputs =
Arrays.asList(
"Hello, y'all! How are you?",
"Today is a sunny day. Good weather I'd say",
"I am happy");
tokenizer.setTruncationAndPadding(
HuggingFaceTokenizer.TruncationStrategy.LONGEST_FIRST,
HuggingFaceTokenizer.PaddingStrategy.BATCH_LONGEST,
10);
List<long[]> expectedIds =
Arrays.asList(
new long[] {101, 8667, 117, 194, 112, 1155, 106, 1731, 1132, 102},
new long[] {101, 3570, 1110, 170, 21162, 1285, 119, 2750, 4250, 102},
new long[] {101, 146, 1821, 2816, 102, 0, 0, 0, 0, 0});
List<long[]> expectedAttentionMasks =
Arrays.asList(
new long[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
new long[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
new long[] {1, 1, 1, 1, 1, 0, 0, 0, 0, 0});
Encoding[] encodings = tokenizer.batchEncode(inputs);
for (int i = 0; i < encodings.length; ++i) {
Assert.assertEquals(encodings[i].getIds(), expectedIds.get(i));
Assert.assertEquals(encodings[i].getAttentionMask(), expectedAttentionMasks.get(i));
}

tokenizer.setTruncationAndPadding(
HuggingFaceTokenizer.TruncationStrategy.LONGEST_FIRST,
HuggingFaceTokenizer.PaddingStrategy.FIXED,
12);
expectedIds =
Arrays.asList(
new long[] {
101, 8667, 117, 194, 112, 1155, 106, 1731, 1132, 1128, 136, 102
},
new long[] {
101, 3570, 1110, 170, 21162, 1285, 119, 2750, 4250, 146, 112, 102
},
new long[] {101, 146, 1821, 2816, 102, 0, 0, 0, 0, 0, 0, 0});
expectedAttentionMasks =
Arrays.asList(
new long[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
new long[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
new long[] {1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0});
encodings = tokenizer.batchEncode(inputs);
for (int i = 0; i < encodings.length; ++i) {
Assert.assertEquals(encodings[i].getIds(), expectedIds.get(i));
Assert.assertEquals(encodings[i].getAttentionMask(), expectedAttentionMasks.get(i));
}

tokenizer.setTruncationAndPadding(
HuggingFaceTokenizer.TruncationStrategy.ONLY_FIRST,
HuggingFaceTokenizer.PaddingStrategy.BATCH_LONGEST,
8);
Encoding encoding = tokenizer.encode("Hello there my friend", "How are you");
long[] expectedId = new long[] {101, 8667, 1175, 102, 1731, 1132, 1128, 102};
Assert.assertEquals(encoding.getIds(), expectedId);

tokenizer.setTruncationAndPadding(
HuggingFaceTokenizer.TruncationStrategy.ONLY_SECOND,
HuggingFaceTokenizer.PaddingStrategy.DO_NOT_PAD,
8);

encoding = tokenizer.encode("Hello there my friend", "How are you");
expectedId = new long[] {101, 8667, 1175, 1139, 1910, 102, 1731, 102};
Assert.assertEquals(encoding.getIds(), expectedId);
}
}
}

0 comments on commit 8d4383c

Please sign in to comment.