From cddbccc1d988bf0bad508b8a22bfec93ca3d40bb Mon Sep 17 00:00:00 2001 From: Byran Liu Date: Tue, 14 Jan 2025 21:13:53 -0800 Subject: [PATCH] [tokenizers] Fixes ZeroShotClassificationTranslator bug --- .../translator/ZeroShotClassificationServingTranslator.java | 4 ++-- .../translator/ZeroShotClassificationTranslator.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/ZeroShotClassificationServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/ZeroShotClassificationServingTranslator.java index 15d503fcef9..9b9cdb9ad7c 100644 --- a/api/src/main/java/ai/djl/modality/nlp/translator/ZeroShotClassificationServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/nlp/translator/ZeroShotClassificationServingTranslator.java @@ -32,9 +32,9 @@ public class ZeroShotClassificationServingTranslator private Translator translator; /** - * Constructs a {@code TokenClassificationServingTranslator} instance. + * Constructs a {@code ZeroShotClassificationServingTranslator} instance. * - * @param translator a {@code Translator} processes token classification input + * @param translator a {@code Translator} processes zero-shot-classification input */ public ZeroShotClassificationServingTranslator( Translator translator) { diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/ZeroShotClassificationTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/ZeroShotClassificationTranslator.java index a8d651848b8..6159cfd1702 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/ZeroShotClassificationTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/ZeroShotClassificationTranslator.java @@ -116,7 +116,7 @@ private String applyTemplate(String template, String arg) { return template + arg; } int len = template.length(); - return template.substring(0, pos) + arg + template.substring(pos + 1, len); + return template.substring(0, pos) + arg + template.substring(pos + 2, len); } /**