-
Notifications
You must be signed in to change notification settings - Fork 688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
POC of LLMDecoder on both Pt trace model and Onnx model (LMAdapter part) #2509
Conversation
fe1aa0a
to
a74696a
Compare
* range(|inputIds|). This means for each i, the output probability is conditional on the past | ||
* sequence up to i. | ||
*/ | ||
public interface LMAdapter extends AutoCloseable { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per our discussion, see if you can remake the LMAdapter to be a type of Block
. Then, you can load it using Model.load()
rather than requiring the special handling in Engine
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type change to Block is done. But the special handling in Engine is not avoided. GPT2PtLMBlock
depends on module pytorch-engines.main and can not be used in engine agnostic frontend. The reason GPT2PtLMBlock
is engine specific is that it adapts to different engines and that engine-specific types are used, like IValue
for pytorch.
engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java
Show resolved
Hide resolved
mainPt(args); | ||
} | ||
|
||
public static void mainOnnx(String[] args) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm noticing that these examples are almost identical. Is it possible to refactor out a common helper. So, each one would just initialize the LMAdapter
and then pass it into the helper. Maybe the helper would look something like public static void generateInternal(LMAdapter generator);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The testing data used here, like inputIds, positionIds and attentionMask, are more of case by case. It's like those test examples in NDIndexTest.java. They are not likely to be reused. The rest common parts have been factored out.
See the latest PR for the updates. |
Merged in a different PR. |
Interface
stepGenerator
will be renamed toLMAdapter
Description
This is a POC which shows that the GPT2 can be traced and exported into DJL for a step inference. Moreover, this step inference allows the following:
In this POC, both GPT2_init.pt and GPT2.pt are used since they have different inputs, the former not having past_key_values input.
Run
TextGeneration.java
to test it.Design
The new classes are the following:
JavaDecoder is made engine agnostic. The inference loop can be implemented inside generateText().
This interface is a wrapper over the model files from different sources, e.g. gpt2.pt, gpt2.onnx, etc.
It can be seen as a Java abstraction of a causal language model, i.e. the conditional probability
p_\theta(v_t | x_1, ..., x_{t-1})
, i.e. given the past tokens up to a certain time x_{< t}, the probability that the next token is v, taken from a vocabulary set V. \theta is the model's weight.It will be implemented individually to adapt for different models, which applies to different scenarios.
Model tracing
The onnx model gpt2.onnx is loaded from https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-using-past-keysvalues-in-the-decoder.
See also https://github.com/huggingface/optimum/releases.
The gpt2.pt is traced with the following scripts: https://gist.github.com/KexinFeng/4876c6bfb27f40abffe4d5a92c02acff