diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslator.java new file mode 100644 index 000000000000..b59da7020785 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslator.java @@ -0,0 +1,188 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.modality.cv.output.Joints; +import ai.djl.modality.cv.output.Joints.Joint; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.translate.ArgumentsUtil; +import ai.djl.translate.TranslatorContext; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** A translator for Yolov8 pose estimation models. */ +public class YoloPoseTranslator extends BaseImageTranslator { + + private static final int MAX_DETECTION = 300; + + private int width; + private int height; + private float threshold; + private float nmsThreshold; + + /** + * Creates the Pose Estimation translator from the given builder. + * + * @param builder the builder for the translator + */ + public YoloPoseTranslator(Builder builder) { + super(builder); + this.width = builder.width; + this.height = builder.height; + this.threshold = builder.threshold; + this.nmsThreshold = builder.nmsThreshold; + } + + /** {@inheritDoc} */ + @Override + public Joints[] processOutput(TranslatorContext ctx, NDList list) { + NDArray pred = list.singletonOrThrow(); + NDArray candidates = pred.get(4).gt(threshold); + pred = pred.transpose(); + NDArray sub = pred.get("..., :4"); + sub = xywh2xyxy(sub); + pred = sub.concat(pred.get("..., 4:"), -1); + pred = pred.get(candidates); + + NDList split = pred.split(new long[] {4, 5}, 1); + NDArray box = split.get(0); + + int numBox = Math.toIntExact(box.getShape().get(0)); + + float[] buf = box.toFloatArray(); + float[] confidences = split.get(1).toFloatArray(); + float[] mask = split.get(2).toFloatArray(); + + List boxes = new ArrayList<>(numBox); + List scores = new ArrayList<>(numBox); + + for (int i = 0; i < numBox; ++i) { + float xPos = buf[i * 4]; + float yPos = buf[i * 4 + 1]; + float w = buf[i * 4 + 2] - xPos; + float h = buf[i * 4 + 3] - yPos; + Rectangle rect = new Rectangle(xPos, yPos, w, h); + boxes.add(rect); + scores.add((double) confidences[i]); + } + List nms = Rectangle.nms(boxes, scores, nmsThreshold); + if (nms.size() > MAX_DETECTION) { + nms = nms.subList(0, MAX_DETECTION); + } + Joints[] ret = new Joints[nms.size()]; + for (int i = 0; i < ret.length; ++i) { + List joints = new ArrayList<>(); + ret[i] = new Joints(joints); + + int index = nms.get(i); + int pos = index * 51; + for (int j = 0; j < 17; ++j) { + joints.add( + new Joints.Joint( + mask[pos + j * 3] / width, + mask[pos + j * 3 + 1] / height, + mask[pos + j * 3 + 2])); + } + } + return ret; + } + + private NDArray xywh2xyxy(NDArray array) { + NDArray xy = array.get("..., :2"); + NDArray wh = array.get("..., 2:").div(2); + return xy.sub(wh).concat(xy.add(wh), -1); + } + + /** + * Creates a builder to build a {@code YoloPoseTranslator}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Creates a builder to build a {@code YoloPoseTranslator} with specified arguments. + * + * @param arguments arguments to specify builder options + * @return a new builder + */ + public static Builder builder(Map arguments) { + Builder builder = new Builder(); + builder.configPreProcess(arguments); + builder.configPostProcess(arguments); + + return builder; + } + + /** The builder for Pose Estimation translator. */ + public static class Builder extends BaseBuilder { + + float threshold = 0.25f; + float nmsThreshold = 0.7f; + + Builder() {} + + /** + * Sets the threshold for prediction accuracy. + * + *

Predictions below the threshold will be dropped. + * + * @param threshold the threshold for prediction accuracy + * @return the builder + */ + public Builder optThreshold(float threshold) { + this.threshold = threshold; + return self(); + } + + /** + * Sets the NMS threshold. + * + * @param nmsThreshold the NMS threshold + * @return this builder + */ + public Builder optNmsThreshold(float nmsThreshold) { + this.nmsThreshold = nmsThreshold; + return this; + } + + /** {@inheritDoc} */ + @Override + protected Builder self() { + return this; + } + + /** {@inheritDoc} */ + @Override + protected void configPostProcess(Map arguments) { + optThreshold(ArgumentsUtil.floatValue(arguments, "threshold", threshold)); + optNmsThreshold(ArgumentsUtil.floatValue(arguments, "nmsThreshold", nmsThreshold)); + } + + /** + * Builds the translator. + * + * @return the new translator + */ + public YoloPoseTranslator build() { + validate(); + return new YoloPoseTranslator(this); + } + } +} diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslatorFactory.java new file mode 100644 index 000000000000..e27620acdf7b --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslatorFactory.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.Joints; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; +import ai.djl.util.Pair; + +import java.io.Serializable; +import java.lang.reflect.Type; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** An {@link TranslatorFactory} that creates a {@link YoloPoseTranslator} instance. */ +public class YoloPoseTranslatorFactory implements TranslatorFactory, Serializable { + + private static final long serialVersionUID = 1L; + + private static final Set> SUPPORTED_TYPES = new HashSet<>(); + + static { + SUPPORTED_TYPES.add(new Pair<>(Image.class, Joints[].class)); + SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); + } + + /** {@inheritDoc} */ + @Override + public Set> getSupportedTypes() { + return SUPPORTED_TYPES; + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public Translator newInstance( + Class input, Class output, Model model, Map arguments) { + YoloPoseTranslator translator = YoloPoseTranslator.builder(arguments).build(); + if (input == Image.class && output == Joints[].class) { + return (Translator) translator; + } else if (input == Input.class && output == Output.class) { + return (Translator) new ImageServingTranslator(translator); + } + throw new IllegalArgumentException("Unsupported input/output types."); + } +} diff --git a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java index eacfe74c8d32..2317a45b2259 100644 --- a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java +++ b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java @@ -39,6 +39,7 @@ public class PtModelZoo extends ModelZoo { GROUP_ID, "Human-Action-Recognition-VIT-Base-patch16-224", "0.0.1")); + addModel(REPOSITORY.model(CV.POSE_ESTIMATION, GROUP_ID, "yolov8n-pose", "0.0.1")); addModel(REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1")); addModel( REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet18_embedding", "0.0.1")); diff --git a/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/pose_estimation/ai/djl/pytorch/yolov8n-pose/metadata.json b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/pose_estimation/ai/djl/pytorch/yolov8n-pose/metadata.json new file mode 100644 index 000000000000..505256d982d3 --- /dev/null +++ b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/pose_estimation/ai/djl/pytorch/yolov8n-pose/metadata.json @@ -0,0 +1,44 @@ +{ + "metadataVersion": "0.2", + "resourceType": "model", + "application": "cv/pose_estimation", + "groupId": "ai.djl.pytorch", + "artifactId": "yolov8n-pose", + "name": "Yolov8n pose", + "description": "Yolov8n Pose Estimation", + "website": "http://www.djl.ai/engines/pytorch/pytorch-model-zoo", + "licenses": { + "license": { + "name": "The Apache License, Version 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + } + }, + "artifacts": [ + { + "version": "0.0.1", + "snapshot": false, + "name": "yolov8n-pose", + "properties": { + }, + "arguments": { + "width": 640, + "height": 640, + "centerCrop": true, + "resize": true, + "threshold": 0.25, + "translatorFactory": "ai.djl.modality.cv.translator.YoloPoseTranslatorFactory" + }, + "options": { + "mapLocation": "true" + }, + "files": { + "model": { + "uri": "0.0.1/yolov8n-pose.zip", + "sha1Hash": "36bbeda33b2221a94b26de70c68ef78437d28861", + "name": "", + "size": 11684801 + } + } + } + ] +} diff --git a/examples/src/main/java/ai/djl/examples/inference/cv/PoseEstimation.java b/examples/src/main/java/ai/djl/examples/inference/cv/PoseEstimation.java index 3a5866c7b550..9187c1f611b5 100644 --- a/examples/src/main/java/ai/djl/examples/inference/cv/PoseEstimation.java +++ b/examples/src/main/java/ai/djl/examples/inference/cv/PoseEstimation.java @@ -12,18 +12,14 @@ */ package ai.djl.examples.inference.cv; -import ai.djl.MalformedModelException; import ai.djl.ModelException; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; -import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.modality.cv.output.Joints; -import ai.djl.modality.cv.output.Rectangle; +import ai.djl.modality.cv.translator.YoloPoseTranslatorFactory; import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ZooModel; -import ai.djl.training.util.ProgressBar; import ai.djl.translate.TranslateException; import org.slf4j.Logger; @@ -33,9 +29,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; +import java.util.Arrays; /** * An example of inference using a pose estimation model. @@ -51,98 +45,39 @@ public final class PoseEstimation { private PoseEstimation() {} public static void main(String[] args) throws IOException, ModelException, TranslateException { - List joints = predict(); - logger.info("{}", joints); + Joints[] joints = predict(); + logger.info("{}", Arrays.toString(joints)); } - public static List predict() throws IOException, ModelException, TranslateException { + public static Joints[] predict() throws IOException, ModelException, TranslateException { Path imageFile = Paths.get("src/test/resources/pose_soccer.png"); Image img = ImageFactory.getInstance().fromFile(imageFile); - List people = predictPeopleInImage(img); - - if (people.isEmpty()) { - logger.warn("No people found in image."); - return Collections.emptyList(); - } - - return predictJointsForPeople(people); - } - - private static List predictPeopleInImage(Image img) - throws MalformedModelException, - ModelNotFoundException, - IOException, - TranslateException { - - Criteria criteria = - Criteria.builder() - .setTypes(Image.class, DetectedObjects.class) - .optModelUrls("djl://ai.djl.mxnet/ssd/0.0.1/ssd_512_resnet50_v1_voc") - .optEngine("MXNet") - .optProgress(new ProgressBar()) - .build(); - - DetectedObjects detectedObjects; - try (ZooModel ssd = criteria.loadModel(); - Predictor predictor = ssd.newPredictor()) { - detectedObjects = predictor.predict(img); - } - - List items = detectedObjects.items(); - List people = new ArrayList<>(); - for (DetectedObjects.DetectedObject item : items) { - if ("person".equals(item.getClassName())) { - Rectangle rect = item.getBoundingBox().getBounds(); - int width = img.getWidth(); - int height = img.getHeight(); - people.add( - img.getSubImage( - (int) (rect.getX() * width), - (int) (rect.getY() * height), - (int) (rect.getWidth() * width), - (int) (rect.getHeight() * height))); - } - } - return people; - } - - private static List predictJointsForPeople(List people) - throws MalformedModelException, - ModelNotFoundException, - IOException, - TranslateException { - - // Use DJL MXNet model zoo model, model can be found: - // https://mlrepo.djl.ai/model/cv/pose_estimation/ai/djl/mxnet/simple_pose/0.0.1/simple_pose_resnet18_v1b-0000.params.gz - // https://mlrepo.djl.ai/model/cv/pose_estimation/ai/djl/mxnet/simple_pose/0.0.1/simple_pose_resnet18_v1b-symbol.json - Criteria criteria = + // Use DJL PyTorch model zoo model + Criteria criteria = Criteria.builder() - .setTypes(Image.class, Joints.class) - .optModelUrls( - "djl://ai.djl.mxnet/simple_pose/0.0.1/simple_pose_resnet18_v1b") + .setTypes(Image.class, Joints[].class) + .optModelUrls("djl://ai.djl.pytorch/yolov8n-pose") + .optTranslatorFactory(new YoloPoseTranslatorFactory()) .build(); - List allJoints = new ArrayList<>(); - try (ZooModel pose = criteria.loadModel(); - Predictor predictor = pose.newPredictor()) { - int count = 0; - for (Image person : people) { - Joints joints = predictor.predict(person); - saveJointsImage(person, joints, count++); - allJoints.add(joints); - } + try (ZooModel pose = criteria.loadModel(); + Predictor predictor = pose.newPredictor()) { + Joints[] allJoints = predictor.predict(img); + saveJointsImage(img, allJoints); + return allJoints; } - return allJoints; } - private static void saveJointsImage(Image img, Joints joints, int count) throws IOException { + private static void saveJointsImage(Image img, Joints[] allJoints) throws IOException { Path outputDir = Paths.get("build/output"); Files.createDirectories(outputDir); - img.drawJoints(joints); + for (Joints joints : allJoints) { + img.drawJoints(joints); + } - Path imagePath = outputDir.resolve("joints-" + count + ".png"); + Path imagePath = outputDir.resolve("joints.png"); // Must use png format because you can't save as jpg with an alpha channel img.save(Files.newOutputStream(imagePath), "png"); logger.info("Pose image has been saved in: {}", imagePath); diff --git a/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java b/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java index d0afdb90be12..398c71ced6e8 100644 --- a/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java @@ -15,22 +15,19 @@ import ai.djl.ModelException; import ai.djl.examples.inference.cv.PoseEstimation; import ai.djl.modality.cv.output.Joints; -import ai.djl.testing.TestRequirements; import ai.djl.translate.TranslateException; import org.testng.Assert; import org.testng.annotations.Test; import java.io.IOException; -import java.util.List; public class PoseEstimationTest { @Test public void testPoseEstimation() throws ModelException, TranslateException, IOException { - TestRequirements.linux(); - - List result = PoseEstimation.predict(); - Assert.assertTrue(result.get(0).getJoints().get(0).getConfidence() > 0.6d); + Joints[] result = PoseEstimation.predict(); + Assert.assertEquals(result.length, 3); + Assert.assertEquals(result[0].getJoints().size(), 17); } }