diff --git a/api/src/main/java/ai/djl/repository/LocalRepository.java b/api/src/main/java/ai/djl/repository/LocalRepository.java index d7d292936a6..a755d393ccc 100644 --- a/api/src/main/java/ai/djl/repository/LocalRepository.java +++ b/api/src/main/java/ai/djl/repository/LocalRepository.java @@ -27,6 +27,7 @@ import java.util.Comparator; import java.util.List; import java.util.Map; +import java.util.stream.Stream; /** * A {@code LocalRepository} is a {@link Repository} located in a filesystem directory. @@ -92,29 +93,27 @@ public Artifact resolve(MRL mrl, Map filter) throws IOException @Override public List getResources() { List list = new ArrayList<>(); - try { - Files.walk(path) - .forEach( - f -> { - if (f.endsWith("metadata.json") && Files.isRegularFile(f)) { - Path relative = path.relativize(f); - String type = relative.getName(0).toString(); - try (Reader reader = Files.newBufferedReader(f)) { - Metadata metadata = - JsonUtils.GSON.fromJson(reader, Metadata.class); - Application application = metadata.getApplication(); - String groupId = metadata.getGroupId(); - String artifactId = metadata.getArtifactId(); - if ("dataset".equals(type)) { - list.add(dataset(application, groupId, artifactId)); - } else if ("model".equals(type)) { - list.add(model(application, groupId, artifactId)); - } - } catch (IOException e) { - logger.warn("Failed to read metadata.json", e); - } + try (Stream stream = Files.walk(path)) { + stream.forEach( + f -> { + if (f.endsWith("metadata.json") && Files.isRegularFile(f)) { + Path relative = path.relativize(f); + String type = relative.getName(0).toString(); + try (Reader reader = Files.newBufferedReader(f)) { + Metadata metadata = JsonUtils.GSON.fromJson(reader, Metadata.class); + Application application = metadata.getApplication(); + String groupId = metadata.getGroupId(); + String artifactId = metadata.getArtifactId(); + if ("dataset".equals(type)) { + list.add(dataset(application, groupId, artifactId)); + } else if ("model".equals(type)) { + list.add(model(application, groupId, artifactId)); } - }); + } catch (IOException e) { + logger.warn("Failed to read metadata.json", e); + } + } + }); } catch (IOException e) { logger.warn("", e); } diff --git a/api/src/main/java/ai/djl/util/ClassLoaderUtils.java b/api/src/main/java/ai/djl/util/ClassLoaderUtils.java index 47141fe8de6..dc5bfba8678 100644 --- a/api/src/main/java/ai/djl/util/ClassLoaderUtils.java +++ b/api/src/main/java/ai/djl/util/ClassLoaderUtils.java @@ -26,7 +26,6 @@ import java.nio.file.Path; import java.security.AccessController; import java.security.PrivilegedAction; -import java.util.Collection; import java.util.Collections; import java.util.Enumeration; import java.util.List; @@ -62,10 +61,11 @@ public static T findImplementation(Path path, Class type, String classNam // we only consider .class files and skip .java files List jarFiles; if (Files.isDirectory(path)) { - jarFiles = - Files.list(path) - .filter(p -> p.toString().endsWith(".jar")) - .collect(Collectors.toList()); + try (Stream stream = Files.list(path)) { + jarFiles = + stream.filter(p -> p.toString().endsWith(".jar")) + .collect(Collectors.toList()); + } } else { jarFiles = Collections.emptyList(); } @@ -111,18 +111,19 @@ private static T scanDirectory(ClassLoader cl, Class type, Path dir) thro logger.trace("Directory not exists: {}", dir); return null; } - Collection files = - Files.walk(dir) - .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class")) - .collect(Collectors.toList()); - for (Path file : files) { - Path p = dir.relativize(file); - String className = p.toString(); - className = className.substring(0, className.lastIndexOf('.')); - className = className.replace(File.separatorChar, '.'); - T implemented = initClass(cl, type, className); - if (implemented != null) { - return implemented; + try (Stream stream = Files.walk(dir)) { + List files = + stream.filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class")) + .collect(Collectors.toList()); + for (Path file : files) { + Path p = dir.relativize(file); + String className = p.toString(); + className = className.substring(0, className.lastIndexOf('.')); + className = className.replace(File.separatorChar, '.'); + T implemented = initClass(cl, type, className); + if (implemented != null) { + return implemented; + } } } return null; diff --git a/api/src/main/java/ai/djl/util/Utils.java b/api/src/main/java/ai/djl/util/Utils.java index ec611c8ed91..99da664f97a 100644 --- a/api/src/main/java/ai/djl/util/Utils.java +++ b/api/src/main/java/ai/djl/util/Utils.java @@ -42,6 +42,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; +import java.util.stream.Stream; /** A class containing utility methods. */ public final class Utils { @@ -105,19 +106,18 @@ public static void pad(StringBuilder sb, char c, int count) { * @param dir the directory to be removed */ public static void deleteQuietly(Path dir) { - try { - Files.walk(dir) - .sorted(Comparator.reverseOrder()) - .forEach( - path -> { - try { - Files.deleteIfExists(path); - } catch (IOException ignore) { - // ignore - } - }); + List list; + try (Stream stream = Files.walk(dir)) { + list = stream.sorted(Comparator.reverseOrder()).collect(Collectors.toList()); } catch (IOException ignore) { - // ignore + return; + } + for (Path path : list) { + try { + Files.deleteIfExists(path); + } catch (IOException ignore) { + // ignore + } } } @@ -255,23 +255,24 @@ public static float[] toFloatArray(List list) { */ public static int getCurrentEpoch(Path modelDir, String modelName) throws IOException { final Pattern pattern = Pattern.compile(Pattern.quote(modelName) + "-(\\d{4}).params"); - List checkpoints = - Files.walk(modelDir, 1, FileVisitOption.FOLLOW_LINKS) - .map( - p -> { - Matcher m = pattern.matcher(p.toFile().getName()); - if (m.matches()) { - return Integer.parseInt(m.group(1)); - } - return null; - }) - .filter(Objects::nonNull) - .sorted() - .collect(Collectors.toList()); - if (checkpoints.isEmpty()) { - return -1; + try (Stream stream = Files.walk(modelDir, 1, FileVisitOption.FOLLOW_LINKS)) { + List checkpoints = + stream.map( + p -> { + Matcher m = pattern.matcher(p.toFile().getName()); + if (m.matches()) { + return Integer.parseInt(m.group(1)); + } + return null; + }) + .filter(Objects::nonNull) + .sorted() + .collect(Collectors.toList()); + if (checkpoints.isEmpty()) { + return -1; + } + return checkpoints.get(checkpoints.size() - 1); } - return checkpoints.get(checkpoints.size() - 1); } /** @@ -380,11 +381,10 @@ public static boolean isOfflineMode() { */ public static Path getNestedModelDir(Path modelDir) { if (Files.isDirectory(modelDir)) { - try { + try (Stream stream = Files.list(modelDir)) { // handle actual model directory is subdirectory case List files = - Files.list(modelDir) - .filter(p -> !p.getFileName().toString().startsWith(".")) + stream.filter(p -> !p.getFileName().toString().startsWith(".")) .collect(Collectors.toList()); if (files.size() == 1 && Files.isDirectory(files.get(0))) { return files.get(0); diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java index 81a27d01d1f..87099bf875e 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxModel.java @@ -42,6 +42,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.function.Predicate; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * {@code MxModel} is the MXNet implementation of {@link Model}. @@ -176,9 +177,8 @@ public Trainer newTrainer(TrainingConfig trainingConfig) { /** {@inheritDoc} */ @Override public String[] getArtifactNames() { - try { - List files = - Files.walk(modelDir).filter(Files::isRegularFile).collect(Collectors.toList()); + try (Stream stream = Files.walk(modelDir)) { + List files = stream.filter(Files::isRegularFile).collect(Collectors.toList()); List ret = new ArrayList<>(files.size()); for (Path path : files) { String fileName = path.toFile().getName(); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index e409918a091..6b5251c357c 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -38,6 +38,7 @@ import java.util.Map; import java.util.function.Predicate; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * {@code PtModel} is the PyTorch implementation of {@link Model}. @@ -227,9 +228,8 @@ public Trainer newTrainer(TrainingConfig trainingConfig) { /** {@inheritDoc} */ @Override public String[] getArtifactNames() { - try { - List files = - Files.walk(modelDir).filter(Files::isRegularFile).collect(Collectors.toList()); + try (Stream stream = Files.walk(modelDir)) { + List files = stream.filter(Files::isRegularFile).collect(Collectors.toList()); List ret = new ArrayList<>(files.size()); for (Path path : files) { String fileName = path.toFile().getName(); diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfModel.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfModel.java index 53f3981a8f6..25e54d4721f 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfModel.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfModel.java @@ -16,7 +16,6 @@ import ai.djl.Device; import ai.djl.MalformedModelException; import ai.djl.Model; -import ai.djl.ndarray.NDManager; import ai.djl.nn.Block; import ai.djl.tensorflow.engine.javacpp.JavacppUtils; import ai.djl.util.Utils; @@ -36,6 +35,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import java.util.stream.Stream; /** {@code TfModel} is the TensorFlow implementation of {@link Model}. */ public class TfModel extends BaseModel { @@ -148,30 +148,17 @@ public void save(Path modelPath, String newModelName) { throw new UnsupportedOperationException("Not supported for TensorFlow Engine"); } - /** {@inheritDoc} */ - @Override - public Block getBlock() { - return block; - } - /** {@inheritDoc} */ @Override public void setBlock(Block block) { throw new UnsupportedOperationException("Not supported for TensorFlow Engine"); } - /** {@inheritDoc} */ - @Override - public NDManager getNDManager() { - return manager; - } - /** {@inheritDoc} */ @Override public String[] getArtifactNames() { - try { - List files = - Files.walk(modelDir).filter(Files::isRegularFile).collect(Collectors.toList()); + try (Stream stream = Files.walk(modelDir)) { + List files = stream.filter(Files::isRegularFile).collect(Collectors.toList()); List ret = new ArrayList<>(files.size()); for (Path path : files) { String fileName = path.toFile().getName(); diff --git a/examples/src/main/java/ai/djl/examples/training/util/BertCodeDataset.java b/examples/src/main/java/ai/djl/examples/training/util/BertCodeDataset.java index f83f5a3d1dc..d8d538d5749 100644 --- a/examples/src/main/java/ai/djl/examples/training/util/BertCodeDataset.java +++ b/examples/src/main/java/ai/djl/examples/training/util/BertCodeDataset.java @@ -38,6 +38,7 @@ import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; +import java.util.stream.Stream; /** An example code dataset using the code within the DJL path. */ public class BertCodeDataset implements Dataset { @@ -94,9 +95,8 @@ public int getMaxSequenceLength() { } private static List listSourceFiles(Path root) { - try { - return Files.walk(root) - .filter(Files::isRegularFile) + try (Stream stream = Files.walk(root)) { + return stream.filter(Files::isRegularFile) .filter(path -> path.toString().toLowerCase().endsWith(".java")) .collect(Collectors.toList()); } catch (IOException ioe) { diff --git a/integration/src/main/java/ai/djl/integration/IntegrationTest.java b/integration/src/main/java/ai/djl/integration/IntegrationTest.java index a55cec9a22c..f3871ff2783 100644 --- a/integration/src/main/java/ai/djl/integration/IntegrationTest.java +++ b/integration/src/main/java/ai/djl/integration/IntegrationTest.java @@ -50,6 +50,7 @@ import java.util.jar.JarEntry; import java.util.jar.JarFile; import java.util.stream.Collectors; +import java.util.stream.Stream; @SuppressWarnings("PMD.TestClassWithoutTestCases") public class IntegrationTest { @@ -219,20 +220,25 @@ private static List> listTestClasses(Arguments arguments, Class claz Path classPath = Paths.get(url.toURI()); if (Files.isDirectory(classPath)) { - Collection files = - Files.walk(classPath) - .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class")) - .collect(Collectors.toList()); - for (Path file : files) { - Path p = classPath.relativize(file); - String className = p.toString(); - className = className.substring(0, className.lastIndexOf('.')); - className = className.replace(File.separatorChar, '.'); - if (className.startsWith(arguments.getPackageName()) && !className.contains("$")) { - try { - classList.add(Class.forName(className)); - } catch (ExceptionInInitializerError ignore) { - // ignore + try (Stream stream = Files.walk(classPath)) { + Collection files = + stream.filter( + p -> + Files.isRegularFile(p) + && p.toString().endsWith(".class")) + .collect(Collectors.toList()); + for (Path file : files) { + Path p = classPath.relativize(file); + String className = p.toString(); + className = className.substring(0, className.lastIndexOf('.')); + className = className.replace(File.separatorChar, '.'); + if (className.startsWith(arguments.getPackageName()) + && !className.contains("$")) { + try { + classList.add(Class.forName(className)); + } catch (ExceptionInInitializerError ignore) { + // ignore + } } } } diff --git a/testing/src/main/java/ai/djl/testing/CoverageUtils.java b/testing/src/main/java/ai/djl/testing/CoverageUtils.java index 42cdecb5fbd..a409e74f150 100644 --- a/testing/src/main/java/ai/djl/testing/CoverageUtils.java +++ b/testing/src/main/java/ai/djl/testing/CoverageUtils.java @@ -38,6 +38,7 @@ import java.util.jar.JarEntry; import java.util.jar.JarFile; import java.util.stream.Collectors; +import java.util.stream.Stream; @SuppressWarnings({"PMD.AvoidAccessibilityAlteration", "PMD.TestClassWithoutTestCases"}) public final class CoverageUtils { @@ -125,20 +126,24 @@ private static List> getClasses(Class clazz) Path classPath = Paths.get(url.toURI()); if (Files.isDirectory(classPath)) { - Collection files = - Files.walk(classPath) - .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class")) - .collect(Collectors.toList()); - for (Path file : files) { - Path p = classPath.relativize(file); - String className = p.toString(); - className = className.substring(0, className.lastIndexOf('.')); - className = className.replace(File.separatorChar, '.'); + try (Stream stream = Files.walk(classPath)) { + Collection files = + stream.filter( + p -> + Files.isRegularFile(p) + && p.toString().endsWith(".class")) + .collect(Collectors.toList()); + for (Path file : files) { + Path p = classPath.relativize(file); + String className = p.toString(); + className = className.substring(0, className.lastIndexOf('.')); + className = className.replace(File.separatorChar, '.'); - try { - classList.add(Class.forName(className, true, cl)); - } catch (Throwable ignore) { - // ignore + try { + classList.add(Class.forName(className, true, cl)); + } catch (Throwable ignore) { + // ignore + } } } } else if (path.toLowerCase().endsWith(".jar")) {