Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Closes file stream #3130

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions api/src/main/java/ai/djl/repository/LocalRepository.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

/**
* A {@code LocalRepository} is a {@link Repository} located in a filesystem directory.
Expand Down Expand Up @@ -92,29 +93,27 @@ public Artifact resolve(MRL mrl, Map<String, String> filter) throws IOException
@Override
public List<MRL> getResources() {
List<MRL> list = new ArrayList<>();
try {
Files.walk(path)
.forEach(
f -> {
if (f.endsWith("metadata.json") && Files.isRegularFile(f)) {
Path relative = path.relativize(f);
String type = relative.getName(0).toString();
try (Reader reader = Files.newBufferedReader(f)) {
Metadata metadata =
JsonUtils.GSON.fromJson(reader, Metadata.class);
Application application = metadata.getApplication();
String groupId = metadata.getGroupId();
String artifactId = metadata.getArtifactId();
if ("dataset".equals(type)) {
list.add(dataset(application, groupId, artifactId));
} else if ("model".equals(type)) {
list.add(model(application, groupId, artifactId));
}
} catch (IOException e) {
logger.warn("Failed to read metadata.json", e);
}
try (Stream<Path> stream = Files.walk(path)) {
stream.forEach(
f -> {
if (f.endsWith("metadata.json") && Files.isRegularFile(f)) {
Path relative = path.relativize(f);
String type = relative.getName(0).toString();
try (Reader reader = Files.newBufferedReader(f)) {
Metadata metadata = JsonUtils.GSON.fromJson(reader, Metadata.class);
Application application = metadata.getApplication();
String groupId = metadata.getGroupId();
String artifactId = metadata.getArtifactId();
if ("dataset".equals(type)) {
list.add(dataset(application, groupId, artifactId));
} else if ("model".equals(type)) {
list.add(model(application, groupId, artifactId));
}
});
} catch (IOException e) {
logger.warn("Failed to read metadata.json", e);
}
}
});
} catch (IOException e) {
logger.warn("", e);
}
Expand Down
35 changes: 18 additions & 17 deletions api/src/main/java/ai/djl/util/ClassLoaderUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
Expand Down Expand Up @@ -62,10 +61,11 @@ public static <T> T findImplementation(Path path, Class<T> type, String classNam
// we only consider .class files and skip .java files
List<Path> jarFiles;
if (Files.isDirectory(path)) {
jarFiles =
Files.list(path)
.filter(p -> p.toString().endsWith(".jar"))
.collect(Collectors.toList());
try (Stream<Path> stream = Files.list(path)) {
jarFiles =
stream.filter(p -> p.toString().endsWith(".jar"))
.collect(Collectors.toList());
}
} else {
jarFiles = Collections.emptyList();
}
Expand Down Expand Up @@ -111,18 +111,19 @@ private static <T> T scanDirectory(ClassLoader cl, Class<T> type, Path dir) thro
logger.trace("Directory not exists: {}", dir);
return null;
}
Collection<Path> files =
Files.walk(dir)
.filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
.collect(Collectors.toList());
for (Path file : files) {
Path p = dir.relativize(file);
String className = p.toString();
className = className.substring(0, className.lastIndexOf('.'));
className = className.replace(File.separatorChar, '.');
T implemented = initClass(cl, type, className);
if (implemented != null) {
return implemented;
try (Stream<Path> stream = Files.walk(dir)) {
List<Path> files =
stream.filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
.collect(Collectors.toList());
for (Path file : files) {
Path p = dir.relativize(file);
String className = p.toString();
className = className.substring(0, className.lastIndexOf('.'));
className = className.replace(File.separatorChar, '.');
T implemented = initClass(cl, type, className);
if (implemented != null) {
return implemented;
}
}
}
return null;
Expand Down
62 changes: 31 additions & 31 deletions api/src/main/java/ai/djl/util/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/** A class containing utility methods. */
public final class Utils {
Expand Down Expand Up @@ -105,19 +106,18 @@
* @param dir the directory to be removed
*/
public static void deleteQuietly(Path dir) {
try {
Files.walk(dir)
.sorted(Comparator.reverseOrder())
.forEach(
path -> {
try {
Files.deleteIfExists(path);
} catch (IOException ignore) {
// ignore
}
});
List<Path> list;
try (Stream<Path> stream = Files.walk(dir)) {
list = stream.sorted(Comparator.reverseOrder()).collect(Collectors.toList());
} catch (IOException ignore) {
// ignore
return;
}
for (Path path : list) {
try {
Files.deleteIfExists(path);
} catch (IOException ignore) {
// ignore
}
}
}

Expand Down Expand Up @@ -255,23 +255,24 @@
*/
public static int getCurrentEpoch(Path modelDir, String modelName) throws IOException {
final Pattern pattern = Pattern.compile(Pattern.quote(modelName) + "-(\\d{4}).params");
List<Integer> checkpoints =
Files.walk(modelDir, 1, FileVisitOption.FOLLOW_LINKS)
.map(
p -> {
Matcher m = pattern.matcher(p.toFile().getName());
if (m.matches()) {
return Integer.parseInt(m.group(1));
}
return null;
})
.filter(Objects::nonNull)
.sorted()
.collect(Collectors.toList());
if (checkpoints.isEmpty()) {
return -1;
try (Stream<Path> stream = Files.walk(modelDir, 1, FileVisitOption.FOLLOW_LINKS)) {
List<Integer> checkpoints =
stream.map(
p -> {
Matcher m = pattern.matcher(p.toFile().getName());
if (m.matches()) {
return Integer.parseInt(m.group(1));
}
return null;
})
.filter(Objects::nonNull)
.sorted()
.collect(Collectors.toList());
if (checkpoints.isEmpty()) {
return -1;
}
return checkpoints.get(checkpoints.size() - 1);
}
return checkpoints.get(checkpoints.size() - 1);
}

/**
Expand Down Expand Up @@ -380,11 +381,10 @@
*/
public static Path getNestedModelDir(Path modelDir) {
if (Files.isDirectory(modelDir)) {
try {
try (Stream<Path> stream = Files.list(modelDir)) {
// handle actual model directory is subdirectory case
List<Path> files =
Files.list(modelDir)
.filter(p -> !p.getFileName().toString().startsWith("."))
stream.filter(p -> !p.getFileName().toString().startsWith("."))
.collect(Collectors.toList());
if (files.size() == 1 && Files.isDirectory(files.get(0))) {
return files.get(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* {@code MxModel} is the MXNet implementation of {@link Model}.
Expand Down Expand Up @@ -176,9 +177,8 @@ public Trainer newTrainer(TrainingConfig trainingConfig) {
/** {@inheritDoc} */
@Override
public String[] getArtifactNames() {
try {
List<Path> files =
Files.walk(modelDir).filter(Files::isRegularFile).collect(Collectors.toList());
try (Stream<Path> stream = Files.walk(modelDir)) {
List<Path> files = stream.filter(Files::isRegularFile).collect(Collectors.toList());
List<String> ret = new ArrayList<>(files.size());
for (Path path : files) {
String fileName = path.toFile().getName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* {@code PtModel} is the PyTorch implementation of {@link Model}.
Expand Down Expand Up @@ -227,9 +228,8 @@ public Trainer newTrainer(TrainingConfig trainingConfig) {
/** {@inheritDoc} */
@Override
public String[] getArtifactNames() {
try {
List<Path> files =
Files.walk(modelDir).filter(Files::isRegularFile).collect(Collectors.toList());
try (Stream<Path> stream = Files.walk(modelDir)) {
List<Path> files = stream.filter(Files::isRegularFile).collect(Collectors.toList());
List<String> ret = new ArrayList<>(files.size());
for (Path path : files) {
String fileName = path.toFile().getName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Block;
import ai.djl.tensorflow.engine.javacpp.JavacppUtils;
import ai.djl.util.Utils;
Expand All @@ -36,6 +35,7 @@
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/** {@code TfModel} is the TensorFlow implementation of {@link Model}. */
public class TfModel extends BaseModel {
Expand Down Expand Up @@ -148,30 +148,17 @@ public void save(Path modelPath, String newModelName) {
throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
}

/** {@inheritDoc} */
@Override
public Block getBlock() {
return block;
}

/** {@inheritDoc} */
@Override
public void setBlock(Block block) {
throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
}

/** {@inheritDoc} */
@Override
public NDManager getNDManager() {
return manager;
}

/** {@inheritDoc} */
@Override
public String[] getArtifactNames() {
try {
List<Path> files =
Files.walk(modelDir).filter(Files::isRegularFile).collect(Collectors.toList());
try (Stream<Path> stream = Files.walk(modelDir)) {
List<Path> files = stream.filter(Files::isRegularFile).collect(Collectors.toList());
List<String> ret = new ArrayList<>(files.size());
for (Path path : files) {
String fileName = path.toFile().getName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/** An example code dataset using the code within the DJL path. */
public class BertCodeDataset implements Dataset {
Expand Down Expand Up @@ -94,9 +95,8 @@ public int getMaxSequenceLength() {
}

private static List<Path> listSourceFiles(Path root) {
try {
return Files.walk(root)
.filter(Files::isRegularFile)
try (Stream<Path> stream = Files.walk(root)) {
return stream.filter(Files::isRegularFile)
.filter(path -> path.toString().toLowerCase().endsWith(".java"))
.collect(Collectors.toList());
} catch (IOException ioe) {
Expand Down
34 changes: 20 additions & 14 deletions integration/src/main/java/ai/djl/integration/IntegrationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@SuppressWarnings("PMD.TestClassWithoutTestCases")
public class IntegrationTest {
Expand Down Expand Up @@ -219,20 +220,25 @@ private static List<Class<?>> listTestClasses(Arguments arguments, Class<?> claz

Path classPath = Paths.get(url.toURI());
if (Files.isDirectory(classPath)) {
Collection<Path> files =
Files.walk(classPath)
.filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
.collect(Collectors.toList());
for (Path file : files) {
Path p = classPath.relativize(file);
String className = p.toString();
className = className.substring(0, className.lastIndexOf('.'));
className = className.replace(File.separatorChar, '.');
if (className.startsWith(arguments.getPackageName()) && !className.contains("$")) {
try {
classList.add(Class.forName(className));
} catch (ExceptionInInitializerError ignore) {
// ignore
try (Stream<Path> stream = Files.walk(classPath)) {
Collection<Path> files =
stream.filter(
p ->
Files.isRegularFile(p)
&& p.toString().endsWith(".class"))
.collect(Collectors.toList());
for (Path file : files) {
Path p = classPath.relativize(file);
String className = p.toString();
className = className.substring(0, className.lastIndexOf('.'));
className = className.replace(File.separatorChar, '.');
if (className.startsWith(arguments.getPackageName())
&& !className.contains("$")) {
try {
classList.add(Class.forName(className));
} catch (ExceptionInInitializerError ignore) {
// ignore
}
}
}
}
Expand Down
Loading
Loading