From 6192fd6438cde6b603a1a78ababb4591c910563d Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 2 Jul 2024 15:11:23 -0700 Subject: [PATCH 1/2] [api] Reactor nms for yolo translator --- .../ai/djl/modality/cv/output/Rectangle.java | 78 +++++++++ .../cv/translator/YoloV5Translator.java | 161 ++++-------------- .../cv/translator/YoloV8Translator.java | 14 +- 3 files changed, 122 insertions(+), 131 deletions(-) diff --git a/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java b/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java index 92afc603272..b7b16127a00 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java +++ b/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java @@ -14,6 +14,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.PriorityQueue; /** * A {@code Rectangle} specifies an area in a coordinate space that is enclosed by the {@code @@ -152,4 +153,81 @@ public String toString() { return String.format( "{\"x\"=%.3f, \"y\"=%.3f, \"width\"=%.3f, \"height\"=%.3f}", x, y, width, height); } + + /** + * Applies nms (non-maximum suppression) to the list of rectangles. + * + * @param boxes an list of {@code Rectangle} + * @param scores a list of scores + * @param nmsThreshold the nms threshold + * @return the filtered list with the index of the original list + */ + public static List nms( + List boxes, List scores, float nmsThreshold) { + List ret = new ArrayList<>(); + PriorityQueue pq = + new PriorityQueue<>( + 50, + (lhs, rhs) -> { + // Intentionally reversed to put high confidence at the head of the + // queue. + return Double.compare(scores.get(rhs), scores.get(lhs)); + }); + for (int i = 0; i < boxes.size(); ++i) { + pq.add(i); + } + + // do non maximum suppression + while (!pq.isEmpty()) { + // insert detection with max confidence + int[] detections = pq.stream().mapToInt(Integer::intValue).toArray(); + ret.add(detections[0]); + Rectangle box = boxes.get(detections[0]); + pq.clear(); + for (int i = 1; i < detections.length; i++) { + int detection = detections[i]; + Rectangle location = boxes.get(detection); + if (box.boxIou(location) < nmsThreshold) { + pq.add(detection); + } + } + } + return ret; + } + + private double boxIou(Rectangle other) { + double intersection = intersection(other); + double union = + getWidth() * getHeight() + other.getWidth() * other.getHeight() - intersection; + return intersection / union; + } + + private double intersection(Rectangle b) { + double w = + overlap( + (getX() * 2 + getWidth()) / 2, + getWidth(), + (b.getX() * 2 + b.getWidth()) / 2, + b.getWidth()); + double h = + overlap( + (getY() * 2 + getHeight()) / 2, + getHeight(), + (b.getY() * 2 + b.getHeight()) / 2, + b.getHeight()); + if (w < 0 || h < 0) { + return 0; + } + return w * h; + } + + private double overlap(double x1, double w1, double x2, double w2) { + double l1 = x1 - w1 / 2; + double l2 = x2 - w2 / 2; + double left = Math.max(l1, l2); + double r1 = x1 + w1 / 2; + double r2 = x2 + w2 / 2; + double right = Math.min(r1, r2); + return right - left; + } } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java index c31353766d3..67682d7c8b7 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java @@ -24,7 +24,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.PriorityQueue; /** * A translator for YoloV5 models. This was tested with ONNX exported Yolo models. For details check @@ -68,104 +67,58 @@ public static YoloV5Translator.Builder builder(Map arguments) { return builder; } - protected double boxIntersection(Rectangle a, Rectangle b) { - double w = - overlap( - (a.getX() * 2 + a.getWidth()) / 2, - a.getWidth(), - (b.getX() * 2 + b.getWidth()) / 2, - b.getWidth()); - double h = - overlap( - (a.getY() * 2 + a.getHeight()) / 2, - a.getHeight(), - (b.getY() * 2 + b.getHeight()) / 2, - b.getHeight()); - if (w < 0 || h < 0) { - return 0; - } - return w * h; - } - - protected double boxIou(Rectangle a, Rectangle b) { - return boxIntersection(a, b) / boxUnion(a, b); - } - - protected double boxUnion(Rectangle a, Rectangle b) { - double i = boxIntersection(a, b); - return (a.getWidth()) * (a.getHeight()) + (b.getWidth()) * (b.getHeight()) - i; - } - - protected DetectedObjects nms(List list) { + protected DetectedObjects nms( + List boxes, List classIds, List scores) { List retClasses = new ArrayList<>(); List retProbs = new ArrayList<>(); List retBB = new ArrayList<>(); - for (int k = 0; k < classes.size(); k++) { - // 1.find max confidence per class - PriorityQueue pq = - new PriorityQueue<>( - 50, - (lhs, rhs) -> { - // Intentionally reversed to put high confidence at the head of the - // queue. - return Double.compare(rhs.getConfidence(), lhs.getConfidence()); - }); - - for (IntermediateResult intermediateResult : list) { - if (intermediateResult.getDetectedClass() == k) { - pq.add(intermediateResult); + for (int classId = 0; classId < classes.size(); classId++) { + List r = new ArrayList<>(); + List s = new ArrayList<>(); + List map = new ArrayList<>(); + for (int j = 0; j < classIds.size(); ++j) { + if (classIds.get(j) == classId) { + r.add(boxes.get(j)); + s.add(scores.get(j).doubleValue()); + map.add(j); } } - - // 2.do non maximum suppression - while (pq.size() > 0) { - // insert detection with max confidence - IntermediateResult[] a = new IntermediateResult[pq.size()]; - IntermediateResult[] detections = pq.toArray(a); - Rectangle rec = detections[0].getLocation(); - retClasses.add(detections[0].id); - retProbs.add(detections[0].confidence); + if (r.isEmpty()) { + continue; + } + List nms = Rectangle.nms(r, s, nmsThreshold); + for (int index : nms) { + int pos = map.get(index); + int id = classIds.get(pos); + retClasses.add(classes.get(id)); + retProbs.add(scores.get(pos).doubleValue()); + Rectangle rect = boxes.get(pos); if (applyRatio) { retBB.add( new Rectangle( - rec.getX() / imageWidth, - rec.getY() / imageHeight, - rec.getWidth() / imageWidth, - rec.getHeight() / imageHeight)); + rect.getX() / imageWidth, + rect.getY() / imageHeight, + rect.getWidth() / imageWidth, + rect.getHeight() / imageHeight)); } else { - retBB.add( - new Rectangle(rec.getX(), rec.getY(), rec.getWidth(), rec.getHeight())); - } - pq.clear(); - for (int j = 1; j < detections.length; j++) { - IntermediateResult detection = detections[j]; - Rectangle location = detection.getLocation(); - if (boxIou(rec, location) < nmsThreshold) { - pq.add(detection); - } + retBB.add(rect); } } } return new DetectedObjects(retClasses, retProbs, retBB); } - protected double overlap(double x1, double w1, double x2, double w2) { - double l1 = x1 - w1 / 2; - double l2 = x2 - w2 / 2; - double left = Math.max(l1, l2); - double r1 = x1 + w1 / 2; - double r2 = x2 + w2 / 2; - double right = Math.min(r1, r2); - return right - left; - } - protected DetectedObjects processFromBoxOutput(NDList list) { float[] flattened = list.get(0).toFloatArray(); - ArrayList intermediateResults = new ArrayList<>(); int sizeClasses = classes.size(); int stride = 5 + sizeClasses; int size = flattened.length / stride; + + ArrayList boxes = new ArrayList<>(); + ArrayList scores = new ArrayList<>(); + ArrayList classIds = new ArrayList<>(); + for (int i = 0; i < size; i++) { int indexBase = i * stride; float maxClass = 0; @@ -184,11 +137,12 @@ protected DetectedObjects processFromBoxOutput(NDList list) { float h = flattened[indexBase + 3]; Rectangle rect = new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h); - intermediateResults.add( - new IntermediateResult(classes.get(maxIndex), score, maxIndex, rect)); + boxes.add(rect); + scores.add(score); + classIds.add(maxIndex); } } - return nms(intermediateResults); + return nms(boxes, classIds, scores); } private DetectedObjects processFromDetectOutput() { @@ -279,49 +233,4 @@ public YoloV5Translator build() { return new YoloV5Translator(this); } } - - protected static final class IntermediateResult { - - /** - * A sortable score for how good the recognition is relative to others. Higher should be - * better. - */ - private double confidence; - - /** Display name for the recognition. */ - private int detectedClass; - - /** - * A unique identifier for what has been recognized. Specific to the class, not the instance - * of the object. - */ - private String id; - - /** Optional location within the source image for the location of the recognized object. */ - private Rectangle location; - - IntermediateResult(String id, double confidence, int detectedClass, Rectangle location) { - this.confidence = confidence; - this.id = id; - this.detectedClass = detectedClass; - this.location = location; - } - - public double getConfidence() { - return confidence; - } - - public int getDetectedClass() { - return detectedClass; - } - - public String getId() { - return id; - } - - public Rectangle getLocation() { - return new Rectangle( - location.getX(), location.getY(), location.getWidth(), location.getHeight()); - } - } } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java index 7fb1fda627a..7895ad15c07 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java @@ -79,7 +79,10 @@ protected DetectedObjects processFromBoxOutput(NDList list) { "Expected classes: " + (nClasses - 4) + ", got " + classes.size()); } - ArrayList intermediateResults = new ArrayList<>(); + ArrayList boxes = new ArrayList<>(); + ArrayList scores = new ArrayList<>(); + ArrayList classIds = new ArrayList<>(); + // reverse order search in heap; searches through #maxBoxes for optimization when set for (int i = numberRows - 1; i > numberRows - maxBoxes; --i) { int index = i * nClasses; @@ -101,12 +104,13 @@ protected DetectedObjects processFromBoxOutput(NDList list) { float h = buf[index + 3]; Rectangle rect = new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h); - intermediateResults.add( - new IntermediateResult( - classes.get(maxIndex), maxClassProb, maxIndex, rect)); + boxes.add(rect); + scores.add(maxClassProb); + classIds.add(maxIndex); } } - return nms(intermediateResults); + + return nms(boxes, classIds, scores); } /** The builder for {@link YoloV8Translator}. */ From 6911840c1f84dd180db03b0b9fb30cb8f7632d3c Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sun, 30 Jun 2024 09:53:44 -0700 Subject: [PATCH 2/2] [pytorch] Adds yolov8n pose estimation model --- .../cv/translator/YoloPoseTranslator.java | 188 ++++++++++++++++++ .../translator/YoloPoseTranslatorFactory.java | 61 ++++++ .../java/ai/djl/pytorch/zoo/PtModelZoo.java | 1 + .../ai/djl/pytorch/yolov8n-pose/metadata.json | 43 ++++ .../examples/inference/cv/PoseEstimation.java | 105 ++-------- .../inference/PoseEstimationTest.java | 9 +- 6 files changed, 316 insertions(+), 91 deletions(-) create mode 100644 api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslator.java create mode 100644 api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslatorFactory.java create mode 100644 engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/pose_estimation/ai/djl/pytorch/yolov8n-pose/metadata.json diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslator.java new file mode 100644 index 00000000000..b59da702078 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslator.java @@ -0,0 +1,188 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.modality.cv.output.Joints; +import ai.djl.modality.cv.output.Joints.Joint; +import ai.djl.modality.cv.output.Rectangle; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.translate.ArgumentsUtil; +import ai.djl.translate.TranslatorContext; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** A translator for Yolov8 pose estimation models. */ +public class YoloPoseTranslator extends BaseImageTranslator { + + private static final int MAX_DETECTION = 300; + + private int width; + private int height; + private float threshold; + private float nmsThreshold; + + /** + * Creates the Pose Estimation translator from the given builder. + * + * @param builder the builder for the translator + */ + public YoloPoseTranslator(Builder builder) { + super(builder); + this.width = builder.width; + this.height = builder.height; + this.threshold = builder.threshold; + this.nmsThreshold = builder.nmsThreshold; + } + + /** {@inheritDoc} */ + @Override + public Joints[] processOutput(TranslatorContext ctx, NDList list) { + NDArray pred = list.singletonOrThrow(); + NDArray candidates = pred.get(4).gt(threshold); + pred = pred.transpose(); + NDArray sub = pred.get("..., :4"); + sub = xywh2xyxy(sub); + pred = sub.concat(pred.get("..., 4:"), -1); + pred = pred.get(candidates); + + NDList split = pred.split(new long[] {4, 5}, 1); + NDArray box = split.get(0); + + int numBox = Math.toIntExact(box.getShape().get(0)); + + float[] buf = box.toFloatArray(); + float[] confidences = split.get(1).toFloatArray(); + float[] mask = split.get(2).toFloatArray(); + + List boxes = new ArrayList<>(numBox); + List scores = new ArrayList<>(numBox); + + for (int i = 0; i < numBox; ++i) { + float xPos = buf[i * 4]; + float yPos = buf[i * 4 + 1]; + float w = buf[i * 4 + 2] - xPos; + float h = buf[i * 4 + 3] - yPos; + Rectangle rect = new Rectangle(xPos, yPos, w, h); + boxes.add(rect); + scores.add((double) confidences[i]); + } + List nms = Rectangle.nms(boxes, scores, nmsThreshold); + if (nms.size() > MAX_DETECTION) { + nms = nms.subList(0, MAX_DETECTION); + } + Joints[] ret = new Joints[nms.size()]; + for (int i = 0; i < ret.length; ++i) { + List joints = new ArrayList<>(); + ret[i] = new Joints(joints); + + int index = nms.get(i); + int pos = index * 51; + for (int j = 0; j < 17; ++j) { + joints.add( + new Joints.Joint( + mask[pos + j * 3] / width, + mask[pos + j * 3 + 1] / height, + mask[pos + j * 3 + 2])); + } + } + return ret; + } + + private NDArray xywh2xyxy(NDArray array) { + NDArray xy = array.get("..., :2"); + NDArray wh = array.get("..., 2:").div(2); + return xy.sub(wh).concat(xy.add(wh), -1); + } + + /** + * Creates a builder to build a {@code YoloPoseTranslator}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Creates a builder to build a {@code YoloPoseTranslator} with specified arguments. + * + * @param arguments arguments to specify builder options + * @return a new builder + */ + public static Builder builder(Map arguments) { + Builder builder = new Builder(); + builder.configPreProcess(arguments); + builder.configPostProcess(arguments); + + return builder; + } + + /** The builder for Pose Estimation translator. */ + public static class Builder extends BaseBuilder { + + float threshold = 0.25f; + float nmsThreshold = 0.7f; + + Builder() {} + + /** + * Sets the threshold for prediction accuracy. + * + *

Predictions below the threshold will be dropped. + * + * @param threshold the threshold for prediction accuracy + * @return the builder + */ + public Builder optThreshold(float threshold) { + this.threshold = threshold; + return self(); + } + + /** + * Sets the NMS threshold. + * + * @param nmsThreshold the NMS threshold + * @return this builder + */ + public Builder optNmsThreshold(float nmsThreshold) { + this.nmsThreshold = nmsThreshold; + return this; + } + + /** {@inheritDoc} */ + @Override + protected Builder self() { + return this; + } + + /** {@inheritDoc} */ + @Override + protected void configPostProcess(Map arguments) { + optThreshold(ArgumentsUtil.floatValue(arguments, "threshold", threshold)); + optNmsThreshold(ArgumentsUtil.floatValue(arguments, "nmsThreshold", nmsThreshold)); + } + + /** + * Builds the translator. + * + * @return the new translator + */ + public YoloPoseTranslator build() { + validate(); + return new YoloPoseTranslator(this); + } + } +} diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslatorFactory.java new file mode 100644 index 00000000000..e27620acdf7 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslatorFactory.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.Joints; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; +import ai.djl.util.Pair; + +import java.io.Serializable; +import java.lang.reflect.Type; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** An {@link TranslatorFactory} that creates a {@link YoloPoseTranslator} instance. */ +public class YoloPoseTranslatorFactory implements TranslatorFactory, Serializable { + + private static final long serialVersionUID = 1L; + + private static final Set> SUPPORTED_TYPES = new HashSet<>(); + + static { + SUPPORTED_TYPES.add(new Pair<>(Image.class, Joints[].class)); + SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); + } + + /** {@inheritDoc} */ + @Override + public Set> getSupportedTypes() { + return SUPPORTED_TYPES; + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public Translator newInstance( + Class input, Class output, Model model, Map arguments) { + YoloPoseTranslator translator = YoloPoseTranslator.builder(arguments).build(); + if (input == Image.class && output == Joints[].class) { + return (Translator) translator; + } else if (input == Input.class && output == Output.class) { + return (Translator) new ImageServingTranslator(translator); + } + throw new IllegalArgumentException("Unsupported input/output types."); + } +} diff --git a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java index eacfe74c8d3..336e72e98d8 100644 --- a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java +++ b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java @@ -45,6 +45,7 @@ public class PtModelZoo extends ModelZoo { addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov5s", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1")); + addModel(REPOSITORY.model(CV.POSE_ESTIMATION, GROUP_ID, "yolov8n-pose", "0.0.1")); addModel(REPOSITORY.model(NLP.QUESTION_ANSWER, GROUP_ID, "bertqa", "0.0.1")); addModel(REPOSITORY.model(NLP.SENTIMENT_ANALYSIS, GROUP_ID, "distilbert", "0.0.1")); addModel(REPOSITORY.model(CV.IMAGE_GENERATION, GROUP_ID, "biggan-deep", "0.0.1")); diff --git a/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/pose_estimation/ai/djl/pytorch/yolov8n-pose/metadata.json b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/pose_estimation/ai/djl/pytorch/yolov8n-pose/metadata.json new file mode 100644 index 00000000000..62741a61ef8 --- /dev/null +++ b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/pose_estimation/ai/djl/pytorch/yolov8n-pose/metadata.json @@ -0,0 +1,43 @@ +{ + "metadataVersion": "0.2", + "resourceType": "model", + "application": "cv/pose_estimation", + "groupId": "ai.djl.pytorch", + "artifactId": "yolov8n-pose", + "name": "Yolov8n pose", + "description": "Yolov8n Pose Estimation", + "website": "http://www.djl.ai/engines/pytorch/pytorch-model-zoo", + "licenses": { + "license": { + "name": "The Apache License, Version 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + } + }, + "artifacts": [ + { + "version": "0.0.1", + "snapshot": false, + "name": "yolov8n-pose", + "properties": { + }, + "arguments": { + "width": 640, + "height": 640, + "resize": true, + "threshold": 0.25, + "translatorFactory": "ai.djl.modality.cv.translator.YoloPoseTranslatorFactory" + }, + "options": { + "mapLocation": "true" + }, + "files": { + "model": { + "uri": "0.0.1/yolov8n-pose.zip", + "sha1Hash": "b2358198586ea35189f30aa5f49285dc1842ec1b", + "name": "", + "size": 11684792 + } + } + } + ] +} diff --git a/examples/src/main/java/ai/djl/examples/inference/cv/PoseEstimation.java b/examples/src/main/java/ai/djl/examples/inference/cv/PoseEstimation.java index 3a5866c7b55..9187c1f611b 100644 --- a/examples/src/main/java/ai/djl/examples/inference/cv/PoseEstimation.java +++ b/examples/src/main/java/ai/djl/examples/inference/cv/PoseEstimation.java @@ -12,18 +12,14 @@ */ package ai.djl.examples.inference.cv; -import ai.djl.MalformedModelException; import ai.djl.ModelException; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; -import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.modality.cv.output.Joints; -import ai.djl.modality.cv.output.Rectangle; +import ai.djl.modality.cv.translator.YoloPoseTranslatorFactory; import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ZooModel; -import ai.djl.training.util.ProgressBar; import ai.djl.translate.TranslateException; import org.slf4j.Logger; @@ -33,9 +29,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; +import java.util.Arrays; /** * An example of inference using a pose estimation model. @@ -51,98 +45,39 @@ public final class PoseEstimation { private PoseEstimation() {} public static void main(String[] args) throws IOException, ModelException, TranslateException { - List joints = predict(); - logger.info("{}", joints); + Joints[] joints = predict(); + logger.info("{}", Arrays.toString(joints)); } - public static List predict() throws IOException, ModelException, TranslateException { + public static Joints[] predict() throws IOException, ModelException, TranslateException { Path imageFile = Paths.get("src/test/resources/pose_soccer.png"); Image img = ImageFactory.getInstance().fromFile(imageFile); - List people = predictPeopleInImage(img); - - if (people.isEmpty()) { - logger.warn("No people found in image."); - return Collections.emptyList(); - } - - return predictJointsForPeople(people); - } - - private static List predictPeopleInImage(Image img) - throws MalformedModelException, - ModelNotFoundException, - IOException, - TranslateException { - - Criteria criteria = - Criteria.builder() - .setTypes(Image.class, DetectedObjects.class) - .optModelUrls("djl://ai.djl.mxnet/ssd/0.0.1/ssd_512_resnet50_v1_voc") - .optEngine("MXNet") - .optProgress(new ProgressBar()) - .build(); - - DetectedObjects detectedObjects; - try (ZooModel ssd = criteria.loadModel(); - Predictor predictor = ssd.newPredictor()) { - detectedObjects = predictor.predict(img); - } - - List items = detectedObjects.items(); - List people = new ArrayList<>(); - for (DetectedObjects.DetectedObject item : items) { - if ("person".equals(item.getClassName())) { - Rectangle rect = item.getBoundingBox().getBounds(); - int width = img.getWidth(); - int height = img.getHeight(); - people.add( - img.getSubImage( - (int) (rect.getX() * width), - (int) (rect.getY() * height), - (int) (rect.getWidth() * width), - (int) (rect.getHeight() * height))); - } - } - return people; - } - - private static List predictJointsForPeople(List people) - throws MalformedModelException, - ModelNotFoundException, - IOException, - TranslateException { - - // Use DJL MXNet model zoo model, model can be found: - // https://mlrepo.djl.ai/model/cv/pose_estimation/ai/djl/mxnet/simple_pose/0.0.1/simple_pose_resnet18_v1b-0000.params.gz - // https://mlrepo.djl.ai/model/cv/pose_estimation/ai/djl/mxnet/simple_pose/0.0.1/simple_pose_resnet18_v1b-symbol.json - Criteria criteria = + // Use DJL PyTorch model zoo model + Criteria criteria = Criteria.builder() - .setTypes(Image.class, Joints.class) - .optModelUrls( - "djl://ai.djl.mxnet/simple_pose/0.0.1/simple_pose_resnet18_v1b") + .setTypes(Image.class, Joints[].class) + .optModelUrls("djl://ai.djl.pytorch/yolov8n-pose") + .optTranslatorFactory(new YoloPoseTranslatorFactory()) .build(); - List allJoints = new ArrayList<>(); - try (ZooModel pose = criteria.loadModel(); - Predictor predictor = pose.newPredictor()) { - int count = 0; - for (Image person : people) { - Joints joints = predictor.predict(person); - saveJointsImage(person, joints, count++); - allJoints.add(joints); - } + try (ZooModel pose = criteria.loadModel(); + Predictor predictor = pose.newPredictor()) { + Joints[] allJoints = predictor.predict(img); + saveJointsImage(img, allJoints); + return allJoints; } - return allJoints; } - private static void saveJointsImage(Image img, Joints joints, int count) throws IOException { + private static void saveJointsImage(Image img, Joints[] allJoints) throws IOException { Path outputDir = Paths.get("build/output"); Files.createDirectories(outputDir); - img.drawJoints(joints); + for (Joints joints : allJoints) { + img.drawJoints(joints); + } - Path imagePath = outputDir.resolve("joints-" + count + ".png"); + Path imagePath = outputDir.resolve("joints.png"); // Must use png format because you can't save as jpg with an alpha channel img.save(Files.newOutputStream(imagePath), "png"); logger.info("Pose image has been saved in: {}", imagePath); diff --git a/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java b/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java index d0afdb90be1..398c71ced6e 100644 --- a/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java @@ -15,22 +15,19 @@ import ai.djl.ModelException; import ai.djl.examples.inference.cv.PoseEstimation; import ai.djl.modality.cv.output.Joints; -import ai.djl.testing.TestRequirements; import ai.djl.translate.TranslateException; import org.testng.Assert; import org.testng.annotations.Test; import java.io.IOException; -import java.util.List; public class PoseEstimationTest { @Test public void testPoseEstimation() throws ModelException, TranslateException, IOException { - TestRequirements.linux(); - - List result = PoseEstimation.predict(); - Assert.assertTrue(result.get(0).getJoints().get(0).getConfidence() > 0.6d); + Joints[] result = PoseEstimation.predict(); + Assert.assertEquals(result.length, 3); + Assert.assertEquals(result[0].getJoints().size(), 17); } }