From 429a26f2fec7830633e65e2d334a2e9e42128e06 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sat, 17 Aug 2024 21:10:52 -0700 Subject: [PATCH] [api] Adds center fit image operation for Yolo (#3425) --- .../djl/modality/cv/transform/CenterFit.java | 59 +++++++++++++++++++ .../cv/translator/BaseImageTranslator.java | 11 ++++ .../translator/ObjectDetectionTranslator.java | 46 ++------------- .../SingleShotDetectionTranslator.java | 17 ++---- .../cv/translator/YoloPoseTranslator.java | 4 -- .../YoloSegmentationTranslator.java | 4 -- .../cv/translator/YoloTranslator.java | 16 ++--- .../cv/translator/YoloV5Translator.java | 40 +++++++++---- .../cv/translator/YoloV8Translator.java | 6 +- 9 files changed, 119 insertions(+), 84 deletions(-) create mode 100644 api/src/main/java/ai/djl/modality/cv/transform/CenterFit.java diff --git a/api/src/main/java/ai/djl/modality/cv/transform/CenterFit.java b/api/src/main/java/ai/djl/modality/cv/transform/CenterFit.java new file mode 100644 index 00000000000..17b910ccfda --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/transform/CenterFit.java @@ -0,0 +1,59 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.transform; + +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.Transform; + +/** A {@link Transform} that fit the size of an image. */ +public class CenterFit implements Transform { + + private int width; + private int height; + + /** + * Creates a {@code CenterFit} {@link Transform} that fit to the given width and height with + * given interpolation. + * + * @param width the desired width + * @param height the desired height + */ + public CenterFit(int width, int height) { + this.width = width; + this.height = height; + } + + /** {@inheritDoc} */ + @Override + public NDArray transform(NDArray array) { + Shape shape = array.getShape(); + int w = (int) shape.get(1); + int h = (int) shape.get(0); + if (w > width || h > height) { + array = NDImageUtils.centerCrop(array, Math.min(w, width), Math.min(h, height)); + } + int padW = width - w; + int padH = height - h; + if (padW > 0 || padH > 0) { + padW = Math.max(0, padW); + padH = Math.max(0, padH); + int padW1 = padW / 2; + int padH1 = padH / 2; + Shape padding = new Shape(0, 0, padW1, padW - padW1, padH1, padH - padH1); + array = array.pad(padding, 0); + } + return array; + } +} diff --git a/api/src/main/java/ai/djl/modality/cv/translator/BaseImageTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/BaseImageTranslator.java index a0352ff8b37..3a66df73214 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/BaseImageTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/BaseImageTranslator.java @@ -15,6 +15,7 @@ import ai.djl.Model; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.transform.CenterCrop; +import ai.djl.modality.cv.transform.CenterFit; import ai.djl.modality.cv.transform.Normalize; import ai.djl.modality.cv.transform.Resize; import ai.djl.modality.cv.transform.ToTensor; @@ -50,6 +51,8 @@ public abstract class BaseImageTranslator implements Translator { private Image.Flag flag; private Batchifier batchifier; + protected int width; + protected int height; /** * Constructs an ImageTranslator with the provided builder. @@ -60,6 +63,8 @@ public BaseImageTranslator(BaseBuilder builder) { flag = builder.flag; pipeline = builder.pipeline; batchifier = builder.batchifier; + width = builder.width; + height = builder.height; } /** {@inheritDoc} */ @@ -72,6 +77,8 @@ public Batchifier getBatchifier() { @Override public NDList processInput(TranslatorContext ctx, Image input) { NDArray array = input.toNDArray(ctx.getNDManager(), flag); + ctx.setAttachment("width", input.getWidth()); + ctx.setAttachment("height", input.getHeight()); return pipeline.transform(new NDList(array)); } @@ -171,6 +178,10 @@ protected void configPreProcess(Map arguments) { if (ArgumentsUtil.booleanValue(arguments, "centerCrop", false)) { addTransform(new CenterCrop(width, height)); } + String centerFit = ArgumentsUtil.stringValue(arguments, "centerFit", "false"); + if ("true".equals(centerFit)) { + addTransform(new CenterFit(width, height)); + } if (ArgumentsUtil.booleanValue(arguments, "toTensor", true)) { addTransform(new ToTensor()); } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/ObjectDetectionTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/ObjectDetectionTranslator.java index 64e2feadbad..f16d1b1c687 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/ObjectDetectionTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/ObjectDetectionTranslator.java @@ -29,9 +29,8 @@ public abstract class ObjectDetectionTranslator extends BaseImageTranslator classes; - protected double imageWidth; - protected double imageHeight; protected boolean applyRatio; + protected boolean removePadding; /** * Creates the {@link ObjectDetectionTranslator} from the given builder. @@ -42,9 +41,8 @@ protected ObjectDetectionTranslator(ObjectDetectionBuilder builder) { super(builder); this.threshold = builder.threshold; this.synsetLoader = builder.synsetLoader; - this.imageWidth = builder.imageWidth; - this.imageHeight = builder.imageHeight; this.applyRatio = builder.applyRatio; + this.removePadding = builder.removePadding; } /** {@inheritDoc} */ @@ -61,9 +59,8 @@ public abstract static class ObjectDetectionBuilder { protected float threshold = 0.2f; - protected double imageWidth; - protected double imageHeight; protected boolean applyRatio; + protected boolean removePadding; /** * Sets the threshold for prediction accuracy. @@ -78,19 +75,6 @@ public T optThreshold(float threshold) { return self(); } - /** - * Sets the optional rescale size. - * - * @param imageWidth the width to rescale images to - * @param imageHeight the height to rescale images to - * @return this builder - */ - public T optRescaleSize(double imageWidth, double imageHeight) { - this.imageWidth = imageWidth; - this.imageHeight = imageHeight; - return self(); - } - /** * Determine Whether to divide output object width/height on the inference result. Default * false. @@ -108,37 +92,17 @@ public T optApplyRatio(boolean value) { return self(); } - /** - * Get resized image width. - * - * @return image width - */ - public double getImageWidth() { - return imageWidth; - } - - /** - * Get resized image height. - * - * @return image height - */ - public double getImageHeight() { - return imageHeight; - } - /** {@inheritDoc} */ @Override protected void configPostProcess(Map arguments) { super.configPostProcess(arguments); - if (ArgumentsUtil.booleanValue(arguments, "rescale")) { - optRescaleSize(width, height); - } if (ArgumentsUtil.booleanValue(arguments, "optApplyRatio") || ArgumentsUtil.booleanValue(arguments, "applyRatio")) { optApplyRatio(true); - optRescaleSize(width, height); } threshold = ArgumentsUtil.floatValue(arguments, "threshold", 0.2f); + String centerFit = ArgumentsUtil.stringValue(arguments, "centerFit", "false"); + removePadding = "true".equals(centerFit); } } } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/SingleShotDetectionTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/SingleShotDetectionTranslator.java index 40b24f26eb8..a04780f3fbb 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/SingleShotDetectionTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/SingleShotDetectionTranslator.java @@ -59,19 +59,14 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { } String className = classes.get(classId); float[] box = boundingBoxes.get(i).toFloatArray(); - // rescale box coordinates by imageWidth and imageHeight - double x = imageWidth > 0 ? box[0] / imageWidth : box[0]; - double y = imageHeight > 0 ? box[1] / imageHeight : box[1]; - double w = imageWidth > 0 ? box[2] / imageWidth - x : box[2] - x; - double h = imageHeight > 0 ? box[3] / imageHeight - y : box[3] - y; + // rescale box coordinates by width and height + double x = width > 0 ? box[0] / width : box[0]; + double y = height > 0 ? box[1] / height : box[1]; + double w = width > 0 ? box[2] / width - x : box[2] - x; + double h = height > 0 ? box[3] / height - y : box[3] - y; Rectangle rect; if (applyRatio) { - rect = - new Rectangle( - x / imageWidth, - y / imageHeight, - w / imageWidth, - h / imageHeight); + rect = new Rectangle(x / width, y / height, w / width, h / height); } else { rect = new Rectangle(x, y, w, h); } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslator.java index b59da702078..45b1da77b28 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloPoseTranslator.java @@ -29,8 +29,6 @@ public class YoloPoseTranslator extends BaseImageTranslator { private static final int MAX_DETECTION = 300; - private int width; - private int height; private float threshold; private float nmsThreshold; @@ -41,8 +39,6 @@ public class YoloPoseTranslator extends BaseImageTranslator { */ public YoloPoseTranslator(Builder builder) { super(builder); - this.width = builder.width; - this.height = builder.height; this.threshold = builder.threshold; this.nmsThreshold = builder.nmsThreshold; } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloSegmentationTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloSegmentationTranslator.java index 1d0ad777d1a..ff0f4b4e04e 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloSegmentationTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloSegmentationTranslator.java @@ -33,8 +33,6 @@ public class YoloSegmentationTranslator extends YoloV5Translator { private float threshold; private float nmsThreshold; - private int width; - private int height; /** * Creates the instance segmentation translator from the given builder. @@ -45,8 +43,6 @@ public YoloSegmentationTranslator(Builder builder) { super(builder); this.threshold = builder.threshold; this.nmsThreshold = builder.nmsThreshold; - this.width = builder.width; - this.height = builder.height; } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloTranslator.java index 6cb411158b3..6d54336c91e 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloTranslator.java @@ -44,10 +44,10 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { NDArray boundingBoxes = list.get(2); int detected = Math.toIntExact(probs.length); - NDArray xMin = boundingBoxes.get(":, 0").clip(0, imageWidth).div(imageWidth); - NDArray yMin = boundingBoxes.get(":, 1").clip(0, imageHeight).div(imageHeight); - NDArray xMax = boundingBoxes.get(":, 2").clip(0, imageWidth).div(imageWidth); - NDArray yMax = boundingBoxes.get(":, 3").clip(0, imageHeight).div(imageHeight); + NDArray xMin = boundingBoxes.get(":, 0").clip(0, width).div(width); + NDArray yMin = boundingBoxes.get(":, 1").clip(0, height).div(height); + NDArray xMax = boundingBoxes.get(":, 2").clip(0, width).div(width); + NDArray yMax = boundingBoxes.get(":, 3").clip(0, height).div(height); float[] boxX = xMin.toFloatArray(); float[] boxY = yMin.toFloatArray(); @@ -67,10 +67,10 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { if (applyRatio) { rect = new Rectangle( - boxX[i] / imageWidth, - boxY[i] / imageHeight, - boxWidth[i] / imageWidth, - boxHeight[i] / imageHeight); + boxX[i] / width, + boxY[i] / height, + boxWidth[i] / width, + boxHeight[i] / height); } else { rect = new Rectangle(boxX[i], boxY[i], boxWidth[i], boxHeight[i]); } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java index 67682d7c8b7..2cb80418d9f 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java @@ -27,7 +27,7 @@ /** * A translator for YoloV5 models. This was tested with ONNX exported Yolo models. For details check - * here: https://github.com/ultralytics/yolov5 + * here */ public class YoloV5Translator extends ObjectDetectionTranslator { @@ -68,7 +68,11 @@ public static YoloV5Translator.Builder builder(Map arguments) { } protected DetectedObjects nms( - List boxes, List classIds, List scores) { + int imageWidth, + int imageHeight, + List boxes, + List classIds, + List scores) { List retClasses = new ArrayList<>(); List retProbs = new ArrayList<>(); List retBB = new ArrayList<>(); @@ -94,22 +98,30 @@ protected DetectedObjects nms( retClasses.add(classes.get(id)); retProbs.add(scores.get(pos).doubleValue()); Rectangle rect = boxes.get(pos); - if (applyRatio) { - retBB.add( + if (removePadding) { + int padW = (width - imageWidth) / 2; + int padH = (height - imageHeight) / 2; + rect = new Rectangle( - rect.getX() / imageWidth, - rect.getY() / imageHeight, + (rect.getX() - padW) / imageWidth, + (rect.getY() - padH) / imageHeight, rect.getWidth() / imageWidth, - rect.getHeight() / imageHeight)); - } else { - retBB.add(rect); + rect.getHeight() / imageHeight); + } else if (applyRatio) { + rect = + new Rectangle( + rect.getX() / width, + rect.getY() / height, + rect.getWidth() / width, + rect.getHeight() / height); } + retBB.add(rect); } } return new DetectedObjects(retClasses, retProbs, retBB); } - protected DetectedObjects processFromBoxOutput(NDList list) { + protected DetectedObjects processFromBoxOutput(int imageWidth, int imageHeight, NDList list) { float[] flattened = list.get(0).toFloatArray(); int sizeClasses = classes.size(); int stride = 5 + sizeClasses; @@ -142,7 +154,7 @@ protected DetectedObjects processFromBoxOutput(NDList list) { classIds.add(maxIndex); } } - return nms(boxes, classIds, scores); + return nms(imageWidth, imageHeight, boxes, classIds, scores); } private DetectedObjects processFromDetectOutput() { @@ -153,6 +165,8 @@ private DetectedObjects processFromDetectOutput() { /** {@inheritDoc} */ @Override public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { + int imageWidth = (Integer) ctx.getAttachment("width"); + int imageHeight = (Integer) ctx.getAttachment("height"); switch (yoloOutputLayerType) { case DETECT: return processFromDetectOutput(); @@ -160,11 +174,11 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { if (list.get(0).getShape().dimension() > 2) { return processFromDetectOutput(); } else { - return processFromBoxOutput(list); + return processFromBoxOutput(imageWidth, imageHeight, list); } case BOX: default: - return processFromBoxOutput(list); + return processFromBoxOutput(imageWidth, imageHeight, list); } } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java index 2101d9bbb47..7c4270ebb4e 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/YoloV8Translator.java @@ -25,7 +25,7 @@ /** * A translator for YoloV8 models. This was tested with ONNX exported Yolo models. For details check - * here: https://github.com/ultralytics/ultralytics + * here */ public class YoloV8Translator extends YoloV5Translator { @@ -66,7 +66,7 @@ public static Builder builder(Map arguments) { /** {@inheritDoc} */ @Override - protected DetectedObjects processFromBoxOutput(NDList list) { + protected DetectedObjects processFromBoxOutput(int imageWidth, int imageHeight, NDList list) { NDArray rawResult = list.get(0); NDArray reshapedResult = rawResult.transpose(); Shape shape = reshapedResult.getShape(); @@ -110,7 +110,7 @@ protected DetectedObjects processFromBoxOutput(NDList list) { } } - return nms(boxes, classIds, scores); + return nms(imageWidth, imageHeight, boxes, classIds, scores); } /** The builder for {@link YoloV8Translator}. */