Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Loads native engine in deterministic order #3300

Merged
merged 1 commit into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions api/src/main/java/ai/djl/util/Platform.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Properties;

/**
Expand Down Expand Up @@ -71,32 +73,50 @@ public static Platform detectPlatform(String engine) {
}

Platform systemPlatform = Platform.fromSystem(engine);
Platform placeholder = null;
List<Platform> availablePlatforms = new ArrayList<>();
while (urls.hasMoreElements()) {
URL url = urls.nextElement();
Platform platform = Platform.fromUrl(url);
platform.apiVersion = systemPlatform.apiVersion;
if (platform.isPlaceholder()) {
placeholder = platform;
availablePlatforms.add(platform);
} else if (platform.matches(systemPlatform)) {
logger.info("Found matching platform from: {}", url);
return platform;
availablePlatforms.add(platform);
} else {
logger.info("Ignore mismatching platform from: {}", url);
}
}
if (placeholder != null) {
logger.info("Found placeholder platform from: {}", placeholder);
return placeholder;
}

if (systemPlatform.version == null) {
throw new AssertionError("No " + engine + " version found in property file.");
}
if (systemPlatform.apiVersion == null) {
throw new AssertionError("No " + engine + " djl_version found in property file.");
if (availablePlatforms.isEmpty()) {
if (systemPlatform.version == null) {
throw new AssertionError("No " + engine + " version found in property file.");
}
if (systemPlatform.apiVersion == null) {
throw new AssertionError("No " + engine + " djl_version found in property file.");
}
return systemPlatform;
} else if (availablePlatforms.size() == 1) {
Platform ret = availablePlatforms.get(0);
if (ret.isPlaceholder()) {
logger.info("Found placeholder platform from: {}", ret);
}
return ret;
}
return systemPlatform;
availablePlatforms.sort(
(o1, o2) -> {
if (o1.isPlaceholder()) {
return 1;
} else if (o2.isPlaceholder()) {
return -1;
}
// cu121-precx11 > cu121 > cu118-precss11 > cpu-precxx11 > cpu
int ret = o2.getFlavor().compareTo(o1.getFlavor());
if (ret == 0) {
return o2.getVersion().compareTo(o1.getVersion());
}
return ret;
});
return availablePlatforms.get(0);
}

/**
Expand Down
17 changes: 12 additions & 5 deletions api/src/main/java/ai/djl/util/cuda/CudaUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,13 @@ public static String getCudaVersionString() {
*/
public static String getComputeCapability(int device) {
if (Boolean.getBoolean("ai.djl.util.cuda.fork")) {
String[] ret = execute(device);
if (ret.length != 3) {
throw new IllegalArgumentException(ret[0]);
if (gpuInfo == null) { // NOPMD
gpuInfo = execute(-1);
}
return ret[0];
if (device >= gpuInfo.length - 2) {
throw new IllegalArgumentException("Invalid device: " + device);
}
return gpuInfo[device + 2];
}

if (LIB == null) {
Expand Down Expand Up @@ -214,7 +216,12 @@ public static void main(String[] args) {
return;
}
int cudaVersion = getCudaVersion();
System.out.println(gpuCount + "," + cudaVersion);
StringBuilder sb = new StringBuilder();
sb.append(gpuCount).append(',').append(cudaVersion);
for (int i = 0; i < gpuCount; ++i) {
sb.append(',').append(getComputeCapability(i));
}
System.out.println(sb);
return;
}
try {
Expand Down
64 changes: 64 additions & 0 deletions api/src/test/java/ai/djl/util/PlatformTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@
*/
package ai.djl.util;

import ai.djl.util.cuda.CudaUtils;

import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.Test;

import java.io.BufferedWriter;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.jar.JarEntry;
import java.util.jar.JarOutputStream;

public class PlatformTest {

Expand Down Expand Up @@ -94,6 +101,43 @@ public void testPlatform() throws IOException {
Assert.assertFalse(platform.matches(system));
}

@Test
public void testDetectPlatform() throws IOException, ReflectiveOperationException {
Path dir = Paths.get("build/tmp/");
Files.createDirectories(dir);
Platform system = Platform.fromSystem();
String classifier = system.getClassifier();
createZipFile(0, "1.0", "cpu", classifier, true);
createZipFile(1, "1.0", "cpu", classifier, false);
createZipFile(2, "1.0", "cpu-precxx11", classifier, false);
createZipFile(3, "1.0", "cu117", classifier, false);
createZipFile(4, "1.0", "cu117-precxx11", classifier, false);
createZipFile(5, "1.0", "cu999", classifier, false);
createZipFile(6, "1.0", "cu999-precxx11", classifier, false);
createZipFile(7, "99.99", "cu999-precxx11", classifier, false);
System.setProperty("ai.djl.util.cuda.fork", "true");
try {
String[] gpuInfo = new String[] {"1", "99990", "90"};
Field field = CudaUtils.class.getDeclaredField("gpuInfo");
field.setAccessible(true);
field.set(null, gpuInfo);
URL[] urls = new URL[8];
for (int i = 0; i < 8; ++i) {
urls[i] = dir.resolve(i + ".jar").toUri().toURL();
}
URLClassLoader cl = new URLClassLoader(urls);
Thread.currentThread().setContextClassLoader(cl);

Platform detected = Platform.detectPlatform("pytorch");
Assert.assertEquals(detected.getFlavor(), "cu999-precxx11");

field.set(null, null);
} finally {
System.clearProperty("ai.djl.util.cuda.fork");
Thread.currentThread().setContextClassLoader(null);
}
}

private URL createPropertyFile(String content) throws IOException {
Path dir = Paths.get("build/tmp/testFile/");
Files.createDirectories(dir);
Expand All @@ -104,4 +148,24 @@ private URL createPropertyFile(String content) throws IOException {
}
return file.toUri().toURL();
}

private void createZipFile(
int index, String version, String flavor, String classifier, boolean placeHolder)
throws IOException {
Path file = Paths.get("build/tmp/" + index + ".jar");
try (JarOutputStream jos = new JarOutputStream(Files.newOutputStream(file))) {
JarEntry entry = new JarEntry("native/lib/pytorch.properties");
jos.putNextEntry(entry);
if (placeHolder) {
jos.write("placeholder=true\nversion=2.3.1".getBytes(StandardCharsets.UTF_8));
} else {
jos.write("version=".getBytes(StandardCharsets.UTF_8));
jos.write(version.getBytes(StandardCharsets.UTF_8));
jos.write("\nflavor=".getBytes(StandardCharsets.UTF_8));
jos.write(flavor.getBytes(StandardCharsets.UTF_8));
jos.write("\nclassifier=".getBytes(StandardCharsets.UTF_8));
jos.write(classifier.getBytes(StandardCharsets.UTF_8));
}
}
}
}
2 changes: 2 additions & 0 deletions api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ public void testCudaUtilsWithFork() {
System.setProperty("ai.djl.util.cuda.fork", "true");
try {
testCudaUtils();
CudaUtils.main(new String[0]);
CudaUtils.main(new String[] {"-1"});
} finally {
System.clearProperty("ai.djl.util.cuda.fork");
}
Expand Down
Loading