Skip to content

Commit

Permalink
[api] Move compileJava() into ClassLoaderUtils
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed May 14, 2023
1 parent aa61c10 commit df50871
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 137 deletions.
140 changes: 3 additions & 137 deletions api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,13 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.Type;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;

/** A {@link TranslatorFactory} that creates a generic {@link Translator}. */
public class ServingTranslatorFactory implements TranslatorFactory {
Expand Down Expand Up @@ -99,94 +84,9 @@ public <I, O> Translator<I, O> newInstance(
}

private ServingTranslator findTranslator(Path path, String className) {
try {
Path classesDir = path.resolve("classes");
compileJavaClass(classesDir);

List<Path> jarFiles = new ArrayList<>();
if (Files.isDirectory(path)) {
try (Stream<Path> stream = Files.list(path)) {
stream.forEach(
p -> {
if (p.toString().endsWith(".jar")) {
jarFiles.add(p);
}
});
}
}
List<URL> urls = new ArrayList<>(jarFiles.size() + 1);
urls.add(classesDir.toUri().toURL());
for (Path p : jarFiles) {
urls.add(p.toUri().toURL());
}

ClassLoader parentCl = ClassLoaderUtils.getContextClassLoader();
ClassLoader cl = new URLClassLoader(urls.toArray(new URL[0]), parentCl);
if (className != null && !className.isEmpty()) {
logger.info("Trying to loading specified Translator: {}", className);
return initTranslator(cl, className);
}

ServingTranslator translator = scanDirectory(cl, classesDir);
if (translator != null) {
return translator;
}

for (Path p : jarFiles) {
translator = scanJarFile(cl, p);
if (translator != null) {
return translator;
}
}
} catch (IOException e) {
logger.debug("Failed to find Translator", e);
}
return null;
}

private ServingTranslator scanDirectory(ClassLoader cl, Path dir) throws IOException {
if (!Files.isDirectory(dir)) {
logger.debug("Directory not exists: {}", dir);
return null;
}
Collection<Path> files;
try (Stream<Path> stream = Files.walk(dir)) {
files =
stream.filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
.collect(Collectors.toList());
}
for (Path file : files) {
Path p = dir.relativize(file);
String className = p.toString();
className = className.substring(0, className.lastIndexOf('.'));
className = className.replace(File.separatorChar, '.');
ServingTranslator translator = initTranslator(cl, className);
if (translator != null) {
logger.info("Found translator in model directory: {}", className);
return translator;
}
}
return null;
}

private ServingTranslator scanJarFile(ClassLoader cl, Path path) throws IOException {
try (JarFile jarFile = new JarFile(path.toFile())) {
Enumeration<JarEntry> en = jarFile.entries();
while (en.hasMoreElements()) {
JarEntry entry = en.nextElement();
String fileName = entry.getName();
if (fileName.endsWith(".class")) {
fileName = fileName.substring(0, fileName.lastIndexOf('.'));
fileName = fileName.replace('/', '.');
ServingTranslator translator = initTranslator(cl, fileName);
if (translator != null) {
logger.info("Found translator {} in jar {}", fileName, path);
return translator;
}
}
}
}
return null;
Path classesDir = path.resolve("classes");
ClassLoaderUtils.compileJavaClass(classesDir);
return ClassLoaderUtils.findImplementation(path, ServingTranslator.class, className);
}

private TranslatorFactory loadTranslatorFactory(String className) {
Expand All @@ -201,18 +101,6 @@ private TranslatorFactory loadTranslatorFactory(String className) {
return null;
}

private ServingTranslator initTranslator(ClassLoader cl, String className) {
try {
Class<?> clazz = Class.forName(className, true, cl);
Class<? extends ServingTranslator> subclass = clazz.asSubclass(ServingTranslator.class);
Constructor<? extends ServingTranslator> constructor = subclass.getConstructor();
return constructor.newInstance();
} catch (Throwable e) {
logger.trace("Not able to load Translator: " + className, e);
}
return null;
}

private Translator<Input, Output> loadDefaultTranslator(Map<String, ?> arguments) {
String appName = ArgumentsUtil.stringValue(arguments, "application");
if (appName != null) {
Expand All @@ -228,26 +116,4 @@ private Translator<Input, Output> loadDefaultTranslator(Map<String, ?> arguments
private Translator<Input, Output> getImageClassificationTranslator(Map<String, ?> arguments) {
return new ImageServingTranslator(ImageClassificationTranslator.builder(arguments).build());
}

private void compileJavaClass(Path dir) {
try {
if (!Files.isDirectory(dir)) {
logger.debug("Directory not exists: {}", dir);
return;
}
String[] files;
try (Stream<Path> stream = Files.walk(dir)) {
files =
stream.filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".java"))
.map(p -> p.toAbsolutePath().toString())
.toArray(String[]::new);
}
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
if (files.length > 0) {
compiler.run(null, null, null, files);
}
} catch (Throwable e) {
logger.warn("Failed to compile bundled java file", e);
}
}
}
31 changes: 31 additions & 0 deletions api/src/main/java/ai/djl/util/ClassLoaderUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;

/** A utility class that load classes from specific URLs. */
public final class ClassLoaderUtils {
Expand Down Expand Up @@ -229,4 +233,31 @@ public static void nativeLoad(String nativeHelper, String path) {
throw new IllegalArgumentException("Invalid native_helper: " + nativeHelper, e);
}
}

/**
* Tries to compile java classes in the directory.
*
* @param dir the directory to scan java file.
*/
public static void compileJavaClass(Path dir) {
try {
if (!Files.isDirectory(dir)) {
logger.debug("Directory not exists: {}", dir);
return;
}
String[] files;
try (Stream<Path> stream = Files.walk(dir)) {
files =
stream.filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".java"))
.map(p -> p.toAbsolutePath().toString())
.toArray(String[]::new);
}
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
if (files.length > 0) {
compiler.run(null, null, null, files);
}
} catch (Throwable e) {
logger.warn("Failed to compile bundled java file", e);
}
}
}

0 comments on commit df50871

Please sign in to comment.