diff --git a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java index 9c82fa6ff52..544f6503f50 100644 --- a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java +++ b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslator.java @@ -22,6 +22,7 @@ import ai.djl.translate.NoBatchifyTranslator; import ai.djl.translate.TranslatorContext; +import java.util.Collections; import java.util.stream.Collectors; /** The {@link ai.djl.translate.Translator} for PyTorch GPT2 model. */ @@ -49,16 +50,31 @@ public PtGptTranslator(long kvDim, int numAttentionHeads, int numLayers) { /** {@inheritDoc} */ @Override public NDList processInput(TranslatorContext ctx, NDList input) throws Exception { + // input = [inputIds, posIds, attnMask] NDManager manager = ctx.getNDManager(); if (input.size() == 3) { - ctx.setAttachment("initialCall", Boolean.TRUE); + // In this case, input has null pastKeyValues. We prefix-append a dummy pastKeyValues, + // which is treated as prefix padding, and set the corresponding attnMask to be zero. No + // need to shift the position ids. + ctx.setAttachment("useDummyPastKeyValues", Boolean.TRUE); + + // Pad the null pastKeyValues with dummy values + NDList pastKeyValues = initialDummyPastKeyValues(input.get(0), manager); + for (NDArray pkv : pastKeyValues) { + pkv.setName(tupleName); + input.add(pkv); + } + + // Append zero to the attentionMask from left, corresponding to the padding long batchSize = input.get(0).getShape().get(0); - NDArray attentionMask = input.get(2); - attentionMask = - manager.zeros(new Shape(batchSize, 1), DataType.INT64) - .concat(attentionMask, -1); + NDArray attentionMask = + manager.zeros(new Shape(batchSize, 1), DataType.INT64).concat(input.get(2), -1); input.set(2, attentionMask); - addInitialDummyPastKeyValues(input, input.get(0), manager); + } else { + for (int i = 3; i < numLayers * 2 + 3; ++i) { + NDArray pkv = input.get(i); + pkv.setName(tupleName); + } } return input; @@ -74,12 +90,12 @@ public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) throws if (output.size() > numLayers * 2 + 1) { hiddenStatesOutput = output.get(numLayers * 2 + 1); } else { - // If the traced_GPT2 model outputs hiddenStates, then this is not executed. If the - // provided traced model doesn't output hiddenStates, we can throw a warning here. + // Here is reached only if the language model doesn't output hiddenStates, which is + // needed only in contrastive search. We can also throw a warning here. hiddenStatesOutput = manager.zeros(new Shape(1)); } - if (ctx.getAttachment("initialCall") != null) { + if (ctx.getAttachment("useDummyPastKeyValues") != null) { NDIndex index2 = new NDIndex(":, :, 1:, ..."); pastKeyValuesOutput = new NDList( @@ -95,12 +111,11 @@ public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) throws return new CausalLMOutput(logitsOutput, hiddenStatesOutput, pastKeyValuesOutput); } - private void addInitialDummyPastKeyValues(NDList list, NDArray inputIds, NDManager manager) { + private NDList initialDummyPastKeyValues(NDArray inputIds, NDManager manager) { long numBatch = inputIds.getShape().get(0); - for (int i = 0; i < numLayers * 2; ++i) { - NDArray array = manager.zeros(new Shape(numBatch, numAttentionHeads, 1, kvDim)); - array.setName(tupleName); - list.add(array); - } + NDArray dummyKV = manager.zeros(new Shape(numBatch, numAttentionHeads, 1, kvDim)); + NDList pastKeyValues = new NDList(); + pastKeyValues.addAll(Collections.nCopies(2 * numLayers, dummyKV)); + return pastKeyValues; } } diff --git a/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/TextGenerationTest.java b/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/GptTranslatorTest.java similarity index 94% rename from engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/TextGenerationTest.java rename to engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/GptTranslatorTest.java index 800843e9567..68523be7802 100644 --- a/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/TextGenerationTest.java +++ b/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/GptTranslatorTest.java @@ -34,10 +34,12 @@ import java.nio.file.Path; import java.nio.file.Paths; -public class TextGenerationTest { +public class GptTranslatorTest { @Test public void testGpt2() throws TranslateException, ModelException, IOException { + // This is a fake model that simulates language models like GPT2: NDList(inputIds, posIds, + // attnMask) -> NDList(logits(1), pastKv(12*2)[, hiddenStates(13)]) Block block = new LambdaBlock( a -> {