diff --git a/api/src/main/java/ai/djl/util/Platform.java b/api/src/main/java/ai/djl/util/Platform.java index 2ab25c2f30d..dd31a383349 100644 --- a/api/src/main/java/ai/djl/util/Platform.java +++ b/api/src/main/java/ai/djl/util/Platform.java @@ -20,7 +20,9 @@ import java.io.IOException; import java.io.InputStream; import java.net.URL; +import java.util.ArrayList; import java.util.Enumeration; +import java.util.List; import java.util.Properties; /** @@ -71,32 +73,50 @@ public static Platform detectPlatform(String engine) { } Platform systemPlatform = Platform.fromSystem(engine); - Platform placeholder = null; + List availablePlatforms = new ArrayList<>(); while (urls.hasMoreElements()) { URL url = urls.nextElement(); Platform platform = Platform.fromUrl(url); platform.apiVersion = systemPlatform.apiVersion; if (platform.isPlaceholder()) { - placeholder = platform; + availablePlatforms.add(platform); } else if (platform.matches(systemPlatform)) { logger.info("Found matching platform from: {}", url); - return platform; + availablePlatforms.add(platform); } else { logger.info("Ignore mismatching platform from: {}", url); } } - if (placeholder != null) { - logger.info("Found placeholder platform from: {}", placeholder); - return placeholder; - } - - if (systemPlatform.version == null) { - throw new AssertionError("No " + engine + " version found in property file."); - } - if (systemPlatform.apiVersion == null) { - throw new AssertionError("No " + engine + " djl_version found in property file."); + if (availablePlatforms.isEmpty()) { + if (systemPlatform.version == null) { + throw new AssertionError("No " + engine + " version found in property file."); + } + if (systemPlatform.apiVersion == null) { + throw new AssertionError("No " + engine + " djl_version found in property file."); + } + return systemPlatform; + } else if (availablePlatforms.size() == 1) { + Platform ret = availablePlatforms.get(0); + if (ret.isPlaceholder()) { + logger.info("Found placeholder platform from: {}", ret); + } + return ret; } - return systemPlatform; + availablePlatforms.sort( + (o1, o2) -> { + if (o1.isPlaceholder()) { + return 1; + } else if (o2.isPlaceholder()) { + return -1; + } + // cu121-precx11 > cu121 > cu118-precss11 > cpu-precxx11 > cpu + int ret = o2.getFlavor().compareTo(o1.getFlavor()); + if (ret == 0) { + return o2.getVersion().compareTo(o1.getVersion()); + } + return ret; + }); + return availablePlatforms.get(0); } /** diff --git a/api/src/main/java/ai/djl/util/cuda/CudaUtils.java b/api/src/main/java/ai/djl/util/cuda/CudaUtils.java index 859596a6be6..b234eb479db 100644 --- a/api/src/main/java/ai/djl/util/cuda/CudaUtils.java +++ b/api/src/main/java/ai/djl/util/cuda/CudaUtils.java @@ -139,11 +139,13 @@ public static String getCudaVersionString() { */ public static String getComputeCapability(int device) { if (Boolean.getBoolean("ai.djl.util.cuda.fork")) { - String[] ret = execute(device); - if (ret.length != 3) { - throw new IllegalArgumentException(ret[0]); + if (gpuInfo == null) { // NOPMD + gpuInfo = execute(-1); } - return ret[0]; + if (device >= gpuInfo.length - 2) { + throw new IllegalArgumentException("Invalid device: " + device); + } + return gpuInfo[device + 2]; } if (LIB == null) { @@ -214,7 +216,12 @@ public static void main(String[] args) { return; } int cudaVersion = getCudaVersion(); - System.out.println(gpuCount + "," + cudaVersion); + StringBuilder sb = new StringBuilder(); + sb.append(gpuCount).append(',').append(cudaVersion); + for (int i = 0; i < gpuCount; ++i) { + sb.append(',').append(getComputeCapability(i)); + } + System.out.println(sb); return; } try { diff --git a/api/src/test/java/ai/djl/util/PlatformTest.java b/api/src/test/java/ai/djl/util/PlatformTest.java index 039cbbea875..3e20609a0f9 100644 --- a/api/src/test/java/ai/djl/util/PlatformTest.java +++ b/api/src/test/java/ai/djl/util/PlatformTest.java @@ -12,16 +12,23 @@ */ package ai.djl.util; +import ai.djl.util.cuda.CudaUtils; + import org.testng.Assert; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; import java.io.BufferedWriter; import java.io.IOException; +import java.lang.reflect.Field; import java.net.URL; +import java.net.URLClassLoader; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.jar.JarEntry; +import java.util.jar.JarOutputStream; public class PlatformTest { @@ -94,6 +101,43 @@ public void testPlatform() throws IOException { Assert.assertFalse(platform.matches(system)); } + @Test + public void testDetectPlatform() throws IOException, ReflectiveOperationException { + Path dir = Paths.get("build/tmp/"); + Files.createDirectories(dir); + Platform system = Platform.fromSystem(); + String classifier = system.getClassifier(); + createZipFile(0, "1.0", "cpu", classifier, true); + createZipFile(1, "1.0", "cpu", classifier, false); + createZipFile(2, "1.0", "cpu-precxx11", classifier, false); + createZipFile(3, "1.0", "cu117", classifier, false); + createZipFile(4, "1.0", "cu117-precxx11", classifier, false); + createZipFile(5, "1.0", "cu999", classifier, false); + createZipFile(6, "1.0", "cu999-precxx11", classifier, false); + createZipFile(7, "99.99", "cu999-precxx11", classifier, false); + System.setProperty("ai.djl.util.cuda.fork", "true"); + try { + String[] gpuInfo = new String[] {"1", "99990", "90"}; + Field field = CudaUtils.class.getDeclaredField("gpuInfo"); + field.setAccessible(true); + field.set(null, gpuInfo); + URL[] urls = new URL[8]; + for (int i = 0; i < 8; ++i) { + urls[i] = dir.resolve(i + ".jar").toUri().toURL(); + } + URLClassLoader cl = new URLClassLoader(urls); + Thread.currentThread().setContextClassLoader(cl); + + Platform detected = Platform.detectPlatform("pytorch"); + Assert.assertEquals(detected.getFlavor(), "cu999-precxx11"); + + field.set(null, null); + } finally { + System.clearProperty("ai.djl.util.cuda.fork"); + Thread.currentThread().setContextClassLoader(null); + } + } + private URL createPropertyFile(String content) throws IOException { Path dir = Paths.get("build/tmp/testFile/"); Files.createDirectories(dir); @@ -104,4 +148,24 @@ private URL createPropertyFile(String content) throws IOException { } return file.toUri().toURL(); } + + private void createZipFile( + int index, String version, String flavor, String classifier, boolean placeHolder) + throws IOException { + Path file = Paths.get("build/tmp/" + index + ".jar"); + try (JarOutputStream jos = new JarOutputStream(Files.newOutputStream(file))) { + JarEntry entry = new JarEntry("native/lib/pytorch.properties"); + jos.putNextEntry(entry); + if (placeHolder) { + jos.write("placeholder=true\nversion=2.3.1".getBytes(StandardCharsets.UTF_8)); + } else { + jos.write("version=".getBytes(StandardCharsets.UTF_8)); + jos.write(version.getBytes(StandardCharsets.UTF_8)); + jos.write("\nflavor=".getBytes(StandardCharsets.UTF_8)); + jos.write(flavor.getBytes(StandardCharsets.UTF_8)); + jos.write("\nclassifier=".getBytes(StandardCharsets.UTF_8)); + jos.write(classifier.getBytes(StandardCharsets.UTF_8)); + } + } + } } diff --git a/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java b/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java index df135938a3b..f6ecc66acc1 100644 --- a/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java +++ b/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java @@ -55,6 +55,8 @@ public void testCudaUtilsWithFork() { System.setProperty("ai.djl.util.cuda.fork", "true"); try { testCudaUtils(); + CudaUtils.main(new String[0]); + CudaUtils.main(new String[] {"-1"}); } finally { System.clearProperty("ai.djl.util.cuda.fork"); }