diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/BatchTensorList.java b/api/src/main/java/ai/djl/modality/nlp/generate/BatchTensorList.java index f7eedaf5533..df5520907cf 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/BatchTensorList.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/BatchTensorList.java @@ -14,13 +14,12 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; - -// BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration -// of the -// autoregressive loop. -// It is a struct consisting of NDArrays, whose first dimension is batch, and also contains -// sequence dimension (whose position in tensor's shape is specified by seqDimOrder). -// The SeqBatcher batch operations will operate on these two dimensions. +/** + * BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration + * of the autoregressive loop It is a struct consisting of NDArrays, whose first dimension is batch, + * and also contains sequence dimension (whose position in tensor's shape is specified by seqDimOrder). + * The SeqBatcher batch operations will operate on these two dimensions. + */ public abstract class BatchTensorList { // [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow. private NDArray pastOutputIds; @@ -39,6 +38,13 @@ public abstract class BatchTensorList { BatchTensorList() {} + /** + * Constructs a BatchTensorList. + * + * @param list the NDList that contains the serialized version of the batch tensors + * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension + * is in a tensor's shape + */ BatchTensorList(NDList list, long[] seqDimOrder) { this.seqDimOrder = seqDimOrder; pastOutputIds = list.get(0); @@ -46,6 +52,15 @@ public abstract class BatchTensorList { pastKeyValues = list.subNDList(2); } + /** + * Constructs a BatchTensorList. + * + * @param pastOutputIds past output token ids + * @param pastAttentionMask past attention mask + * @param pastKeyValues past kv cache + * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension + * is in a tensor's shape + */ BatchTensorList( NDArray pastOutputIds, NDArray pastAttentionMask, @@ -57,11 +72,32 @@ public abstract class BatchTensorList { this.seqDimOrder = seqDimOrder; } + /** + * Construct a BatchTensorList from the serialized version of the batch tensors. The + * pastOutputIds has to be the first in the output list. + * + * @param inputList the serialized version of the batch tensors + * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension + * is in a tensor's shape + * @return BatchTensorList + */ public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder); - // The pastOutputIds has to be the first in the output list + /** + * Gets the serialized version of the BatchTensorList. The pastOutputIds has to be the first in + * the output list. + * + * @return the NDList that contains the serialized BatchTensorList + */ public abstract NDList getList(); + /** + * Gets the sequence dimension order which specifies where the sequence dimension is in a + * tensor's shape. + * + * @return the sequence dimension order which specifies where the sequence dimension is in a + * tensor's shape + */ public long[] getSeqDimOrder() { return seqDimOrder; } @@ -75,6 +111,11 @@ public NDArray getPastOutputIds() { return pastOutputIds; } + /** + * Sets the past output token ids. + * + * @param pastOutputIds the past output token ids + */ public void setPastOutputIds(NDArray pastOutputIds) { this.pastOutputIds = pastOutputIds; } @@ -88,6 +129,11 @@ public NDArray getPastAttentionMask() { return pastAttentionMask; } + /** + * Sets the attention mask. + * + * @param pastAttentionMask the attention mask + */ public void setPastAttentionMask(NDArray pastAttentionMask) { this.pastAttentionMask = pastAttentionMask; } @@ -101,10 +147,22 @@ public NDList getPastKeyValues() { return pastKeyValues; } + /** + * Sets the kv cache. + * + * @param pastKeyValues the kv cache + */ public void setPastKeyValues(NDList pastKeyValues) { this.pastKeyValues = pastKeyValues; } + /** + * Sets the sequence dimension order which specifies where the sequence dimension is in a + * tensor's shape. + * + * @param seqDimOrder the sequence dimension order which specifies where the sequence dimension + * is in a tensor's shape + */ public void setSeqDimOrder(long[] seqDimOrder) { this.seqDimOrder = seqDimOrder; } diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/CausalLMOutput.java b/api/src/main/java/ai/djl/modality/nlp/generate/CausalLMOutput.java index 1e72e9293f4..bcca99ecb41 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/CausalLMOutput.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/CausalLMOutput.java @@ -32,11 +32,24 @@ public class CausalLMOutput { // The cache of past sequence. seq-dim-size == |seq_past| + |inputIds| private NDList pastKeyValuesList; + /** + * Construct the CausalLMOutput. + * + * @param logits the logits NDArray + * @param pastKeyValues the key-value cache + */ public CausalLMOutput(NDArray logits, NDList pastKeyValues) { this.logits = logits; this.pastKeyValuesList = pastKeyValues; } + /** + * Construct the CausalLMOutput. + * + * @param logits the logits NDArray + * @param hiddenState the first layer hiddenStates used as word embedding + * @param pastKeyValueList the key-value cache + */ public CausalLMOutput(NDArray logits, NDArray hiddenState, NDList pastKeyValueList) { this.logits = logits; this.pastKeyValuesList = pastKeyValueList; @@ -52,6 +65,11 @@ public NDArray getLogits() { return logits; } + /** + * Sets the value of the logits. + * + * @param logits value of logits NDArray + */ public void setLogits(NDArray logits) { this.logits = logits; } diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/ContrastiveSeqBatchScheduler.java b/api/src/main/java/ai/djl/modality/nlp/generate/ContrastiveSeqBatchScheduler.java index 291f72c07d6..3a6709dc356 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/ContrastiveSeqBatchScheduler.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/ContrastiveSeqBatchScheduler.java @@ -25,13 +25,24 @@ import java.util.function.Function; import java.util.stream.Collectors; +/** + * {@code ContrastiveSeqBatchScheduler} is a class which implements the contrastive search algorithm + * used in SeqBatchScheduler. + */ public class ContrastiveSeqBatchScheduler extends SeqBatchScheduler { + /** + * Construct a ContrastiveSeqBatchScheduler. + * + * @param lmBlock the predictor containing language model + * @param config the autoregressive search configuration + */ public ContrastiveSeqBatchScheduler( Predictor lmBlock, SearchConfig config) { super(lmBlock, config); } + /** {@inheritDoc} */ @Override public SeqBatcher initForward(NDArray inputIds, NDArray batchUids) throws TranslateException { try (NDScope scope = new NDScope()) { @@ -72,6 +83,7 @@ public SeqBatcher initForward(NDArray inputIds, NDArray batchUids) throws Transl } } + /** {@inheritDoc} */ @Override public NDArray inferenceCall() throws TranslateException { NDArray outputIds; diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/SearchConfig.java b/api/src/main/java/ai/djl/modality/nlp/generate/SearchConfig.java index 8169f223c63..4f65b87a59b 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/SearchConfig.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/SearchConfig.java @@ -12,6 +12,10 @@ */ package ai.djl.modality.nlp.generate; +/** + * {@code SearchConfig} is a class whose fields are parameters used for autoregressive search / text + * generation. + */ public class SearchConfig { private int k; @@ -30,7 +34,7 @@ public SearchConfig() { this.maxSeqLength = 30; this.eosTokenId = 50256; this.padTokenId = 50256; - this.suffixPadding = true; + this.suffixPadding = false; } /** @@ -42,6 +46,11 @@ public int getK() { return k; } + /** + * Sets the value for the topk choice. + * + * @param k the value for topk choice + */ public void setK(int k) { this.k = k; } @@ -55,6 +64,11 @@ public float getAlpha() { return alpha; } + /** + * Sets the value of alpha the penalty for repetition. + * + * @param alpha the value of the penalty for repetition + */ public void setAlpha(float alpha) { this.alpha = alpha; } @@ -68,6 +82,11 @@ public int getBeam() { return beam; } + /** + * Sets the value of beam size. + * + * @param beam the value of beam size + */ public void setBeam(int beam) { this.beam = beam; } @@ -81,6 +100,11 @@ public int getMaxSeqLength() { return maxSeqLength; } + /** + * Sets the value of max sequence length. + * + * @param maxSeqLength the value max sequence length + */ public void setMaxSeqLength(int maxSeqLength) { this.maxSeqLength = maxSeqLength; } @@ -94,6 +118,11 @@ public long getPadTokenId() { return padTokenId; } + /** + * Sets the value of padTokenId. + * + * @param padTokenId the token id for padding + */ public void setPadTokenId(long padTokenId) { this.padTokenId = padTokenId; } @@ -116,6 +145,11 @@ public boolean isSuffixPadding() { return suffixPadding; } + /** + * Sets the value of suffixPadding or rightPadding. + * + * @param suffixPadding whether the padding is from right + */ public void setSuffixPadding(boolean suffixPadding) { this.suffixPadding = suffixPadding; } diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatchScheduler.java b/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatchScheduler.java index f87805ad34e..a6c3ad036c0 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatchScheduler.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatchScheduler.java @@ -27,11 +27,13 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -// This is a scheduler, serving as an API to the consumer of the system, allowing for three major -// actions: initForward, addBatch, fastForward, collectResults. -// An optimal control sequence should be solved, after considering the time consumption of each -// action, the batch size and sequence length of queueing requests. Such optimal control solver -// needs additional effort. Primitive policy is setting several thresholds. +/** + * This is a scheduler, serving as an API to the consumer of the system, allowing for three major + * actions: initForward, addBatch, fastForward, collectResults. An optimal control sequence should + * be solved, after considering the time consumption of each action, the batch size and sequence + * length of queueing requests. Such optimal control solver needs additional effort. Primitive + * policy is setting several thresholds. + */ public abstract class SeqBatchScheduler { private static final Logger logger = LoggerFactory.getLogger(SeqBatchScheduler.class); @@ -44,6 +46,12 @@ public abstract class SeqBatchScheduler { Map results; + /** + * Constructor of seqBatchScheduler. + * + * @param lmBlock the predictor that cont + * @param config the search parameter configuration + */ public SeqBatchScheduler(Predictor lmBlock, SearchConfig config) { this.predictor = lmBlock; this.config = config; @@ -51,9 +59,12 @@ public SeqBatchScheduler(Predictor lmBlock, SearchConfig } /** - * Initialize the iteration and SeqBatcher + * Initialize the iteration and SeqBatcher. * - * @return SeqBatcher. Stores the search state and operate on the BatchTensorList. + * @param inputIds the input token ids. + * @param batchUids the request uid identifying a sequence + * @return SeqBatcher Stores the search state and operate on the BatchTensorList + * @throws TranslateException if forward fails */ public abstract SeqBatcher initForward(NDArray inputIds, NDArray batchUids) throws TranslateException; @@ -61,7 +72,9 @@ public abstract SeqBatcher initForward(NDArray inputIds, NDArray batchUids) /** * Go forward for a given number of iterations. * - * @return boolean. Indicate whether the Batch is empty. + * @param count the time of forward calls + * @return boolean Indicate whether the Batch is empty + * @throws TranslateException if forward fails */ public boolean incrementForward(int count) throws TranslateException { int i = 0; @@ -82,9 +95,21 @@ public boolean incrementForward(int count) throws TranslateException { return false; } + /** + * An inference call in an iteration. + * + * @return the output token ids + * @throws TranslateException if forward fails + */ abstract NDArray inferenceCall() throws TranslateException; - /** Add new batch. */ + /** + * Add new batch. + * + * @param inputIds the input token ids. + * @param batchUids the request uid identifying a sequence + * @throws TranslateException if forward fails + */ public void addRequest(NDArray inputIds, NDArray batchUids) throws TranslateException { SeqBatcher seqBatcherNew = initForward(inputIds, batchUids); if (seqBatcher == null) { @@ -94,13 +119,24 @@ public void addRequest(NDArray inputIds, NDArray batchUids) throws TranslateExce } } - /** Collect finished results. */ + /** + * Collect finished results. + * + * @return the outputs stored as a map from requestUid to output token ids + */ public Map collectResults() { Map output = results; results = new ConcurrentHashMap<>(); return output; } + /** + * Compute the offSets by linear search from the left. + * + * @param inputIds input token ids + * @param config search configuration + * @return the offsets NDArray + */ static NDArray computeOffSets(NDArray inputIds, SearchConfig config) { int numBatch = Math.toIntExact(inputIds.getShape().get(0)); int initSeqSize = Math.toIntExact(inputIds.getShape().get(1)); @@ -123,6 +159,13 @@ static NDArray computeOffSets(NDArray inputIds, SearchConfig config) { return manager.create(offSetsArray).reshape(-1, 1); } + /** + * Compute the attention mask by linear search from the left. + * + * @param inputIds input token ids + * @param config search configuration + * @return the attention mask NDArray + */ static NDArray computeAttentionMask(NDArray inputIds, SearchConfig config) { int numBatch = Math.toIntExact(inputIds.getShape().get(0)); int initSeqSize = Math.toIntExact(inputIds.getShape().get(1)); @@ -150,6 +193,16 @@ static NDArray computeAttentionMask(NDArray inputIds, SearchConfig config) { return attentionMask; } + /** + * Compute the position ids by linear search from the left. + * + * @param inputIds input token ids + * @param offSets the offset + * @param pastSeqLength past sequence length + * @param repeat the number of repeats used in interleave-repeating the position_ids to multiple + * rows + * @return the position ids NDArray + */ static NDArray computePositionIds( NDArray inputIds, NDArray offSets, long pastSeqLength, int repeat) { NDManager manager = inputIds.getManager(); diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatcher.java b/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatcher.java index 04d8ae0e947..747c81ea9c5 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatcher.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatcher.java @@ -24,8 +24,11 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -// This stores the search state (BatchTensorList), the control variables (e.g. seqLength, offSets, -// etc), and batch operations (merge, trim, exitCriteria, etc) on BatchTensorList. +/** + * {@code SeqBatcher} stores the search state (BatchTensorList), the control variables (e.g. + * seqLength, offSets, etc), and batch operations (merge, trim, exitCriteria, etc) on + * BatchTensorList. + */ public class SeqBatcher { NDManager manager; @@ -56,17 +59,32 @@ public class SeqBatcher { exitIndexEndPosition = new ConcurrentHashMap<>(); } + /** + * Get the batch data which is stored as a {@code BatchTensorList}. + * + * @return the batch data stored as BatchTensorList + */ public BatchTensorList getData() { return data; } - /** Add new batch. Modify the batch dimension and the left padding. */ + /** + * Add new batch. Modify the batch dimension and the left padding. + * + * @param seqBatcherNew the seqBatcher to add. + */ public void addBatch(SeqBatcher seqBatcherNew) { merge(this, seqBatcherNew, seqLength - seqBatcherNew.seqLength); // manager and finishedSequences stay the same; } - /** Merge two batchers together. Modify the batch dimension and the left padding. */ + /** + * Merge two batchers together. Modify the batch dimension and the left padding. + * + * @param seqBatcher1 the first seqBatcher + * @param seqBatcher2 the second seqBatcher + * @param seqDelta the sequence length difference + */ private void merge(SeqBatcher seqBatcher1, SeqBatcher seqBatcher2, long seqDelta) { if (seqDelta < 0) { SeqBatcher swapTmp = seqBatcher1; @@ -149,6 +167,10 @@ private void merge(SeqBatcher seqBatcher1, SeqBatcher seqBatcher2, long seqDelta /** * Check which batch needs to exit, according certain criteria like EOS or maxLength. It is an * iteration over batch and is thus also considered as batch operation. + * + * @param outputIds output token ids in an incremental forward call + * @param maxLength max total sequence length + * @param eosTokenId end of sentence token id */ public void exitCriteria(NDArray outputIds, long maxLength, long eosTokenId) { long[] outputIdsArray = outputIds.toLongArray(); @@ -162,7 +184,11 @@ public void exitCriteria(NDArray outputIds, long maxLength, long eosTokenId) { } } - /** Collect the finished sequences and trim the left padding. */ + /** + * Collect the finished sequences and trim the left padding. + * + * @return a map that stores request id to output token ids + */ public Map collectAndTrim() { if (exitIndexEndPosition.isEmpty()) { return new ConcurrentHashMap<>(); @@ -261,6 +287,11 @@ public Map collectAndTrim() { } } + /** + * Compute the position ids by linear search from the left. + * + * @return the boolean indicating whether all sequences are empty + */ public boolean sequenceComplete() { return !exitIndexEndPosition.isEmpty(); } diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/StepGeneration.java b/api/src/main/java/ai/djl/modality/nlp/generate/StepGeneration.java index bfafa822eb5..7c8f4062fc9 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/StepGeneration.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/StepGeneration.java @@ -18,10 +18,25 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; +/** + * {@code StepGeneration} is a utility class containing the step generation utility functions used + * in autoregressive search. + */ public final class StepGeneration { private StepGeneration() {} + /** + * Generate the output token id and selecting indices used in contrastive search. + * + * @param topKIds the topk candidate token ids + * @param logits the logits from the language model + * @param contextHiddenStates the embedding of the past generated token ids + * @param topkHiddenStates the embedding of the topk candidate token ids + * @param offSets the offsets + * @param alpha the repetition penalty + * @return the output token ids and selecting indices + */ public static NDList constrastiveStepGeneration( NDArray topKIds, NDArray logits, @@ -80,6 +95,12 @@ public static NDList constrastiveStepGeneration( // b = torch.randn(batch, topK, dim) // result = torch.einsum('bik,bjk->bij', a, b) + /** + * Generate the output token id for greedy search. + * + * @param logits the logits from the language model + * @return the output token ids + */ public static NDArray greedyStepGen(NDArray logits) { // logits: [batch, seq, probDim] assert logits.getShape().getShape().length == 3 : "unexpected input"; @@ -87,6 +108,15 @@ public static NDArray greedyStepGen(NDArray logits) { return logits.argMax(-1).expandDims(1); // [batch, vacDim] } + /** + * Generate the output token id and selecting indices used in beam search. + * + * @param lastProbs the probabilities of the past prefix sequences + * @param logits the logits + * @param numBatch number of batch + * @param numBeam number of beam + * @return the output token ids and selecting indices + */ public static NDList beamStepGeneration( NDArray lastProbs, NDArray logits, long numBatch, long numBeam) { // [batch * beamSource, seq, probDim] -> [batch, beamSource, probDim] diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java b/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java index e2bc0a7b33f..3d2c437f463 100644 --- a/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java +++ b/api/src/main/java/ai/djl/modality/nlp/generate/TextGenerator.java @@ -25,6 +25,11 @@ import java.util.function.Function; import java.util.stream.Collectors; +/** + * {@code TextGenerator} is an LMSearch (language model search) which contains multiple + * autoregressive search methods. It has a Predictor from NDList to CausalLMOutput, which is called + * inside an autoregressive inference loop. + */ public class TextGenerator { private String searchName; @@ -33,6 +38,13 @@ public class TextGenerator { private NDArray positionOffset; + /** + * Construct a text generator. + * + * @param predictor the language model + * @param searchName the autoregressive search name + * @param searchConfig the autoregressive search configuration + */ public TextGenerator( Predictor predictor, String searchName, @@ -42,6 +54,13 @@ public TextGenerator( this.config = searchConfig; } + /** + * Greedy search. + * + * @param inputIds the input token ids. + * @return the output token ids stored as NDArray + * @throws TranslateException if forward fails + */ @SuppressWarnings("try") public NDArray greedySearch(NDArray inputIds) throws TranslateException { NDArray attentionMask = prepareAttentionMaskOffset(inputIds, config); @@ -513,6 +532,13 @@ private NDList prepareInput( return new NDList(inputIds, positionIds, attentionMask); } + /** + * Forward function call to generate text. + * + * @param inputIds the input token ids + * @return generated token ids + * @throws TranslateException if prediction fails + */ public NDArray forward(NDArray inputIds) throws TranslateException { switch (searchName) { case "greedy":