Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Refactor drawMask() for instance segmentation #3304

Merged
merged 1 commit into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -403,28 +403,38 @@ private void drawMask(Mask mask) {
float b = RandomUtils.nextFloat();
int imageWidth = image.getWidth();
int imageHeight = image.getHeight();
int x = (int) (mask.getX() * imageWidth);
int y = (int) (mask.getY() * imageHeight);
float[][] probDist = mask.getProbDist();
// Correct some coordinates of box when going out of image
if (x < 0) {
x = 0;
}
if (y < 0) {
y = 0;
int x = 0;
int y = 0;
int w = imageWidth;
int h = imageHeight;
if (!mask.isFullImageMask()) {
x = (int) (mask.getX() * imageWidth);
y = (int) (mask.getY() * imageHeight);
w = (int) (mask.getWidth() * imageWidth);
h = (int) (mask.getHeight() * imageHeight);
// Correct some coordinates of box when going out of image
if (x < 0) {
x = 0;
}
if (y < 0) {
y = 0;
}
}
float[][] probDist = mask.getProbDist();

BufferedImage maskImage =
new BufferedImage(
probDist.length, probDist[0].length, BufferedImage.TYPE_INT_ARGB);
for (int xCor = 0; xCor < probDist.length; xCor++) {
for (int yCor = 0; yCor < probDist[xCor].length; yCor++) {
float opacity = probDist[xCor][yCor] * 0.8f;
probDist[0].length, probDist.length, BufferedImage.TYPE_INT_ARGB);
for (int yCor = 0; yCor < probDist.length; yCor++) {
for (int xCor = 0; xCor < probDist[0].length; xCor++) {
float opacity = probDist[yCor][xCor] * 0.8f;
maskImage.setRGB(xCor, yCor, new Color(r, g, b, opacity).darker().getRGB());
}
}
java.awt.Image scaled = maskImage.getScaledInstance(w, h, java.awt.Image.SCALE_SMOOTH);

Graphics2D gR = (Graphics2D) image.getGraphics();
gR.drawImage(maskImage, x, y, null);
gR.drawImage(scaled, x, y, null);
gR.dispose();
}

Expand Down
31 changes: 31 additions & 0 deletions api/src/main/java/ai/djl/modality/cv/output/Mask.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class Mask extends Rectangle {

private static final long serialVersionUID = 1L;
private float[][] probDist;
private boolean fullImageMask;

/**
* Constructs a Mask with the given data.
Expand All @@ -32,8 +33,29 @@ public class Mask extends Rectangle {
* @param dist the probability distribution for each pixel in the rectangle
*/
public Mask(double x, double y, double width, double height, float[][] dist) {
this(x, y, width, height, dist, false);
}

/**
* Constructs a Mask with the given data.
*
* @param x the left coordinate of the bounding rectangle
* @param y the top coordinate of the bounding rectangle
* @param width the width of the bounding rectangle
* @param height the height of the bounding rectangle
* @param dist the probability distribution for each pixel in the rectangle
* @param fullImageMask if the mask if for full image
*/
public Mask(
double x,
double y,
double width,
double height,
float[][] dist,
boolean fullImageMask) {
super(x, y, width, height);
this.probDist = dist;
this.fullImageMask = fullImageMask;
}

/**
Expand All @@ -44,4 +66,13 @@ public Mask(double x, double y, double width, double height, float[][] dist) {
public float[][] getProbDist() {
return probDist;
}

/**
* Returns if the mask is for full image.
*
* @return if the mask is for full image
*/
public boolean isFullImageMask() {
return fullImageMask;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Mask;
Expand Down Expand Up @@ -68,14 +67,6 @@ public void prepare(TranslatorContext ctx) throws IOException {
}
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Image image) {
ctx.setAttachment("originalHeight", image.getHeight());
ctx.setAttachment("originalWidth", image.getWidth());
return super.processInput(ctx, image);
}

/** {@inheritDoc} */
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
Expand All @@ -102,18 +93,15 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
double w = box[2] / rescaledWidth - x;
double h = box[3] / rescaledHeight - y;

int maskW = (int) (w * (int) ctx.getAttachment("originalWidth"));
int maskH = (int) (h * (int) ctx.getAttachment("originalHeight"));

// Reshape mask to actual image bounding box shape.
NDArray array = masks.get(i);
Shape maskShape = array.getShape();
array = array.reshape(maskShape.addAll(new Shape(1)));
NDArray maskArray = NDImageUtils.resize(array, maskW, maskH).transpose();
float[] flattened = maskArray.toFloatArray();
float[][] maskFloat = new float[maskW][maskH];
for (int j = 0; j < maskW; j++) {
System.arraycopy(flattened, j * maskH, maskFloat[j], 0, maskH);
int maskH = (int) maskShape.get(0);
int maskW = (int) maskShape.get(1);
float[] flattened = array.toFloatArray();
float[][] maskFloat = new float[maskH][maskW];
for (int j = 0; j < maskH; j++) {
System.arraycopy(flattened, j * maskW, maskFloat[j], 0, maskW);
}
Mask mask = new Mask(x, y, w, h, maskFloat);

Expand Down
40 changes: 25 additions & 15 deletions extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java
Original file line number Diff line number Diff line change
Expand Up @@ -356,27 +356,37 @@ private void drawMask(BufferedImage img, Mask mask) {
float b = RandomUtils.nextFloat();
int imageWidth = img.getWidth();
int imageHeight = img.getHeight();
int x = (int) (mask.getX() * imageWidth);
int y = (int) (mask.getY() * imageHeight);
float[][] probDist = mask.getProbDist();
// Correct some coordinates of box when going out of image
if (x < 0) {
x = 0;
}
if (y < 0) {
y = 0;
int x = 0;
int y = 0;
int w = imageWidth;
int h = imageHeight;
if (!mask.isFullImageMask()) {
x = (int) (mask.getX() * imageWidth);
y = (int) (mask.getY() * imageHeight);
w = (int) (mask.getWidth() * imageWidth);
h = (int) (mask.getHeight() * imageHeight);
// Correct some coordinates of box when going out of image
if (x < 0) {
x = 0;
}
if (y < 0) {
y = 0;
}
}
float[][] probDist = mask.getProbDist();

BufferedImage maskImage =
new BufferedImage(probDist.length, probDist[0].length, BufferedImage.TYPE_INT_ARGB);
for (int xCor = 0; xCor < probDist.length; xCor++) {
for (int yCor = 0; yCor < probDist[xCor].length; yCor++) {
float opacity = probDist[xCor][yCor] * 0.8f;
maskImage.setRGB(xCor, yCor, new Color(r, g, b, opacity).getRGB());
new BufferedImage(probDist[0].length, probDist.length, BufferedImage.TYPE_INT_ARGB);
for (int yCor = 0; yCor < probDist.length; yCor++) {
for (int xCor = 0; xCor < probDist[0].length; xCor++) {
float opacity = probDist[yCor][xCor] * 0.8f;
maskImage.setRGB(xCor, yCor, new Color(r, g, b, opacity).darker().getRGB());
}
}
java.awt.Image scaled = maskImage.getScaledInstance(w, h, java.awt.Image.SCALE_SMOOTH);

Graphics2D gR = (Graphics2D) img.getGraphics();
gR.drawImage(maskImage, x, y, null);
gR.drawImage(scaled, x, y, null);
gR.dispose();
}

Expand Down
Loading