diff --git a/examples/docs/img/joints.png b/examples/docs/img/joints.png deleted file mode 100644 index eb4004a56fc..00000000000 Binary files a/examples/docs/img/joints.png and /dev/null differ diff --git a/examples/docs/pose_estimation.md b/examples/docs/pose_estimation.md index 394fc02ab99..e56a2a4e305 100644 --- a/examples/docs/pose_estimation.md +++ b/examples/docs/pose_estimation.md @@ -2,7 +2,7 @@ Pose estimation is a computer vision technique for determining the pose of an object in an image. -In this example, you learn how to implement inference code with a [ModelZoo model](../../docs/model-zoo.md) to detect dogs in an image. +In this example, you learn how to implement inference code with a [ModelZoo model](../../docs/model-zoo.md) to detect people and their joints in an image. The source code can be found at [PoseEstimation.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/PoseEstimation.java). @@ -28,8 +28,10 @@ cd examples Your output should look like the following: ```text -[INFO ] - Pose image has been saved in: build/output/joints.png -[INFO ] - +[INFO ] - Pose image has been saved in: build/output/joints-0.png +[INFO ] - Pose image has been saved in: build/output/joints-1.png +[INFO ] - Pose image has been saved in: build/output/joints-2.png +[INFO ] - [ [ Joint [x=0.333, y=0.063], confidence: 0.6940, Joint [x=0.333, y=0.031], confidence: 0.7182, Joint [x=0.354, y=0.047], confidence: 0.4949, @@ -47,9 +49,48 @@ Your output should look like the following: Joint [x=0.625, y=0.719], confidence: 0.8233, Joint [x=0.125, y=0.969], confidence: 0.7007, Joint [x=0.958, y=0.844], confidence: 0.7480 -] +], +[ Joint [x=0.354, y=0.125], confidence: 0.8993, + Joint [x=0.375, y=0.109], confidence: 0.9235, + Joint [x=0.354, y=0.109], confidence: 0.8176, + Joint [x=0.438, y=0.094], confidence: 0.9242, + Joint [x=0.458, y=0.094], confidence: 0.6368, + Joint [x=0.500, y=0.156], confidence: 0.8452, + Joint [x=0.688, y=0.156], confidence: 0.6121, + Joint [x=0.479, y=0.250], confidence: 0.9007, + Joint [x=0.854, y=0.234], confidence: 0.7352, + Joint [x=0.208, y=0.250], confidence: 0.7154, + Joint [x=0.958, y=0.313], confidence: 0.5030, + Joint [x=0.625, y=0.484], confidence: 0.6673, + Joint [x=0.500, y=0.500], confidence: 0.7583, + Joint [x=0.708, y=0.719], confidence: 0.7621, + Joint [x=0.271, y=0.641], confidence: 0.8008, + Joint [x=0.250, y=0.906], confidence: 0.8605 +], +[ Joint [x=0.271, y=0.156], confidence: 0.8428, + Joint [x=0.292, y=0.141], confidence: 0.8469, + Joint [x=0.271, y=0.125], confidence: 0.8029, + Joint [x=0.333, y=0.141], confidence: 0.9200, + Joint [x=0.354, y=0.141], confidence: 0.4879, + Joint [x=0.542, y=0.250], confidence: 0.8573, + Joint [x=0.292, y=0.250], confidence: 0.8553, + Joint [x=0.771, y=0.359], confidence: 0.9046, + Joint [x=0.167, y=0.391], confidence: 0.6416, + Joint [x=0.854, y=0.469], confidence: 0.9166, + Joint [x=0.188, y=0.359], confidence: 0.6091, + Joint [x=0.458, y=0.563], confidence: 0.5665, + Joint [x=0.375, y=0.563], confidence: 0.5728, + Joint [x=0.146, y=0.750], confidence: 0.6888, + Joint [x=0.667, y=0.766], confidence: 0.7807, + Joint [x=0.000, y=0.938], confidence: 0.2272, + Joint [x=0.396, y=0.828], confidence: 0.4885 +]] ``` -An output image with the detected joints will be saved as build/output/joints.png: +Output images with the detected joints for each person will be saved in the build/output directory: -![joints](img/joints.png) +![joints-0](https://resources.djl.ai/images/joints-0.png) + +![joints-1](https://resources.djl.ai/images/joints-1.png) + +![joints-2](https://resources.djl.ai/images/joints-2.png) diff --git a/examples/src/main/java/ai/djl/examples/inference/PoseEstimation.java b/examples/src/main/java/ai/djl/examples/inference/PoseEstimation.java index 080ede4af4d..5932fd3c169 100644 --- a/examples/src/main/java/ai/djl/examples/inference/PoseEstimation.java +++ b/examples/src/main/java/ai/djl/examples/inference/PoseEstimation.java @@ -34,6 +34,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -51,25 +52,25 @@ public final class PoseEstimation { private PoseEstimation() {} public static void main(String[] args) throws IOException, ModelException, TranslateException { - Joints joints = PoseEstimation.predict(); + List joints = PoseEstimation.predict(); logger.info("{}", joints); } - public static Joints predict() throws IOException, ModelException, TranslateException { + public static List predict() throws IOException, ModelException, TranslateException { Path imageFile = Paths.get("src/test/resources/pose_soccer.png"); Image img = ImageFactory.getInstance().fromFile(imageFile); - Image person = predictPersonInImage(img); + List people = predictPeopleInImage(img); - if (person == null) { - logger.warn("No person found in image."); - return new Joints(Collections.emptyList()); + if (people.isEmpty()) { + logger.warn("No people found in image."); + return Collections.emptyList(); } - return predictJointsInPerson(person); + return predictJointsForPeople(people); } - private static Image predictPersonInImage(Image img) + private static List predictPeopleInImage(Image img) throws MalformedModelException, ModelNotFoundException, IOException, TranslateException { @@ -93,22 +94,24 @@ private static Image predictPersonInImage(Image img) } List items = detectedObjects.items(); + List people = new ArrayList<>(); for (DetectedObjects.DetectedObject item : items) { if ("person".equals(item.getClassName())) { Rectangle rect = item.getBoundingBox().getBounds(); int width = img.getWidth(); int height = img.getHeight(); - return img.getSubImage( - (int) (rect.getX() * width), - (int) (rect.getY() * height), - (int) (rect.getWidth() * width), - (int) (rect.getHeight() * height)); + people.add( + img.getSubImage( + (int) (rect.getX() * width), + (int) (rect.getY() * height), + (int) (rect.getWidth() * width), + (int) (rect.getHeight() * height))); } } - return null; + return people; } - private static Joints predictJointsInPerson(Image person) + private static List predictJointsForPeople(List people) throws MalformedModelException, ModelNotFoundException, IOException, TranslateException { @@ -121,21 +124,26 @@ private static Joints predictJointsInPerson(Image person) .optFilter("dataset", "imagenet") .build(); + List allJoints = new ArrayList<>(); try (ZooModel pose = criteria.loadModel(); Predictor predictor = pose.newPredictor()) { - Joints joints = predictor.predict(person); - saveJointsImage(person, joints); - return joints; + int count = 0; + for (Image person : people) { + Joints joints = predictor.predict(person); + saveJointsImage(person, joints, count++); + allJoints.add(joints); + } } + return allJoints; } - private static void saveJointsImage(Image img, Joints joints) throws IOException { + private static void saveJointsImage(Image img, Joints joints, int count) throws IOException { Path outputDir = Paths.get("build/output"); Files.createDirectories(outputDir); img.drawJoints(joints); - Path imagePath = outputDir.resolve("joints.png"); + Path imagePath = outputDir.resolve("joints-" + count + ".png"); // Must use png format because you can't save as jpg with an alpha channel img.save(Files.newOutputStream(imagePath), "png"); logger.info("Pose image has been saved in: {}", imagePath); diff --git a/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java b/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java index 3709433affe..f963eb95615 100644 --- a/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java @@ -21,6 +21,7 @@ import org.testng.annotations.Test; import java.io.IOException; +import java.util.List; public class PoseEstimationTest { @@ -28,7 +29,7 @@ public class PoseEstimationTest { public void testPoseEstimation() throws ModelException, TranslateException, IOException { TestRequirements.engine("MXNet"); - Joints result = PoseEstimation.predict(); - Assert.assertTrue(result.getJoints().get(0).getConfidence() > 0.6d); + List result = PoseEstimation.predict(); + Assert.assertTrue(result.get(0).getJoints().get(0).getConfidence() > 0.6d); } }