Skip to content

Commit 064ba81

Browse files
committed
adapt android phi3 example for qnn phi3.5 model
1 parent d0639c8 commit 064ba81

File tree

5 files changed

+85
-10
lines changed

5 files changed

+85
-10
lines changed

mobile/examples/phi-3/android/app/build.gradle.kts

+9-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ android {
1717

1818
ndk {
1919
//noinspection ChromeOsAbiSupport
20-
abiFilters += listOf("arm64-v8a", "x86_64")
20+
//abiFilters += listOf("arm64-v8a", "x86_64")
21+
abiFilters += listOf("arm64-v8a")
2122
}
2223
}
2324

@@ -39,6 +40,9 @@ android {
3940
buildFeatures {
4041
viewBinding = true
4142
}
43+
44+
// set this so QNN libs will show up in nativeLibraryDir
45+
packaging.jniLibs.useLegacyPackaging = true
4246
}
4347

4448
dependencies {
@@ -51,7 +55,9 @@ dependencies {
5155
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
5256

5357
// ONNX Runtime with GenAI
54-
implementation("com.microsoft.onnxruntime:onnxruntime-android:latest.release")
55-
implementation(files("libs/onnxruntime-genai-android-0.4.0-dev.aar"))
58+
//implementation("com.microsoft.onnxruntime:onnxruntime-android:latest.release")
59+
implementation(files("libs/onnxruntime-android-qnn-1.20.0.aar"))
60+
implementation(files("libs/onnxruntime-genai-android-0.5.0-dev.aar"))
61+
implementation("com.qualcomm.qti:qnn-runtime:2.27.0")
5662

5763
}
Binary file not shown.
Binary file not shown.

mobile/examples/phi-3/android/app/src/main/AndroidManifest.xml

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
<category android:name="android.intent.category.LAUNCHER" />
2525
</intent-filter>
2626
</activity>
27+
<uses-native-library
28+
android:name="libcdsprpc.so"
29+
android:required="false" />
2730
</application>
2831

2932
</manifest>

mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java

+73-7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import android.app.Dialog;
66
import android.content.Context;
77
import android.os.Bundle;
8+
import android.system.ErrnoException;
89
import android.text.method.ScrollingMovementMethod;
910
import android.util.Log;
1011
import android.util.Pair;
@@ -17,6 +18,10 @@
1718
import android.widget.Toast;
1819

1920
import java.io.File;
21+
import java.nio.ByteBuffer;
22+
import java.nio.ByteOrder;
23+
import java.nio.FloatBuffer;
24+
import java.nio.LongBuffer;
2025
import java.util.ArrayList;
2126
import java.util.Arrays;
2227
import java.util.List;
@@ -28,6 +33,7 @@
2833
import ai.onnxruntime.genai.Generator;
2934
import ai.onnxruntime.genai.GeneratorParams;
3035
import ai.onnxruntime.genai.Sequences;
36+
import ai.onnxruntime.genai.Tensor;
3137
import ai.onnxruntime.genai.TokenizerStream;
3238
import ai.onnxruntime.genai.demo.databinding.ActivityMainBinding;
3339
import ai.onnxruntime.genai.Model;
@@ -45,7 +51,7 @@ public class MainActivity extends AppCompatActivity implements Consumer<String>
4551
private TextView progressText;
4652
private ImageButton settingsButton;
4753
private static final String TAG = "genai.demo.MainActivity";
48-
private int maxLength = 100;
54+
private int maxLength = 256;
4955
private float lengthPenalty = 1.0f;
5056

5157
private static boolean fileExists(Context context, String fileName) {
@@ -55,6 +61,14 @@ private static boolean fileExists(Context context, String fileName) {
5561

5662
@Override
5763
protected void onCreate(Bundle savedInstanceState) {
64+
try {
65+
// set ADSP_LIBRARY_PATH, QNN-specific
66+
String adspLibraryPath = getApplicationContext().getApplicationInfo().nativeLibraryDir;
67+
android.system.Os.setenv("ADSP_LIBRARY_PATH", adspLibraryPath, true);
68+
} catch (ErrnoException e) {
69+
throw new RuntimeException(e);
70+
}
71+
5872
super.onCreate(savedInstanceState);
5973

6074
binding = ActivityMainBinding.inflate(getLayoutInflater());
@@ -69,8 +83,8 @@ protected void onCreate(Bundle savedInstanceState) {
6983

7084
// Trigger the download operation when the application is created
7185
try {
72-
downloadModels(
73-
getApplicationContext());
86+
createModelFromPath("/data/local/tmp/phi3.5_qnn_qc/phi3.5-split-qnn-qc");
87+
//downloadModels(getApplicationContext());
7488
} catch (GenAIException e) {
7589
throw new RuntimeException(e);
7690
}
@@ -135,17 +149,63 @@ public void run() {
135149
GeneratorParams generatorParams = null;
136150
Generator generator = null;
137151
Sequences encodedPrompt = null;
152+
Tensor attentionMask = null, positionIds = null;
138153
try {
154+
encodedPrompt = tokenizer.encode(promptQuestion_formatted);
155+
139156
stream = tokenizer.createStream();
140157

158+
int maxSequenceLength = 128;
159+
int contextLength = 4096;
160+
161+
int[] promptTokens = encodedPrompt.getSequence(0);
162+
int numPromptTokens = promptTokens.length;
163+
164+
if (numPromptTokens > maxSequenceLength) {
165+
throw new RuntimeException("numPromptTokens is greater than maxSequenceLength");
166+
}
167+
if (numPromptTokens > contextLength) {
168+
throw new RuntimeException("numPromptTokens is greater than contextLength");
169+
}
170+
171+
int paddingSize = maxSequenceLength - numPromptTokens;
172+
173+
// paddedInputIds
174+
int[] paddedInputIds = new int[maxSequenceLength];
175+
for (int i = 0; i < maxSequenceLength; ++i) {
176+
paddedInputIds[i] = i < paddingSize ? 0 : promptTokens[i - paddingSize];
177+
}
178+
179+
ByteOrder nativeOrder = ByteOrder.nativeOrder();
180+
181+
// attentionMask
182+
int attentionMaskPaddingSize = contextLength - numPromptTokens;
183+
ByteBuffer attentionMaskBuffer = ByteBuffer.allocateDirect(contextLength * 4);
184+
attentionMaskBuffer.order(nativeOrder);
185+
FloatBuffer attentionMaskFloatBuffer = attentionMaskBuffer.asFloatBuffer();
186+
for (int i = 0; i < contextLength; i++) {
187+
attentionMaskFloatBuffer.put(i < attentionMaskPaddingSize ? 0.0f : 1.0f);
188+
}
189+
attentionMask = new Tensor(attentionMaskBuffer, new long[]{1, contextLength}, Tensor.ElementType.float32);
190+
191+
// positionIds
192+
ByteBuffer positionIdsBuffer = ByteBuffer.allocateDirect(maxSequenceLength * 8);
193+
positionIdsBuffer.order(nativeOrder);
194+
LongBuffer positionIdsLongBuffer = positionIdsBuffer.asLongBuffer();
195+
for (int i = 0; i < maxSequenceLength; ++i) {
196+
positionIdsLongBuffer.put(i < paddingSize ? 0 : i - paddingSize);
197+
}
198+
positionIds = new Tensor(positionIdsBuffer, new long[]{1, maxSequenceLength}, Tensor.ElementType.int64);
199+
141200
generatorParams = model.createGeneratorParams();
142201
//examples for optional parameters to format AI response
143202
// https://onnxruntime.ai/docs/genai/reference/config.html
144203
generatorParams.setSearchOption("length_penalty", lengthPenalty);
145204
generatorParams.setSearchOption("max_length", maxLength);
205+
generatorParams.setInput("attention_mask_before_processor", attentionMask);
206+
generatorParams.setInput("position_ids", positionIds);
146207

147-
encodedPrompt = tokenizer.encode(promptQuestion_formatted);
148-
generatorParams.setInput(encodedPrompt);
208+
generatorParams.setInput(paddedInputIds, maxSequenceLength, 1);
149209

150210
generator = new Generator(model, generatorParams);
151211

@@ -175,7 +235,7 @@ public void run() {
175235
long totalTime = System.currentTimeMillis() - firstTokenTime;
176236

177237
float promptProcessingTime = (firstTokenTime - startTime)/ 1000.0f;
178-
float tokensPerSecond = (1000 * (numTokens -1)) / totalTime;
238+
float tokensPerSecond = (1000.0f * (numTokens - 1)) / totalTime;
179239

180240
runOnUiThread(() -> {
181241
sendMsgIB.setEnabled(true);
@@ -192,6 +252,8 @@ public void run() {
192252
Log.e(TAG, "Exception occurred during model query: " + e.getMessage());
193253
}
194254
finally {
255+
if (positionIds != null) positionIds.close();
256+
if (attentionMask != null) attentionMask.close();
195257
if (generator != null) generator.close();
196258
if (encodedPrompt != null) encodedPrompt.close();
197259
if (stream != null) stream.close();
@@ -217,8 +279,12 @@ protected void onDestroy() {
217279
super.onDestroy();
218280
}
219281

220-
private void downloadModels(Context context) throws GenAIException {
282+
private void createModelFromPath(String path) throws GenAIException {
283+
model = new Model(path);
284+
tokenizer = model.createTokenizer();
285+
}
221286

287+
private void downloadModels(Context context) throws GenAIException {
222288
final String baseUrl = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/";
223289
List<String> files = Arrays.asList(
224290
"added_tokens.json",

0 commit comments

Comments
 (0)