Skip to content

Commit

Permalink
[api] Reactor nms for yolo translator
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Jul 6, 2024
1 parent bcf8fb3 commit 6192fd6
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 131 deletions.
78 changes: 78 additions & 0 deletions api/src/main/java/ai/djl/modality/cv/output/Rectangle.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;

/**
* A {@code Rectangle} specifies an area in a coordinate space that is enclosed by the {@code
Expand Down Expand Up @@ -152,4 +153,81 @@ public String toString() {
return String.format(
"{\"x\"=%.3f, \"y\"=%.3f, \"width\"=%.3f, \"height\"=%.3f}", x, y, width, height);
}

/**
* Applies nms (non-maximum suppression) to the list of rectangles.
*
* @param boxes an list of {@code Rectangle}
* @param scores a list of scores
* @param nmsThreshold the nms threshold
* @return the filtered list with the index of the original list
*/
public static List<Integer> nms(
List<Rectangle> boxes, List<Double> scores, float nmsThreshold) {
List<Integer> ret = new ArrayList<>();
PriorityQueue<Integer> pq =
new PriorityQueue<>(
50,
(lhs, rhs) -> {
// Intentionally reversed to put high confidence at the head of the
// queue.
return Double.compare(scores.get(rhs), scores.get(lhs));
});
for (int i = 0; i < boxes.size(); ++i) {
pq.add(i);
}

// do non maximum suppression
while (!pq.isEmpty()) {
// insert detection with max confidence
int[] detections = pq.stream().mapToInt(Integer::intValue).toArray();
ret.add(detections[0]);
Rectangle box = boxes.get(detections[0]);
pq.clear();
for (int i = 1; i < detections.length; i++) {
int detection = detections[i];
Rectangle location = boxes.get(detection);
if (box.boxIou(location) < nmsThreshold) {
pq.add(detection);
}
}
}
return ret;
}

private double boxIou(Rectangle other) {
double intersection = intersection(other);
double union =
getWidth() * getHeight() + other.getWidth() * other.getHeight() - intersection;
return intersection / union;
}

private double intersection(Rectangle b) {
double w =
overlap(
(getX() * 2 + getWidth()) / 2,
getWidth(),
(b.getX() * 2 + b.getWidth()) / 2,
b.getWidth());
double h =
overlap(
(getY() * 2 + getHeight()) / 2,
getHeight(),
(b.getY() * 2 + b.getHeight()) / 2,
b.getHeight());
if (w < 0 || h < 0) {
return 0;
}
return w * h;
}

private double overlap(double x1, double w1, double x2, double w2) {
double l1 = x1 - w1 / 2;
double l2 = x2 - w2 / 2;
double left = Math.max(l1, l2);
double r1 = x1 + w1 / 2;
double r2 = x2 + w2 / 2;
double right = Math.min(r1, r2);
return right - left;
}
}
161 changes: 35 additions & 126 deletions api/src/main/java/ai/djl/modality/cv/translator/YoloV5Translator.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.PriorityQueue;

