Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC of LLMDecoder on both Pt trace model and Onnx model (LMAdapter part) #2509

Closed
wants to merge 10 commits into from

Conversation

KexinFeng
Copy link
Contributor

@KexinFeng KexinFeng commented Apr 7, 2023

Interface stepGenerator will be renamed to LMAdapter

Description

This is a POC which shows that the GPT2 can be traced and exported into DJL for a step inference. Moreover, this step inference allows the following:

  1. batch sequence input
  2. cached past_key_values input, which is a Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], ...] in python.

In this POC, both GPT2_init.pt and GPT2.pt are used since they have different inputs, the former not having past_key_values input.

Run TextGeneration.java to test it.

Design

The new classes are the following:

public class JavaDecoder {
    private StepGenerator generator;
    
    private DecodeParam decodingParams;
    
    public Text generateText() {};
}

JavaDecoder is made engine agnostic. The inference loop can be implemented inside generateText().

public interface StepGenerator {
    private String modelUrl;  // used to load the model
    
    public Token stepGen(inputIds, positionIds, attentionMask, pastKeyValues);  
    // This function is the same as P(w_n|w_{n-1}, w_{n-2}, ...)
}

This interface is a wrapper over the model files from different sources, e.g. gpt2.pt, gpt2.onnx, etc.
It can be seen as a Java abstraction of a causal language model, i.e. the conditional probability p_\theta(v_t | x_1, ..., x_{t-1}), i.e. given the past tokens up to a certain time x_{< t}, the probability that the next token is v, taken from a vocabulary set V. \theta is the model's weight.

It will be implemented individually to adapt for different models, which applies to different scenarios.

  1. Traced GPT2.pt + JNI_for_pytorch. Its POC is done.
  2. GPT2.onnx + OnnxEngine. This will be the use case corresponding to the graph above.
  3. GPT2 + nueronx
  4. FasterTransformer + tensorRTEngine

Model tracing

The onnx model gpt2.onnx is loaded from https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-using-past-keysvalues-in-the-decoder.
See also https://github.com/huggingface/optimum/releases.

The gpt2.pt is traced with the following scripts: https://gist.github.com/KexinFeng/4876c6bfb27f40abffe4d5a92c02acff

@KexinFeng KexinFeng requested review from zachgk, frankfliu and a team as code owners April 7, 2023 00:49
@KexinFeng KexinFeng marked this pull request as draft April 7, 2023 00:53
@KexinFeng KexinFeng force-pushed the LLMDecoder branch 2 times, most recently from fe1aa0a to a74696a Compare April 7, 2023 13:29
@KexinFeng KexinFeng changed the title POC of LLMDecoder POC of LLMDecoder on both Pt trace model and Onnx model Apr 14, 2023
@KexinFeng KexinFeng changed the title POC of LLMDecoder on both Pt trace model and Onnx model POC of LLMDecoder on both Pt trace model and Onnx model (stepGenerator part) Apr 14, 2023
@KexinFeng KexinFeng changed the title POC of LLMDecoder on both Pt trace model and Onnx model (stepGenerator part) POC of LLMDecoder on both Pt trace model and Onnx model (LMAdapter part) Apr 14, 2023
@KexinFeng KexinFeng mentioned this pull request Apr 17, 2023
@KexinFeng KexinFeng marked this pull request as ready for review April 17, 2023 17:04
* range(|inputIds|). This means for each i, the output probability is conditional on the past
* sequence up to i.
*/
public interface LMAdapter extends AutoCloseable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per our discussion, see if you can remake the LMAdapter to be a type of Block. Then, you can load it using Model.load() rather than requiring the special handling in Engine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type change to Block is done. But the special handling in Engine is not avoided. GPT2PtLMBlock depends on module pytorch-engines.main and can not be used in engine agnostic frontend. The reason GPT2PtLMBlock is engine specific is that it adapts to different engines and that engine-specific types are used, like IValue for pytorch.

mainPt(args);
}

public static void mainOnnx(String[] args) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm noticing that these examples are almost identical. Is it possible to refactor out a common helper. So, each one would just initialize the LMAdapter and then pass it into the helper. Maybe the helper would look something like public static void generateInternal(LMAdapter generator);

Copy link
Contributor Author

@KexinFeng KexinFeng May 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The testing data used here, like inputIds, positionIds and attentionMask, are more of case by case. It's like those test examples in NDIndexTest.java. They are not likely to be reused. The rest common parts have been factored out.

@KexinFeng
Copy link
Contributor Author

See the latest PR for the updates.

@KexinFeng
Copy link
Contributor Author

Merged in a different PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants