Skip to content

Commit

Permalink
[api] implements text-generation search algorithm (#2637)
Browse files Browse the repository at this point in the history
Co-authored-by: KexinFeng <fenkexin@amazon.com>
Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
  • Loading branch information
3 people authored Jun 27, 2023
1 parent c6a609f commit 68c7a03
Show file tree
Hide file tree
Showing 22 changed files with 2,701 additions and 0 deletions.
4 changes: 4 additions & 0 deletions api/src/main/java/ai/djl/Application.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ public static Application of(String path) {
return NLP.TOKEN_CLASSIFICATION;
case "nlp/word_embedding":
return NLP.WORD_EMBEDDING;
case "nlp/text_generation":
return NLP.TEXT_GENERATION;
case "tabular":
return Tabular.ANY;
case "tabular/linear_regression":
Expand Down Expand Up @@ -261,6 +263,8 @@ public interface NLP {
*/
Application WORD_EMBEDDING = new Application("nlp/word_embedding");

Application TEXT_GENERATION = new Application("nlp/text_generation");

/**
* An application that translates text from one language to another.
*
Expand Down
174 changes: 174 additions & 0 deletions api/src/main/java/ai/djl/modality/nlp/generate/BatchTensorList.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/**
* BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration
* of the autoregressive loop.
*
* <p>It is a struct consisting of NDArrays, whose first dimension is batch, and also contains
* sequence dimension (whose position in tensor's shape is specified by seqDimOrder). The SeqBatcher
* batch operations will operate on these two dimensions.
*/
public abstract class BatchTensorList {
// [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastOutputIds;

// [batch, seq_past]
// The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastAttentionMask;

// (k, v) * numLayer,
// kv: [batch, heads, seq_past, kvfeature]
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDList pastKeyValues;

// Sequence dimension order among all dimensions for each element in the batch list.
private long[] seqDimOrder;

BatchTensorList() {}

/**
* Constructs a new {@code BatchTensorList} instance.
*
* @param list the NDList that contains the serialized version of the batch tensors
* @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
* is in a tensor's shape
*/
BatchTensorList(NDList list, long[] seqDimOrder) {
this.seqDimOrder = seqDimOrder;
pastOutputIds = list.get(0);
pastAttentionMask = list.get(1);
pastKeyValues = list.subNDList(2);
}

/**
* Constructs a new {@code BatchTensorList} instance.
*
* @param pastOutputIds past output token ids
* @param pastAttentionMask past attention mask
* @param pastKeyValues past kv cache
* @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
* is in a tensor's shape
*/
BatchTensorList(
NDArray pastOutputIds,
NDArray pastAttentionMask,
NDList pastKeyValues,
long[] seqDimOrder) {
this.pastKeyValues = pastKeyValues;
this.pastOutputIds = pastOutputIds;
this.pastAttentionMask = pastAttentionMask;
this.seqDimOrder = seqDimOrder;
}

/**
* Constructs a new {@code BatchTensorList} instance from the serialized version of the batch
* tensors.
*
* <p>The pastOutputIds has to be the first in the output list.
*
* @param inputList the serialized version of the batch tensors
* @param seqDimOrder the sequence dimension order that specifies where the sequence dimension
* is in a tensor's shape
* @return BatchTensorList
*/
public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder);

/**
* Returns the serialized version of the BatchTensorList. The pastOutputIds has to be the first
* in the output list.
*
* @return the NDList that contains the serialized BatchTensorList
*/
public abstract NDList getList();

/**
* Returns the sequence dimension order which specifies where the sequence dimension is in a
* tensor's shape.
*
* @return the sequence dimension order which specifies where the sequence dimension is in a
* tensor's shape
*/
public long[] getSeqDimOrder() {
return seqDimOrder;
}

/**
* Returns the value of the pastOutputIds.
*
* @return the value of pastOutputIds
*/
public NDArray getPastOutputIds() {
return pastOutputIds;
}

/**
* Sets the past output token ids.
*
* @param pastOutputIds the past output token ids
*/
public void setPastOutputIds(NDArray pastOutputIds) {
this.pastOutputIds = pastOutputIds;
}

/**
* Returns the value of the pastAttentionMask.
*
* @return the value of pastAttentionMask
*/
public NDArray getPastAttentionMask() {
return pastAttentionMask;
}

/**
* Sets the attention mask.
*
* @param pastAttentionMask the attention mask
*/
public void setPastAttentionMask(NDArray pastAttentionMask) {
this.pastAttentionMask = pastAttentionMask;
}

/**
* Returns the value of the pastKeyValues.
*
* @return the value of pastKeyValues
*/
public NDList getPastKeyValues() {
return pastKeyValues;
}

/**
* Sets the kv cache.
*
* @param pastKeyValues the kv cache
*/
public void setPastKeyValues(NDList pastKeyValues) {
this.pastKeyValues = pastKeyValues;
}

/**
* Sets the sequence dimension order which specifies where the sequence dimension is in a
* tensor's shape.
*
* @param seqDimOrder the sequence dimension order which specifies where the sequence dimension
* is in a tensor's shape
*/
public void setSeqDimOrder(long[] seqDimOrder) {
this.seqDimOrder = seqDimOrder;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

class BeamBatchTensorList extends BatchTensorList {

// [batch, beam, seq=1]
private NDArray nextInputIds;

// [batch, beam]
private NDArray lastProbs;

// [batch, beam, seq_past + new_seq]
// The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastAttentionMask;

/* Variables below are one time step behind the above state variables. Ie, they contain all the past sequence but excludes the time step that corresponds to the above input. */

// [batch, beam, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastOutputIds;

// (k, v) * numLayer,
// kv: [batch, beam, heads, seq_past, kvfeature]
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDList pastKeyValues;

BeamBatchTensorList() {}

BeamBatchTensorList(
NDArray nextInputIds,
NDArray pastOutputIds,
NDList pastKeyValues,
NDArray pastAttentionMask,
NDArray lastProb) {
this.nextInputIds = nextInputIds;
this.pastKeyValues = pastKeyValues;
this.pastOutputIds = pastOutputIds;
this.pastAttentionMask = pastAttentionMask;
this.lastProbs = lastProb;
}

/** {@inheritDoc} */
@Override
public BatchTensorList fromList(NDList inputList, long[] seqDimOrder) {
return new BeamBatchTensorList();
}

/** {@inheritDoc} */
@Override
public NDList getList() {
return new NDList();
}

/**
* Returns the value of the nextInputIds.
*
* @return the value of nextInputIds
*/
public NDArray getNextInputIds() {
return nextInputIds;
}

public void setNextInputIds(NDArray nextInputIds) {
this.nextInputIds = nextInputIds;
}

/**
* Returns the value of the lastProbs.
*
* @return the value of lastProbs
*/
public NDArray getLastProbs() {
return lastProbs;
}

public void setLastProbs(NDArray lastProbs) {
this.lastProbs = lastProbs;
}

/** {@inheritDoc} */
@Override
public NDArray getPastAttentionMask() {
return pastAttentionMask;
}

/** {@inheritDoc} */
@Override
public void setPastAttentionMask(NDArray pastAttentionMask) {
this.pastAttentionMask = pastAttentionMask;
}

/** {@inheritDoc} */
@Override
public NDArray getPastOutputIds() {
return pastOutputIds;
}

/** {@inheritDoc} */
@Override
public void setPastOutputIds(NDArray pastOutputIds) {
this.pastOutputIds = pastOutputIds;
}

/** {@inheritDoc} */
@Override
public NDList getPastKeyValues() {
return pastKeyValues;
}

/** {@inheritDoc} */
@Override
public void setPastKeyValues(NDList pastKeyValues) {
this.pastKeyValues = pastKeyValues;
}
}
Loading

0 comments on commit 68c7a03

Please sign in to comment.