Skip to content

Commit

Permalink
[api] Ignore hidden files for nested model directory (#1754)
Browse files Browse the repository at this point in the history
Change-Id: I298e136a2d8db1cc6a68e24ddc8dc26b7b2fbb4d
  • Loading branch information
frankfliu authored Jun 28, 2022
1 parent e69f23f commit 51b38b6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
11 changes: 1 addition & 10 deletions api/src/main/java/ai/djl/BaseModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
Expand Down Expand Up @@ -224,15 +223,7 @@ public String getProperty(String key) {
}

protected void setModelDir(Path modelDir) {
if (Files.isDirectory(modelDir)) {
File[] files = modelDir.toFile().listFiles();
if (files != null && files.length == 1 && files[0].isDirectory()) {
// handle archive file contains folder name case
this.modelDir = files[0].toPath().toAbsolutePath();
return;
}
}
this.modelDir = modelDir.toAbsolutePath();
this.modelDir = Utils.getNestedModelDir(modelDir);
}

/** {@inheritDoc} */
Expand Down
24 changes: 24 additions & 0 deletions api/src/main/java/ai/djl/util/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -353,4 +353,28 @@ public static Path getCacheDir() {
}
return Paths.get(cacheDir);
}

/**
* Returns nested model directory if the directory contains only one subdirectory.
*
* @param modelDir the model directory
* @return subdirectory if the model directory only contains one subdirectory
*/
public static Path getNestedModelDir(Path modelDir) {
if (Files.isDirectory(modelDir)) {
try {
// handle actual model directory is subdirectory case
List<Path> files =
Files.list(modelDir)
.filter(p -> !p.getFileName().toString().startsWith("."))
.collect(Collectors.toList());
if (files.size() == 1 && Files.isDirectory(files.get(0))) {
return files.get(0);
}
} catch (IOException e) {
throw new AssertionError("Failed to list files: " + modelDir, e);
}
}
return modelDir.toAbsolutePath();
}
}

0 comments on commit 51b38b6

Please sign in to comment.