Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Set past_kv name for corner case. #2722

Merged
merged 1 commit into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslatorContext;

import java.util.Collections;
import java.util.stream.Collectors;

/** The {@link ai.djl.translate.Translator} for PyTorch GPT2 model. */
Expand Down Expand Up @@ -49,16 +50,31 @@ public PtGptTranslator(long kvDim, int numAttentionHeads, int numLayers) {
/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, NDList input) throws Exception {
// input = [inputIds, posIds, attnMask]
NDManager manager = ctx.getNDManager();
if (input.size() == 3) {
ctx.setAttachment("initialCall", Boolean.TRUE);
// In this case, input has null pastKeyValues. We prefix-append a dummy pastKeyValues,
// which is treated as prefix padding, and set the corresponding attnMask to be zero. No
// need to shift the position ids.
ctx.setAttachment("useDummyPastKeyValues", Boolean.TRUE);

// Pad the null pastKeyValues with dummy values
NDList pastKeyValues = initialDummyPastKeyValues(input.get(0), manager);
for (NDArray pkv : pastKeyValues) {
pkv.setName(tupleName);
input.add(pkv);
}

// Append zero to the attentionMask from left, corresponding to the padding
long batchSize = input.get(0).getShape().get(0);
NDArray attentionMask = input.get(2);
attentionMask =
manager.zeros(new Shape(batchSize, 1), DataType.INT64)
.concat(attentionMask, -1);
NDArray attentionMask =
manager.zeros(new Shape(batchSize, 1), DataType.INT64).concat(input.get(2), -1);
input.set(2, attentionMask);
addInitialDummyPastKeyValues(input, input.get(0), manager);
} else {
for (int i = 3; i < numLayers * 2 + 3; ++i) {
NDArray pkv = input.get(i);
pkv.setName(tupleName);
}
}

return input;
Expand All @@ -74,12 +90,12 @@ public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) throws
if (output.size() > numLayers * 2 + 1) {
hiddenStatesOutput = output.get(numLayers * 2 + 1);
} else {
// If the traced_GPT2 model outputs hiddenStates, then this is not executed. If the
// provided traced model doesn't output hiddenStates, we can throw a warning here.
// Here is reached only if the language model doesn't output hiddenStates, which is
// needed only in contrastive search. We can also throw a warning here.
hiddenStatesOutput = manager.zeros(new Shape(1));
}

if (ctx.getAttachment("initialCall") != null) {
if (ctx.getAttachment("useDummyPastKeyValues") != null) {
NDIndex index2 = new NDIndex(":, :, 1:, ...");
pastKeyValuesOutput =
new NDList(
Expand All @@ -95,12 +111,11 @@ public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) throws
return new CausalLMOutput(logitsOutput, hiddenStatesOutput, pastKeyValuesOutput);
}

private void addInitialDummyPastKeyValues(NDList list, NDArray inputIds, NDManager manager) {
private NDList initialDummyPastKeyValues(NDArray inputIds, NDManager manager) {
long numBatch = inputIds.getShape().get(0);
for (int i = 0; i < numLayers * 2; ++i) {
NDArray array = manager.zeros(new Shape(numBatch, numAttentionHeads, 1, kvDim));
array.setName(tupleName);
list.add(array);
}
NDArray dummyKV = manager.zeros(new Shape(numBatch, numAttentionHeads, 1, kvDim));
NDList pastKeyValues = new NDList();
pastKeyValues.addAll(Collections.nCopies(2 * numLayers, dummyKV));
return pastKeyValues;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@
import java.nio.file.Path;
import java.nio.file.Paths;

public class TextGenerationTest {
public class GptTranslatorTest {

@Test
public void testGpt2() throws TranslateException, ModelException, IOException {
// This is a fake model that simulates language models like GPT2: NDList(inputIds, posIds,
// attnMask) -> NDList(logits(1), pastKv(12*2)[, hiddenStates(13)])
Block block =
new LambdaBlock(
a -> {
Expand Down