From eeab0754108ee0075761866a352dc12a8ce6e736 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Fri, 5 Jul 2024 17:39:22 -0700 Subject: [PATCH] [api] Refactor drawMask() for instance segmentation --- .../djl/modality/cv/BufferedImageFactory.java | 38 +++++++++++------- .../java/ai/djl/modality/cv/output/Mask.java | 31 ++++++++++++++ .../InstanceSegmentationTranslator.java | 24 +++-------- .../main/java/ai/djl/opencv/OpenCVImage.java | 40 ++++++++++++------- 4 files changed, 86 insertions(+), 47 deletions(-) diff --git a/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java b/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java index 40735ddeca7..d18b6922b7a 100644 --- a/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java +++ b/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java @@ -403,28 +403,38 @@ private void drawMask(Mask mask) { float b = RandomUtils.nextFloat(); int imageWidth = image.getWidth(); int imageHeight = image.getHeight(); - int x = (int) (mask.getX() * imageWidth); - int y = (int) (mask.getY() * imageHeight); - float[][] probDist = mask.getProbDist(); - // Correct some coordinates of box when going out of image - if (x < 0) { - x = 0; - } - if (y < 0) { - y = 0; + int x = 0; + int y = 0; + int w = imageWidth; + int h = imageHeight; + if (!mask.isFullImageMask()) { + x = (int) (mask.getX() * imageWidth); + y = (int) (mask.getY() * imageHeight); + w = (int) (mask.getWidth() * imageWidth); + h = (int) (mask.getHeight() * imageHeight); + // Correct some coordinates of box when going out of image + if (x < 0) { + x = 0; + } + if (y < 0) { + y = 0; + } } + float[][] probDist = mask.getProbDist(); BufferedImage maskImage = new BufferedImage( - probDist.length, probDist[0].length, BufferedImage.TYPE_INT_ARGB); - for (int xCor = 0; xCor < probDist.length; xCor++) { - for (int yCor = 0; yCor < probDist[xCor].length; yCor++) { - float opacity = probDist[xCor][yCor] * 0.8f; + probDist[0].length, probDist.length, BufferedImage.TYPE_INT_ARGB); + for (int yCor = 0; yCor < probDist.length; yCor++) { + for (int xCor = 0; xCor < probDist[0].length; xCor++) { + float opacity = probDist[yCor][xCor] * 0.8f; maskImage.setRGB(xCor, yCor, new Color(r, g, b, opacity).darker().getRGB()); } } + java.awt.Image scaled = maskImage.getScaledInstance(w, h, java.awt.Image.SCALE_SMOOTH); + Graphics2D gR = (Graphics2D) image.getGraphics(); - gR.drawImage(maskImage, x, y, null); + gR.drawImage(scaled, x, y, null); gR.dispose(); } diff --git a/api/src/main/java/ai/djl/modality/cv/output/Mask.java b/api/src/main/java/ai/djl/modality/cv/output/Mask.java index 0c4a9fd6688..e9ada0710fc 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/Mask.java +++ b/api/src/main/java/ai/djl/modality/cv/output/Mask.java @@ -21,6 +21,7 @@ public class Mask extends Rectangle { private static final long serialVersionUID = 1L; private float[][] probDist; + private boolean fullImageMask; /** * Constructs a Mask with the given data. @@ -32,8 +33,29 @@ public class Mask extends Rectangle { * @param dist the probability distribution for each pixel in the rectangle */ public Mask(double x, double y, double width, double height, float[][] dist) { + this(x, y, width, height, dist, false); + } + + /** + * Constructs a Mask with the given data. + * + * @param x the left coordinate of the bounding rectangle + * @param y the top coordinate of the bounding rectangle + * @param width the width of the bounding rectangle + * @param height the height of the bounding rectangle + * @param dist the probability distribution for each pixel in the rectangle + * @param fullImageMask if the mask if for full image + */ + public Mask( + double x, + double y, + double width, + double height, + float[][] dist, + boolean fullImageMask) { super(x, y, width, height); this.probDist = dist; + this.fullImageMask = fullImageMask; } /** @@ -44,4 +66,13 @@ public Mask(double x, double y, double width, double height, float[][] dist) { public float[][] getProbDist() { return probDist; } + + /** + * Returns if the mask is for full image. + * + * @return if the mask is for full image + */ + public boolean isFullImageMask() { + return fullImageMask; + } } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslator.java index efd05d7be26..458122bbc0b 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/InstanceSegmentationTranslator.java @@ -12,7 +12,6 @@ */ package ai.djl.modality.cv.translator; -import ai.djl.modality.cv.Image; import ai.djl.modality.cv.output.BoundingBox; import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.modality.cv.output.Mask; @@ -68,14 +67,6 @@ public void prepare(TranslatorContext ctx) throws IOException { } } - /** {@inheritDoc} */ - @Override - public NDList processInput(TranslatorContext ctx, Image image) { - ctx.setAttachment("originalHeight", image.getHeight()); - ctx.setAttachment("originalWidth", image.getWidth()); - return super.processInput(ctx, image); - } - /** {@inheritDoc} */ @Override public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { @@ -102,18 +93,15 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { double w = box[2] / rescaledWidth - x; double h = box[3] / rescaledHeight - y; - int maskW = (int) (w * (int) ctx.getAttachment("originalWidth")); - int maskH = (int) (h * (int) ctx.getAttachment("originalHeight")); - // Reshape mask to actual image bounding box shape. NDArray array = masks.get(i); Shape maskShape = array.getShape(); - array = array.reshape(maskShape.addAll(new Shape(1))); - NDArray maskArray = NDImageUtils.resize(array, maskW, maskH).transpose(); - float[] flattened = maskArray.toFloatArray(); - float[][] maskFloat = new float[maskW][maskH]; - for (int j = 0; j < maskW; j++) { - System.arraycopy(flattened, j * maskH, maskFloat[j], 0, maskH); + int maskH = (int) maskShape.get(0); + int maskW = (int) maskShape.get(1); + float[] flattened = array.toFloatArray(); + float[][] maskFloat = new float[maskH][maskW]; + for (int j = 0; j < maskH; j++) { + System.arraycopy(flattened, j * maskW, maskFloat[j], 0, maskW); } Mask mask = new Mask(x, y, w, h, maskFloat); diff --git a/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java b/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java index b6a2292714e..1cbacd63464 100644 --- a/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java +++ b/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java @@ -356,27 +356,37 @@ private void drawMask(BufferedImage img, Mask mask) { float b = RandomUtils.nextFloat(); int imageWidth = img.getWidth(); int imageHeight = img.getHeight(); - int x = (int) (mask.getX() * imageWidth); - int y = (int) (mask.getY() * imageHeight); - float[][] probDist = mask.getProbDist(); - // Correct some coordinates of box when going out of image - if (x < 0) { - x = 0; - } - if (y < 0) { - y = 0; + int x = 0; + int y = 0; + int w = imageWidth; + int h = imageHeight; + if (!mask.isFullImageMask()) { + x = (int) (mask.getX() * imageWidth); + y = (int) (mask.getY() * imageHeight); + w = (int) (mask.getWidth() * imageWidth); + h = (int) (mask.getHeight() * imageHeight); + // Correct some coordinates of box when going out of image + if (x < 0) { + x = 0; + } + if (y < 0) { + y = 0; + } } + float[][] probDist = mask.getProbDist(); BufferedImage maskImage = - new BufferedImage(probDist.length, probDist[0].length, BufferedImage.TYPE_INT_ARGB); - for (int xCor = 0; xCor < probDist.length; xCor++) { - for (int yCor = 0; yCor < probDist[xCor].length; yCor++) { - float opacity = probDist[xCor][yCor] * 0.8f; - maskImage.setRGB(xCor, yCor, new Color(r, g, b, opacity).getRGB()); + new BufferedImage(probDist[0].length, probDist.length, BufferedImage.TYPE_INT_ARGB); + for (int yCor = 0; yCor < probDist.length; yCor++) { + for (int xCor = 0; xCor < probDist[0].length; xCor++) { + float opacity = probDist[yCor][xCor] * 0.8f; + maskImage.setRGB(xCor, yCor, new Color(r, g, b, opacity).darker().getRGB()); } } + java.awt.Image scaled = maskImage.getScaledInstance(w, h, java.awt.Image.SCALE_SMOOTH); + Graphics2D gR = (Graphics2D) img.getGraphics(); - gR.drawImage(maskImage, x, y, null); + gR.drawImage(scaled, x, y, null); gR.dispose(); }