Skip to content

Commit

Permalink
[examples] update pose estimation example to detect joints for all pe…
Browse files Browse the repository at this point in the history
…ople (#2002)

* update pose estimation example to detect joints for all people

* Move image to S3

Change-Id: I12517d7eb08c5e345f7d891b179371f5993e86cb

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
  • Loading branch information
siddvenk and frankfliu authored Sep 8, 2022
1 parent e885b6d commit bd45dff
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 28 deletions.
Binary file removed examples/docs/img/joints.png
Binary file not shown.
53 changes: 47 additions & 6 deletions examples/docs/pose_estimation.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Pose estimation is a computer vision technique for determining the pose of an object in an image.

In this example, you learn how to implement inference code with a [ModelZoo model](../../docs/model-zoo.md) to detect dogs in an image.
In this example, you learn how to implement inference code with a [ModelZoo model](../../docs/model-zoo.md) to detect people and their joints in an image.

The source code can be found at [PoseEstimation.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/PoseEstimation.java).

Expand All @@ -28,8 +28,10 @@ cd examples
Your output should look like the following:

```text
[INFO ] - Pose image has been saved in: build/output/joints.png
[INFO ] -
[INFO ] - Pose image has been saved in: build/output/joints-0.png
[INFO ] - Pose image has been saved in: build/output/joints-1.png
[INFO ] - Pose image has been saved in: build/output/joints-2.png
[INFO ] - [
[ Joint [x=0.333, y=0.063], confidence: 0.6940,
Joint [x=0.333, y=0.031], confidence: 0.7182,
Joint [x=0.354, y=0.047], confidence: 0.4949,
Expand All @@ -47,9 +49,48 @@ Your output should look like the following:
Joint [x=0.625, y=0.719], confidence: 0.8233,
Joint [x=0.125, y=0.969], confidence: 0.7007,
Joint [x=0.958, y=0.844], confidence: 0.7480
]
],
[ Joint [x=0.354, y=0.125], confidence: 0.8993,
Joint [x=0.375, y=0.109], confidence: 0.9235,
Joint [x=0.354, y=0.109], confidence: 0.8176,
Joint [x=0.438, y=0.094], confidence: 0.9242,
Joint [x=0.458, y=0.094], confidence: 0.6368,
Joint [x=0.500, y=0.156], confidence: 0.8452,
Joint [x=0.688, y=0.156], confidence: 0.6121,
Joint [x=0.479, y=0.250], confidence: 0.9007,
Joint [x=0.854, y=0.234], confidence: 0.7352,
Joint [x=0.208, y=0.250], confidence: 0.7154,
Joint [x=0.958, y=0.313], confidence: 0.5030,
Joint [x=0.625, y=0.484], confidence: 0.6673,
Joint [x=0.500, y=0.500], confidence: 0.7583,
Joint [x=0.708, y=0.719], confidence: 0.7621,
Joint [x=0.271, y=0.641], confidence: 0.8008,
Joint [x=0.250, y=0.906], confidence: 0.8605
],
[ Joint [x=0.271, y=0.156], confidence: 0.8428,
Joint [x=0.292, y=0.141], confidence: 0.8469,
Joint [x=0.271, y=0.125], confidence: 0.8029,
Joint [x=0.333, y=0.141], confidence: 0.9200,
Joint [x=0.354, y=0.141], confidence: 0.4879,
Joint [x=0.542, y=0.250], confidence: 0.8573,
Joint [x=0.292, y=0.250], confidence: 0.8553,
Joint [x=0.771, y=0.359], confidence: 0.9046,
Joint [x=0.167, y=0.391], confidence: 0.6416,
Joint [x=0.854, y=0.469], confidence: 0.9166,
Joint [x=0.188, y=0.359], confidence: 0.6091,
Joint [x=0.458, y=0.563], confidence: 0.5665,
Joint [x=0.375, y=0.563], confidence: 0.5728,
Joint [x=0.146, y=0.750], confidence: 0.6888,
Joint [x=0.667, y=0.766], confidence: 0.7807,
Joint [x=0.000, y=0.938], confidence: 0.2272,
Joint [x=0.396, y=0.828], confidence: 0.4885
]]
```

An output image with the detected joints will be saved as build/output/joints.png:
Output images with the detected joints for each person will be saved in the build/output directory:

![joints](img/joints.png)
![joints-0](https://resources.djl.ai/images/joints-0.png)

![joints-1](https://resources.djl.ai/images/joints-1.png)

![joints-2](https://resources.djl.ai/images/joints-2.png)
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

Expand All @@ -51,25 +52,25 @@ public final class PoseEstimation {
private PoseEstimation() {}

public static void main(String[] args) throws IOException, ModelException, TranslateException {
Joints joints = PoseEstimation.predict();
List<Joints> joints = PoseEstimation.predict();
logger.info("{}", joints);
}

public static Joints predict() throws IOException, ModelException, TranslateException {
public static List<Joints> predict() throws IOException, ModelException, TranslateException {
Path imageFile = Paths.get("src/test/resources/pose_soccer.png");
Image img = ImageFactory.getInstance().fromFile(imageFile);

Image person = predictPersonInImage(img);
List<Image> people = predictPeopleInImage(img);

if (person == null) {
logger.warn("No person found in image.");
return new Joints(Collections.emptyList());
if (people.isEmpty()) {
logger.warn("No people found in image.");
return Collections.emptyList();
}

return predictJointsInPerson(person);
return predictJointsForPeople(people);
}

private static Image predictPersonInImage(Image img)
private static List<Image> predictPeopleInImage(Image img)
throws MalformedModelException, ModelNotFoundException, IOException,
TranslateException {

Expand All @@ -93,22 +94,24 @@ private static Image predictPersonInImage(Image img)
}

List<DetectedObjects.DetectedObject> items = detectedObjects.items();
List<Image> people = new ArrayList<>();
for (DetectedObjects.DetectedObject item : items) {
if ("person".equals(item.getClassName())) {
Rectangle rect = item.getBoundingBox().getBounds();
int width = img.getWidth();
int height = img.getHeight();
return img.getSubImage(
(int) (rect.getX() * width),
(int) (rect.getY() * height),
(int) (rect.getWidth() * width),
(int) (rect.getHeight() * height));
people.add(
img.getSubImage(
(int) (rect.getX() * width),
(int) (rect.getY() * height),
(int) (rect.getWidth() * width),
(int) (rect.getHeight() * height)));
}
}
return null;
return people;
}

private static Joints predictJointsInPerson(Image person)
private static List<Joints> predictJointsForPeople(List<Image> people)
throws MalformedModelException, ModelNotFoundException, IOException,
TranslateException {

Expand All @@ -121,21 +124,26 @@ private static Joints predictJointsInPerson(Image person)
.optFilter("dataset", "imagenet")
.build();

List<Joints> allJoints = new ArrayList<>();
try (ZooModel<Image, Joints> pose = criteria.loadModel();
Predictor<Image, Joints> predictor = pose.newPredictor()) {
Joints joints = predictor.predict(person);
saveJointsImage(person, joints);
return joints;
int count = 0;
for (Image person : people) {
Joints joints = predictor.predict(person);
saveJointsImage(person, joints, count++);
allJoints.add(joints);
}
}
return allJoints;
}

private static void saveJointsImage(Image img, Joints joints) throws IOException {
private static void saveJointsImage(Image img, Joints joints, int count) throws IOException {
Path outputDir = Paths.get("build/output");
Files.createDirectories(outputDir);

img.drawJoints(joints);

Path imagePath = outputDir.resolve("joints.png");
Path imagePath = outputDir.resolve("joints-" + count + ".png");
// Must use png format because you can't save as jpg with an alpha channel
img.save(Files.newOutputStream(imagePath), "png");
logger.info("Pose image has been saved in: {}", imagePath);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
import org.testng.annotations.Test;

import java.io.IOException;
import java.util.List;

public class PoseEstimationTest {

@Test
public void testPoseEstimation() throws ModelException, TranslateException, IOException {
TestRequirements.engine("MXNet");

Joints result = PoseEstimation.predict();
Assert.assertTrue(result.getJoints().get(0).getConfidence() > 0.6d);
List<Joints> result = PoseEstimation.predict();
Assert.assertTrue(result.get(0).getJoints().get(0).getConfidence() > 0.6d);
}
}

0 comments on commit bd45dff

Please sign in to comment.