5
5
import android .app .Dialog ;
6
6
import android .content .Context ;
7
7
import android .os .Bundle ;
8
+ import android .system .ErrnoException ;
8
9
import android .text .method .ScrollingMovementMethod ;
9
10
import android .util .Log ;
10
11
import android .util .Pair ;
17
18
import android .widget .Toast ;
18
19
19
20
import java .io .File ;
21
+ import java .nio .ByteBuffer ;
22
+ import java .nio .ByteOrder ;
23
+ import java .nio .FloatBuffer ;
24
+ import java .nio .LongBuffer ;
20
25
import java .util .ArrayList ;
21
26
import java .util .Arrays ;
22
27
import java .util .List ;
28
33
import ai .onnxruntime .genai .Generator ;
29
34
import ai .onnxruntime .genai .GeneratorParams ;
30
35
import ai .onnxruntime .genai .Sequences ;
36
+ import ai .onnxruntime .genai .Tensor ;
31
37
import ai .onnxruntime .genai .TokenizerStream ;
32
38
import ai .onnxruntime .genai .demo .databinding .ActivityMainBinding ;
33
39
import ai .onnxruntime .genai .Model ;
@@ -45,7 +51,7 @@ public class MainActivity extends AppCompatActivity implements Consumer<String>
45
51
private TextView progressText ;
46
52
private ImageButton settingsButton ;
47
53
private static final String TAG = "genai.demo.MainActivity" ;
48
- private int maxLength = 100 ;
54
+ private int maxLength = 256 ;
49
55
private float lengthPenalty = 1.0f ;
50
56
51
57
private static boolean fileExists (Context context , String fileName ) {
@@ -55,6 +61,14 @@ private static boolean fileExists(Context context, String fileName) {
55
61
56
62
@ Override
57
63
protected void onCreate (Bundle savedInstanceState ) {
64
+ try {
65
+ // set ADSP_LIBRARY_PATH, QNN-specific
66
+ String adspLibraryPath = getApplicationContext ().getApplicationInfo ().nativeLibraryDir ;
67
+ android .system .Os .setenv ("ADSP_LIBRARY_PATH" , adspLibraryPath , true );
68
+ } catch (ErrnoException e ) {
69
+ throw new RuntimeException (e );
70
+ }
71
+
58
72
super .onCreate (savedInstanceState );
59
73
60
74
binding = ActivityMainBinding .inflate (getLayoutInflater ());
@@ -69,8 +83,8 @@ protected void onCreate(Bundle savedInstanceState) {
69
83
70
84
// Trigger the download operation when the application is created
71
85
try {
72
- downloadModels (
73
- getApplicationContext ());
86
+ createModelFromPath ( "/data/local/tmp/phi3.5_qnn_qc/phi3.5-split-qnn-qc" );
87
+ //downloadModels( getApplicationContext());
74
88
} catch (GenAIException e ) {
75
89
throw new RuntimeException (e );
76
90
}
@@ -135,17 +149,63 @@ public void run() {
135
149
GeneratorParams generatorParams = null ;
136
150
Generator generator = null ;
137
151
Sequences encodedPrompt = null ;
152
+ Tensor attentionMask = null , positionIds = null ;
138
153
try {
154
+ encodedPrompt = tokenizer .encode (promptQuestion_formatted );
155
+
139
156
stream = tokenizer .createStream ();
140
157
158
+ int maxSequenceLength = 128 ;
159
+ int contextLength = 4096 ;
160
+
161
+ int [] promptTokens = encodedPrompt .getSequence (0 );
162
+ int numPromptTokens = promptTokens .length ;
163
+
164
+ if (numPromptTokens > maxSequenceLength ) {
165
+ throw new RuntimeException ("numPromptTokens is greater than maxSequenceLength" );
166
+ }
167
+ if (numPromptTokens > contextLength ) {
168
+ throw new RuntimeException ("numPromptTokens is greater than contextLength" );
169
+ }
170
+
171
+ int paddingSize = maxSequenceLength - numPromptTokens ;
172
+
173
+ // paddedInputIds
174
+ int [] paddedInputIds = new int [maxSequenceLength ];
175
+ for (int i = 0 ; i < maxSequenceLength ; ++i ) {
176
+ paddedInputIds [i ] = i < paddingSize ? 0 : promptTokens [i - paddingSize ];
177
+ }
178
+
179
+ ByteOrder nativeOrder = ByteOrder .nativeOrder ();
180
+
181
+ // attentionMask
182
+ int attentionMaskPaddingSize = contextLength - numPromptTokens ;
183
+ ByteBuffer attentionMaskBuffer = ByteBuffer .allocateDirect (contextLength * 4 );
184
+ attentionMaskBuffer .order (nativeOrder );
185
+ FloatBuffer attentionMaskFloatBuffer = attentionMaskBuffer .asFloatBuffer ();
186
+ for (int i = 0 ; i < contextLength ; i ++) {
187
+ attentionMaskFloatBuffer .put (i < attentionMaskPaddingSize ? 0.0f : 1.0f );
188
+ }
189
+ attentionMask = new Tensor (attentionMaskBuffer , new long []{1 , contextLength }, Tensor .ElementType .float32 );
190
+
191
+ // positionIds
192
+ ByteBuffer positionIdsBuffer = ByteBuffer .allocateDirect (maxSequenceLength * 8 );
193
+ positionIdsBuffer .order (nativeOrder );
194
+ LongBuffer positionIdsLongBuffer = positionIdsBuffer .asLongBuffer ();
195
+ for (int i = 0 ; i < maxSequenceLength ; ++i ) {
196
+ positionIdsLongBuffer .put (i < paddingSize ? 0 : i - paddingSize );
197
+ }
198
+ positionIds = new Tensor (positionIdsBuffer , new long []{1 , maxSequenceLength }, Tensor .ElementType .int64 );
199
+
141
200
generatorParams = model .createGeneratorParams ();
142
201
//examples for optional parameters to format AI response
143
202
// https://onnxruntime.ai/docs/genai/reference/config.html
144
203
generatorParams .setSearchOption ("length_penalty" , lengthPenalty );
145
204
generatorParams .setSearchOption ("max_length" , maxLength );
205
+ generatorParams .setInput ("attention_mask_before_processor" , attentionMask );
206
+ generatorParams .setInput ("position_ids" , positionIds );
146
207
147
- encodedPrompt = tokenizer .encode (promptQuestion_formatted );
148
- generatorParams .setInput (encodedPrompt );
208
+ generatorParams .setInput (paddedInputIds , maxSequenceLength , 1 );
149
209
150
210
generator = new Generator (model , generatorParams );
151
211
@@ -175,7 +235,7 @@ public void run() {
175
235
long totalTime = System .currentTimeMillis () - firstTokenTime ;
176
236
177
237
float promptProcessingTime = (firstTokenTime - startTime )/ 1000.0f ;
178
- float tokensPerSecond = (1000 * (numTokens -1 )) / totalTime ;
238
+ float tokensPerSecond = (1000.0f * (numTokens - 1 )) / totalTime ;
179
239
180
240
runOnUiThread (() -> {
181
241
sendMsgIB .setEnabled (true );
@@ -192,6 +252,8 @@ public void run() {
192
252
Log .e (TAG , "Exception occurred during model query: " + e .getMessage ());
193
253
}
194
254
finally {
255
+ if (positionIds != null ) positionIds .close ();
256
+ if (attentionMask != null ) attentionMask .close ();
195
257
if (generator != null ) generator .close ();
196
258
if (encodedPrompt != null ) encodedPrompt .close ();
197
259
if (stream != null ) stream .close ();
@@ -217,8 +279,12 @@ protected void onDestroy() {
217
279
super .onDestroy ();
218
280
}
219
281
220
- private void downloadModels (Context context ) throws GenAIException {
282
+ private void createModelFromPath (String path ) throws GenAIException {
283
+ model = new Model (path );
284
+ tokenizer = model .createTokenizer ();
285
+ }
221
286
287
+ private void downloadModels (Context context ) throws GenAIException {
222
288
final String baseUrl = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/" ;
223
289
List <String > files = Arrays .asList (
224
290
"added_tokens.json" ,
0 commit comments