Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GR-59148] Read JNI dictionary from all layers. #10456

Merged
merged 1 commit into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public static int getThreadLocalPinnedObjectCount() {
}

public static long getMethodID(Class<?> clazz, String name, String signature, boolean isStatic) {
return JNIReflectionDictionary.singleton().getMethodID(clazz, name, signature, isStatic).rawValue();
return JNIReflectionDictionary.getMethodID(clazz, name, signature, isStatic).rawValue();
}

public static int getThreadLocalOwnedMonitorsCount() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import static com.oracle.svm.core.SubstrateOptions.JNIVerboseLookupErrors;

import java.io.PrintStream;
import java.util.EnumSet;
import java.util.Map;
import java.util.function.Function;

Expand All @@ -46,6 +47,9 @@
import com.oracle.svm.core.jni.MissingJNIRegistrationUtils;
import com.oracle.svm.core.jni.headers.JNIFieldId;
import com.oracle.svm.core.jni.headers.JNIMethodId;
import com.oracle.svm.core.layeredimagesingleton.LayeredImageSingletonBuilderFlags;
import com.oracle.svm.core.layeredimagesingleton.MultiLayeredImageSingleton;
import com.oracle.svm.core.layeredimagesingleton.UnsavedSingleton;
import com.oracle.svm.core.log.Log;
import com.oracle.svm.core.snippets.KnownIntrinsics;
import com.oracle.svm.core.util.ImageHeapMap;
Expand All @@ -61,7 +65,7 @@
/**
* Provides JNI access to predetermined classes, methods and fields at runtime.
*/
public final class JNIReflectionDictionary {
public final class JNIReflectionDictionary implements MultiLayeredImageSingleton, UnsavedSingleton {
/**
* Enables lookups with {@link WrappedAsciiCString}, which avoids many unnecessary character set
* conversions and allocations.
Expand Down Expand Up @@ -91,47 +95,56 @@ public static void create() {
ImageSingletons.add(JNIReflectionDictionary.class, new JNIReflectionDictionary());
}

@Platforms(HOSTED_ONLY.class)
public static JNIReflectionDictionary singleton() {
return ImageSingletons.lookup(JNIReflectionDictionary.class);
}

private static JNIReflectionDictionary[] layeredSingletons() {
return MultiLayeredImageSingleton.getAllLayers(JNIReflectionDictionary.class);
}

private final EconomicMap<CharSequence, JNIAccessibleClass> classesByName = ImageHeapMap.create(WRAPPED_CSTRING_EQUIVALENCE);
private final EconomicMap<Class<?>, JNIAccessibleClass> classesByClassObject = ImageHeapMap.create();
private final EconomicMap<JNINativeLinkage, JNINativeLinkage> nativeLinkages = ImageHeapMap.create();

private JNIReflectionDictionary() {
}

private void dump(boolean condition, String label) {
private static void dump(boolean condition, String label) {
if (JNIVerboseLookupErrors.getValue() && condition) {
PrintStream ps = Log.logStream();
ps.println(label);
ps.println(" classesByName:");
MapCursor<CharSequence, JNIAccessibleClass> nameCursor = classesByName.getEntries();
while (nameCursor.advance()) {
ps.print(" ");
ps.println(nameCursor.getKey());
JNIAccessibleClass clazz = nameCursor.getValue();
ps.println(" methods:");
MapCursor<JNIAccessibleMethodDescriptor, JNIAccessibleMethod> methodsCursor = clazz.getMethods();
while (methodsCursor.advance()) {
ps.print(" ");
ps.print(methodsCursor.getKey().getName());
ps.println(methodsCursor.getKey().getSignature());
int layerNum = 0;
for (var dictionary : layeredSingletons()) {
PrintStream ps = Log.logStream();
ps.println("Layer " + layerNum);
ps.println(label);
ps.println(" classesByName:");
MapCursor<CharSequence, JNIAccessibleClass> nameCursor = dictionary.classesByName.getEntries();
while (nameCursor.advance()) {
ps.print(" ");
ps.println(nameCursor.getKey());
JNIAccessibleClass clazz = nameCursor.getValue();
ps.println(" methods:");
MapCursor<JNIAccessibleMethodDescriptor, JNIAccessibleMethod> methodsCursor = clazz.getMethods();
while (methodsCursor.advance()) {
ps.print(" ");
ps.print(methodsCursor.getKey().getName());
ps.println(methodsCursor.getKey().getSignature());
}
ps.println(" fields:");
UnmodifiableMapCursor<CharSequence, JNIAccessibleField> fieldsCursor = clazz.getFields();
while (fieldsCursor.advance()) {
ps.print(" ");
ps.println(fieldsCursor.getKey());
}
}
ps.println(" fields:");
UnmodifiableMapCursor<CharSequence, JNIAccessibleField> fieldsCursor = clazz.getFields();
while (fieldsCursor.advance()) {
ps.print(" ");
ps.println(fieldsCursor.getKey());
}
}

ps.println(" classesByClassObject:");
MapCursor<Class<?>, JNIAccessibleClass> cursor = classesByClassObject.getEntries();
while (cursor.advance()) {
ps.print(" ");
ps.println(cursor.getKey());
ps.println(" classesByClassObject:");
MapCursor<Class<?>, JNIAccessibleClass> cursor = dictionary.classesByClassObject.getEntries();
while (cursor.advance()) {
ps.print(" ");
ps.println(cursor.getKey());
}
}
}
}
Expand All @@ -151,6 +164,7 @@ public JNIAccessibleClass addClassIfAbsent(Class<?> classObj, Function<Class<?>,
return classesByClassObject.get(classObj);
}

@Platforms(HOSTED_ONLY.class)
public void addNegativeClassLookupIfAbsent(String typeName) {
String internalName = MetaUtil.toInternalName(typeName);
String queryName = internalName.startsWith("L") ? internalName.substring(1, internalName.length() - 1) : internalName;
Expand All @@ -162,15 +176,21 @@ public void addLinkages(Map<JNINativeLinkage, JNINativeLinkage> linkages) {
nativeLinkages.putAll(EconomicMap.wrapMap(linkages));
}

@Platforms(HOSTED_ONLY.class)
public Iterable<JNIAccessibleClass> getClasses() {
return classesByClassObject.getValues();
}

public Class<?> getClassObjectByName(CharSequence name) {
JNIAccessibleClass clazz = classesByName.get(name);
clazz = checkClass(clazz, name);
dump(clazz == null, "getClassObjectByName");
return (clazz != null) ? clazz.getClassObject() : null;
public static Class<?> getClassObjectByName(CharSequence name) {
for (var dictionary : layeredSingletons()) {
JNIAccessibleClass clazz = dictionary.classesByName.get(name);
clazz = checkClass(clazz, name);
if (clazz != null) {
return clazz.getClassObject();
}
}
dump(true, "getClassObjectByName");
return null;
}

private static JNIAccessibleClass checkClass(JNIAccessibleClass clazz, CharSequence name) {
Expand All @@ -192,20 +212,28 @@ private static JNIAccessibleClass checkClass(JNIAccessibleClass clazz, CharSeque
* method
* @return the linkage for the native method or {@code null} if no linkage exists
*/
public JNINativeLinkage getLinkage(CharSequence declaringClass, CharSequence name, CharSequence descriptor) {
public static JNINativeLinkage getLinkage(CharSequence declaringClass, CharSequence name, CharSequence descriptor) {
JNINativeLinkage key = new JNINativeLinkage(declaringClass, name, descriptor);
return nativeLinkages.get(key);
for (var dictionary : layeredSingletons()) {
var linkage = dictionary.nativeLinkages.get(key);
if (linkage != null) {
return linkage;
}
}
return null;
}

public void unsetEntryPoints(String declaringClass) {
for (JNINativeLinkage linkage : nativeLinkages.getKeys()) {
if (declaringClass.equals(linkage.getDeclaringClassName())) {
linkage.unsetEntryPoint();
public static void unsetEntryPoints(String declaringClass) {
for (var dictionary : layeredSingletons()) {
for (JNINativeLinkage linkage : dictionary.nativeLinkages.getKeys()) {
if (declaringClass.equals(linkage.getDeclaringClassName())) {
linkage.unsetEntryPoint();
}
}
}
}

private JNIAccessibleMethod findMethod(Class<?> clazz, JNIAccessibleMethodDescriptor descriptor, String dumpLabel) {
private static JNIAccessibleMethod findMethod(Class<?> clazz, JNIAccessibleMethodDescriptor descriptor, String dumpLabel) {
JNIAccessibleMethod method = getDeclaredMethod(clazz, descriptor, dumpLabel);
if (descriptor.isConstructor() || descriptor.isClassInitializer()) { // never recurse
return method;
Expand All @@ -220,7 +248,7 @@ private JNIAccessibleMethod findMethod(Class<?> clazz, JNIAccessibleMethodDescri
return method;
}

private JNIAccessibleMethod findSuperinterfaceMethod(Class<?> clazz, JNIAccessibleMethodDescriptor descriptor) {
private static JNIAccessibleMethod findSuperinterfaceMethod(Class<?> clazz, JNIAccessibleMethodDescriptor descriptor) {
for (Class<?> parent : clazz.getInterfaces()) {
JNIAccessibleMethod method = getDeclaredMethod(parent, descriptor, null);
if (method == null) {
Expand All @@ -234,23 +262,29 @@ private JNIAccessibleMethod findSuperinterfaceMethod(Class<?> clazz, JNIAccessib
return null;
}

public JNIMethodId getDeclaredMethodID(Class<?> classObject, JNIAccessibleMethodDescriptor descriptor, boolean isStatic) {
public static JNIMethodId getDeclaredMethodID(Class<?> classObject, JNIAccessibleMethodDescriptor descriptor, boolean isStatic) {
JNIAccessibleMethod method = getDeclaredMethod(classObject, descriptor, "getDeclaredMethodID");
boolean match = (method != null && method.isStatic() == isStatic);
return toMethodID(match ? method : null);
}

private JNIAccessibleMethod getDeclaredMethod(Class<?> classObject, JNIAccessibleMethodDescriptor descriptor, String dumpLabel) {
JNIAccessibleClass clazz = classesByClassObject.get(classObject);
dump(clazz == null && dumpLabel != null, dumpLabel);
JNIAccessibleMethod method = null;
if (clazz != null) {
method = clazz.getMethod(descriptor);
private static JNIAccessibleMethod getDeclaredMethod(Class<?> classObject, JNIAccessibleMethodDescriptor descriptor, String dumpLabel) {
boolean foundClass = false;
for (var dictionary : layeredSingletons()) {
JNIAccessibleClass clazz = dictionary.classesByClassObject.get(classObject);
if (clazz != null) {
foundClass = true;
JNIAccessibleMethod method = clazz.getMethod(descriptor);
if (method != null) {
return method;
}
}
}
return method;
dump(!foundClass && dumpLabel != null, dumpLabel);
return null;
}

public JNIMethodId getMethodID(Class<?> classObject, CharSequence name, CharSequence signature, boolean isStatic) {
public static JNIMethodId getMethodID(Class<?> classObject, CharSequence name, CharSequence signature, boolean isStatic) {
JNIAccessibleMethod method = findMethod(classObject, new JNIAccessibleMethodDescriptor(name, signature), "getMethodID");
method = checkMethod(method, classObject, name, signature);
boolean match = (method != null && method.isStatic() == isStatic && method.isDiscoverableIn(classObject));
Expand Down Expand Up @@ -289,25 +323,29 @@ private static JNIAccessibleMethod checkMethod(JNIAccessibleMethod method, Class
return method;
}

private JNIAccessibleField getDeclaredField(Class<?> classObject, CharSequence name, boolean isStatic, String dumpLabel) {
JNIAccessibleClass clazz = classesByClassObject.get(classObject);
dump(clazz == null && dumpLabel != null, dumpLabel);
if (clazz != null) {
JNIAccessibleField field = clazz.getField(name);
if (field != null && (field.isStatic() == isStatic || field.isNegative())) {
return field;
private static JNIAccessibleField getDeclaredField(Class<?> classObject, CharSequence name, boolean isStatic, String dumpLabel) {
boolean foundClass = false;
for (var dictionary : layeredSingletons()) {
JNIAccessibleClass clazz = dictionary.classesByClassObject.get(classObject);
if (clazz != null) {
foundClass = true;
JNIAccessibleField field = clazz.getField(name);
if (field != null && (field.isStatic() == isStatic || field.isNegative())) {
return field;
}
}
}
dump(!foundClass && dumpLabel != null, dumpLabel);
return null;
}

public JNIFieldId getDeclaredFieldID(Class<?> classObject, String name, boolean isStatic) {
public static JNIFieldId getDeclaredFieldID(Class<?> classObject, String name, boolean isStatic) {
JNIAccessibleField field = getDeclaredField(classObject, name, isStatic, "getDeclaredFieldID");
field = checkField(field, classObject, name);
return (field != null) ? field.getId() : Word.nullPointer();
}

private JNIAccessibleField findField(Class<?> clazz, CharSequence name, boolean isStatic, String dumpLabel) {
private static JNIAccessibleField findField(Class<?> clazz, CharSequence name, boolean isStatic, String dumpLabel) {
// Lookup according to JVM spec 5.4.3.2: local fields, superinterfaces, superclasses
JNIAccessibleField field = getDeclaredField(clazz, name, isStatic, dumpLabel);
if (field == null && isStatic) {
Expand All @@ -319,7 +357,7 @@ private JNIAccessibleField findField(Class<?> clazz, CharSequence name, boolean
return field;
}

private JNIAccessibleField findSuperinterfaceField(Class<?> clazz, CharSequence name) {
private static JNIAccessibleField findSuperinterfaceField(Class<?> clazz, CharSequence name) {
for (Class<?> parent : clazz.getInterfaces()) {
JNIAccessibleField field = getDeclaredField(parent, name, true, null);
if (field == null) {
Expand All @@ -332,21 +370,23 @@ private JNIAccessibleField findSuperinterfaceField(Class<?> clazz, CharSequence
return null;
}

public JNIFieldId getFieldID(Class<?> clazz, CharSequence name, boolean isStatic) {
public static JNIFieldId getFieldID(Class<?> clazz, CharSequence name, boolean isStatic) {
JNIAccessibleField field = findField(clazz, name, isStatic, "getFieldID");
field = checkField(field, clazz, name);
return (field != null && field.isDiscoverableIn(clazz)) ? field.getId() : Word.nullPointer();
}

public String getFieldNameByID(Class<?> classObject, JNIFieldId id) {
JNIAccessibleClass clazz = classesByClassObject.get(classObject);
if (clazz != null) {
UnmodifiableMapCursor<CharSequence, JNIAccessibleField> fieldsCursor = clazz.getFields();
while (fieldsCursor.advance()) {
JNIAccessibleField field = fieldsCursor.getValue();
if (id.equal(field.getId())) {
VMError.guarantee(!field.isNegative(), "Existing fields can't correspond to a negative query");
return (String) fieldsCursor.getKey();
public static String getFieldNameByID(Class<?> classObject, JNIFieldId id) {
for (var dictionary : layeredSingletons()) {
JNIAccessibleClass clazz = dictionary.classesByClassObject.get(classObject);
if (clazz != null) {
UnmodifiableMapCursor<CharSequence, JNIAccessibleField> fieldsCursor = clazz.getFields();
while (fieldsCursor.advance()) {
JNIAccessibleField field = fieldsCursor.getValue();
if (id.equal(field.getId())) {
VMError.guarantee(!field.isNegative(), "Existing fields can't correspond to a negative query");
return (String) fieldsCursor.getKey();
}
}
}
}
Expand Down Expand Up @@ -375,4 +415,8 @@ public static JNIAccessibleMethodDescriptor getMethodDescriptor(JNIAccessibleMet
return null;
}

@Override
public EnumSet<LayeredImageSingletonBuilderFlags> getImageBuilderFlags() {
return LayeredImageSingletonBuilderFlags.ALL_ACCESS;
}
}
Loading
Loading