/**
* A translator for YoloV5 models. This was tested with ONNX exported Yolo models. For details check
Expand Down Expand Up @@ -68,104 +67,58 @@ public static YoloV5Translator.Builder builder(Map<String, ?> arguments) {
return builder;
}

protected double boxIntersection(Rectangle a, Rectangle b) {
double w =
overlap(
(a.getX() * 2 + a.getWidth()) / 2,
a.getWidth(),
(b.getX() * 2 + b.getWidth()) / 2,
b.getWidth());
double h =
overlap(
(a.getY() * 2 + a.getHeight()) / 2,
a.getHeight(),
(b.getY() * 2 + b.getHeight()) / 2,
b.getHeight());
if (w < 0 || h < 0) {
return 0;
}
return w * h;
}

protected double boxIou(Rectangle a, Rectangle b) {
return boxIntersection(a, b) / boxUnion(a, b);
}

protected double boxUnion(Rectangle a, Rectangle b) {
double i = boxIntersection(a, b);
return (a.getWidth()) * (a.getHeight()) + (b.getWidth()) * (b.getHeight()) - i;
}

protected DetectedObjects nms(List<IntermediateResult> list) {
protected DetectedObjects nms(
List<Rectangle> boxes, List<Integer> classIds, List<Float> scores) {
List<String> retClasses = new ArrayList<>();
List<Double> retProbs = new ArrayList<>();
List<BoundingBox> retBB = new ArrayList<>();

for (int k = 0; k < classes.size(); k++) {
// 1.find max confidence per class
PriorityQueue<IntermediateResult> pq =
new PriorityQueue<>(
50,
(lhs, rhs) -> {
// Intentionally reversed to put high confidence at the head of the
// queue.
return Double.compare(rhs.getConfidence(), lhs.getConfidence());
});

for (IntermediateResult intermediateResult : list) {
if (intermediateResult.getDetectedClass() == k) {
pq.add(intermediateResult);
for (int classId = 0; classId < classes.size(); classId++) {
List<Rectangle> r = new ArrayList<>();
List<Double> s = new ArrayList<>();
List<Integer> map = new ArrayList<>();
for (int j = 0; j < classIds.size(); ++j) {
if (classIds.get(j) == classId) {
r.add(boxes.get(j));
s.add(scores.get(j).doubleValue());
map.add(j);
}
}

// 2.do non maximum suppression
while (pq.size() > 0) {
// insert detection with max confidence
IntermediateResult[] a = new IntermediateResult[pq.size()];
IntermediateResult[] detections = pq.toArray(a);
Rectangle rec = detections[0].getLocation();
retClasses.add(detections[0].id);
retProbs.add(detections[0].confidence);
if (r.isEmpty()) {
continue;
}
List<Integer> nms = Rectangle.nms(r, s, nmsThreshold);
for (int index : nms) {
int pos = map.get(index);
int id = classIds.get(pos);
retClasses.add(classes.get(id));
retProbs.add(scores.get(pos).doubleValue());
Rectangle rect = boxes.get(pos);
if (applyRatio) {
retBB.add(
new Rectangle(
rec.getX() / imageWidth,
rec.getY() / imageHeight,
rec.getWidth() / imageWidth,
rec.getHeight() / imageHeight));
rect.getX() / imageWidth,
rect.getY() / imageHeight,
rect.getWidth() / imageWidth,
rect.getHeight() / imageHeight));
} else {
retBB.add(
new Rectangle(rec.getX(), rec.getY(), rec.getWidth(), rec.getHeight()));
}
pq.clear();
for (int j = 1; j < detections.length; j++) {
IntermediateResult detection = detections[j];
Rectangle location = detection.getLocation();
if (boxIou(rec, location) < nmsThreshold) {
pq.add(detection);
}
retBB.add(rect);
}
}
}
return new DetectedObjects(retClasses, retProbs, retBB);
}

protected double overlap(double x1, double w1, double x2, double w2) {
double l1 = x1 - w1 / 2;
double l2 = x2 - w2 / 2;
double left = Math.max(l1, l2);
double r1 = x1 + w1 / 2;
double r2 = x2 + w2 / 2;
double right = Math.min(r1, r2);
return right - left;
}

protected DetectedObjects processFromBoxOutput(NDList list) {
float[] flattened = list.get(0).toFloatArray();
ArrayList<IntermediateResult> intermediateResults = new ArrayList<>();
int sizeClasses = classes.size();
int stride = 5 + sizeClasses;
int size = flattened.length / stride;

ArrayList<Rectangle> boxes = new ArrayList<>();
ArrayList<Float> scores = new ArrayList<>();
ArrayList<Integer> classIds = new ArrayList<>();

for (int i = 0; i < size; i++) {
int indexBase = i * stride;
float maxClass = 0;
Expand All @@ -184,11 +137,12 @@ protected DetectedObjects processFromBoxOutput(NDList list) {
float h = flattened[indexBase + 3];
Rectangle rect =
new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h);
intermediateResults.add(
new IntermediateResult(classes.get(maxIndex), score, maxIndex, rect));
boxes.add(rect);
scores.add(score);
classIds.add(maxIndex);
}
}
return nms(intermediateResults);
return nms(boxes, classIds, scores);
}

private DetectedObjects processFromDetectOutput() {
Expand Down Expand Up @@ -279,49 +233,4 @@ public YoloV5Translator build() {
return new YoloV5Translator(this);
}
}

protected static final class IntermediateResult {

/**
* A sortable score for how good the recognition is relative to others. Higher should be
* better.
*/
private double confidence;

/** Display name for the recognition. */
private int detectedClass;

/**
* A unique identifier for what has been recognized. Specific to the class, not the instance
* of the object.
*/
private String id;

/** Optional location within the source image for the location of the recognized object. */
private Rectangle location;

IntermediateResult(String id, double confidence, int detectedClass, Rectangle location) {
this.confidence = confidence;
this.id = id;
this.detectedClass = detectedClass;
this.location = location;
}

public double getConfidence() {
return confidence;
}

public int getDetectedClass() {
return detectedClass;
}

public String getId() {
return id;
}

public Rectangle getLocation() {
return new Rectangle(
location.getX(), location.getY(), location.getWidth(), location.getHeight());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ protected DetectedObjects processFromBoxOutput(NDList list) {
"Expected classes: " + (nClasses - 4) + ", got " + classes.size());
}

ArrayList<IntermediateResult> intermediateResults = new ArrayList<>();
ArrayList<Rectangle> boxes = new ArrayList<>();
ArrayList<Float> scores = new ArrayList<>();
ArrayList<Integer> classIds = new ArrayList<>();

// reverse order search in heap; searches through #maxBoxes for optimization when set
for (int i = numberRows - 1; i > numberRows - maxBoxes; --i) {
int index = i * nClasses;
Expand All @@ -101,12 +104,13 @@ protected DetectedObjects processFromBoxOutput(NDList list) {
float h = buf[index + 3];
Rectangle rect =
new Rectangle(Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), w, h);
intermediateResults.add(
new IntermediateResult(
classes.get(maxIndex), maxClassProb, maxIndex, rect));
boxes.add(rect);
scores.add(maxClassProb);
classIds.add(maxIndex);
}
}
return nms(intermediateResults);

return nms(boxes, classIds, scores);
}

/** The builder for {@link YoloV8Translator}. */
Expand Down

0 comments on commit 6192fd6

Please sign in to comment.