-
Notifications
You must be signed in to change notification settings - Fork 688
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[api] implements text-generation search algorithm (#2637)
Co-authored-by: KexinFeng <fenkexin@amazon.com> Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
- Loading branch information
1 parent
c6a609f
commit 68c7a03
Showing
22 changed files
with
2,701 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
174 changes: 174 additions & 0 deletions
174
api/src/main/java/ai/djl/modality/nlp/generate/BatchTensorList.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.modality.nlp.generate; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDList; | ||
|
||
/** | ||
* BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration | ||
* of the autoregressive loop. | ||
* | ||
* <p>It is a struct consisting of NDArrays, whose first dimension is batch, and also contains | ||
* sequence dimension (whose position in tensor's shape is specified by seqDimOrder). The SeqBatcher | ||
* batch operations will operate on these two dimensions. | ||
*/ | ||
public abstract class BatchTensorList { | ||
// [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDArray pastOutputIds; | ||
|
||
// [batch, seq_past] | ||
// The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDArray pastAttentionMask; | ||
|
||
// (k, v) * numLayer, | ||
// kv: [batch, heads, seq_past, kvfeature] | ||
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDList pastKeyValues; | ||
|
||
// Sequence dimension order among all dimensions for each element in the batch list. | ||
private long[] seqDimOrder; | ||
|
||
BatchTensorList() {} | ||
|
||
/** | ||
* Constructs a new {@code BatchTensorList} instance. | ||
* | ||
* @param list the NDList that contains the serialized version of the batch tensors | ||
* @param seqDimOrder the sequence dimension order that specifies where the sequence dimension | ||
* is in a tensor's shape | ||
*/ | ||
BatchTensorList(NDList list, long[] seqDimOrder) { | ||
this.seqDimOrder = seqDimOrder; | ||
pastOutputIds = list.get(0); | ||
pastAttentionMask = list.get(1); | ||
pastKeyValues = list.subNDList(2); | ||
} | ||
|
||
/** | ||
* Constructs a new {@code BatchTensorList} instance. | ||
* | ||
* @param pastOutputIds past output token ids | ||
* @param pastAttentionMask past attention mask | ||
* @param pastKeyValues past kv cache | ||
* @param seqDimOrder the sequence dimension order that specifies where the sequence dimension | ||
* is in a tensor's shape | ||
*/ | ||
BatchTensorList( | ||
NDArray pastOutputIds, | ||
NDArray pastAttentionMask, | ||
NDList pastKeyValues, | ||
long[] seqDimOrder) { | ||
this.pastKeyValues = pastKeyValues; | ||
this.pastOutputIds = pastOutputIds; | ||
this.pastAttentionMask = pastAttentionMask; | ||
this.seqDimOrder = seqDimOrder; | ||
} | ||
|
||
/** | ||
* Constructs a new {@code BatchTensorList} instance from the serialized version of the batch | ||
* tensors. | ||
* | ||
* <p>The pastOutputIds has to be the first in the output list. | ||
* | ||
* @param inputList the serialized version of the batch tensors | ||
* @param seqDimOrder the sequence dimension order that specifies where the sequence dimension | ||
* is in a tensor's shape | ||
* @return BatchTensorList | ||
*/ | ||
public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder); | ||
|
||
/** | ||
* Returns the serialized version of the BatchTensorList. The pastOutputIds has to be the first | ||
* in the output list. | ||
* | ||
* @return the NDList that contains the serialized BatchTensorList | ||
*/ | ||
public abstract NDList getList(); | ||
|
||
/** | ||
* Returns the sequence dimension order which specifies where the sequence dimension is in a | ||
* tensor's shape. | ||
* | ||
* @return the sequence dimension order which specifies where the sequence dimension is in a | ||
* tensor's shape | ||
*/ | ||
public long[] getSeqDimOrder() { | ||
return seqDimOrder; | ||
} | ||
|
||
/** | ||
* Returns the value of the pastOutputIds. | ||
* | ||
* @return the value of pastOutputIds | ||
*/ | ||
public NDArray getPastOutputIds() { | ||
return pastOutputIds; | ||
} | ||
|
||
/** | ||
* Sets the past output token ids. | ||
* | ||
* @param pastOutputIds the past output token ids | ||
*/ | ||
public void setPastOutputIds(NDArray pastOutputIds) { | ||
this.pastOutputIds = pastOutputIds; | ||
} | ||
|
||
/** | ||
* Returns the value of the pastAttentionMask. | ||
* | ||
* @return the value of pastAttentionMask | ||
*/ | ||
public NDArray getPastAttentionMask() { | ||
return pastAttentionMask; | ||
} | ||
|
||
/** | ||
* Sets the attention mask. | ||
* | ||
* @param pastAttentionMask the attention mask | ||
*/ | ||
public void setPastAttentionMask(NDArray pastAttentionMask) { | ||
this.pastAttentionMask = pastAttentionMask; | ||
} | ||
|
||
/** | ||
* Returns the value of the pastKeyValues. | ||
* | ||
* @return the value of pastKeyValues | ||
*/ | ||
public NDList getPastKeyValues() { | ||
return pastKeyValues; | ||
} | ||
|
||
/** | ||
* Sets the kv cache. | ||
* | ||
* @param pastKeyValues the kv cache | ||
*/ | ||
public void setPastKeyValues(NDList pastKeyValues) { | ||
this.pastKeyValues = pastKeyValues; | ||
} | ||
|
||
/** | ||
* Sets the sequence dimension order which specifies where the sequence dimension is in a | ||
* tensor's shape. | ||
* | ||
* @param seqDimOrder the sequence dimension order which specifies where the sequence dimension | ||
* is in a tensor's shape | ||
*/ | ||
public void setSeqDimOrder(long[] seqDimOrder) { | ||
this.seqDimOrder = seqDimOrder; | ||
} | ||
} |
128 changes: 128 additions & 0 deletions
128
api/src/main/java/ai/djl/modality/nlp/generate/BeamBatchTensorList.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.modality.nlp.generate; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDList; | ||
|
||
class BeamBatchTensorList extends BatchTensorList { | ||
|
||
// [batch, beam, seq=1] | ||
private NDArray nextInputIds; | ||
|
||
// [batch, beam] | ||
private NDArray lastProbs; | ||
|
||
// [batch, beam, seq_past + new_seq] | ||
// The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDArray pastAttentionMask; | ||
|
||
/* Variables below are one time step behind the above state variables. Ie, they contain all the past sequence but excludes the time step that corresponds to the above input. */ | ||
|
||
// [batch, beam, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDArray pastOutputIds; | ||
|
||
// (k, v) * numLayer, | ||
// kv: [batch, beam, heads, seq_past, kvfeature] | ||
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDList pastKeyValues; | ||
|
||
BeamBatchTensorList() {} | ||
|
||
BeamBatchTensorList( | ||
NDArray nextInputIds, | ||
NDArray pastOutputIds, | ||
NDList pastKeyValues, | ||
NDArray pastAttentionMask, | ||
NDArray lastProb) { | ||
this.nextInputIds = nextInputIds; | ||
this.pastKeyValues = pastKeyValues; | ||
this.pastOutputIds = pastOutputIds; | ||
this.pastAttentionMask = pastAttentionMask; | ||
this.lastProbs = lastProb; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public BatchTensorList fromList(NDList inputList, long[] seqDimOrder) { | ||
return new BeamBatchTensorList(); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDList getList() { | ||
return new NDList(); | ||
} | ||
|
||
/** | ||
* Returns the value of the nextInputIds. | ||
* | ||
* @return the value of nextInputIds | ||
*/ | ||
public NDArray getNextInputIds() { | ||
return nextInputIds; | ||
} | ||
|
||
public void setNextInputIds(NDArray nextInputIds) { | ||
this.nextInputIds = nextInputIds; | ||
} | ||
|
||
/** | ||
* Returns the value of the lastProbs. | ||
* | ||
* @return the value of lastProbs | ||
*/ | ||
public NDArray getLastProbs() { | ||
return lastProbs; | ||
} | ||
|
||
public void setLastProbs(NDArray lastProbs) { | ||
this.lastProbs = lastProbs; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDArray getPastAttentionMask() { | ||
return pastAttentionMask; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public void setPastAttentionMask(NDArray pastAttentionMask) { | ||
this.pastAttentionMask = pastAttentionMask; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDArray getPastOutputIds() { | ||
return pastOutputIds; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public void setPastOutputIds(NDArray pastOutputIds) { | ||
this.pastOutputIds = pastOutputIds; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDList getPastKeyValues() { | ||
return pastKeyValues; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public void setPastKeyValues(NDList pastKeyValues) { | ||
this.pastKeyValues = pastKeyValues; | ||
} | ||
} |
Oops, something went wrong.