-
Notifications
You must be signed in to change notification settings - Fork 688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Object Detection #1930
Object Detection #1930
Conversation
model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/yolo/YOLOV3.java
Outdated
Show resolved
Hide resolved
model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/object_detection/yolo/YOLOV3.java
Show resolved
Hide resolved
3b530c1
to
7474f50
Compare
@@ -48,6 +48,8 @@ public SingleShotDetectionLoss() { | |||
@Override | |||
protected Pair<NDList, NDList> inputForComponent( | |||
int componentIndex, NDList labels, NDList predictions) { | |||
System.out.println(labels.singletonOrThrow()); //print labels for test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove testing code
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
|
||
public class YOLOv3Loss extends Loss{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
./gradlew formatJava
return inter.div(union); | ||
} | ||
|
||
public ArrayList<PairList<Long,Rectangle>> getTargetFromCurrentLabel(NDArray labels){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use List instead of ArrayList
in function declaration
} | ||
|
||
//calculate IOU is already defined in Rectangle | ||
public NDArray calculateIOU(NDArray boxA,NDArray boxB){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
javadoc
@@ -87,7 +87,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans | |||
try (Trainer trainer = model.newTrainer(config)) { | |||
trainer.setMetrics(new Metrics()); | |||
|
|||
Shape inputShape = new Shape(arguments.getBatchSize(), 3, 256, 256); | |||
Shape inputShape = new Shape(1, 3, 256, 256); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably shouldn't modify the example to use a fixed batch size of 1 rather than the variable version
emm,this is just a test pr for my mentor Zack to get a view on my current progress and make suggestions
At 2022-08-29 13:50:35, "Frank Liu" ***@***.***> wrote:
@frankfliu commented on this pull request.
In api/src/main/java/ai/djl/training/loss/SingleShotDetectionLoss.java:
@@ -48,6 +48,8 @@ public SingleShotDetectionLoss() {
@OverRide
protected Pair<NDList, NDList> inputForComponent(
int componentIndex, NDList labels, NDList predictions) {
+ System.out.println(labels.singletonOrThrow()); //print labels for test
Remove testing code
In api/src/main/java/ai/djl/training/loss/YOLOv3Loss.java:
@@ -0,0 +1,275 @@
+package ai.djl.training.loss;
Add license header
In api/src/main/java/ai/djl/training/loss/YOLOv3Loss.java:
+import ai.djl.modality.cv.output.Rectangle;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDArrays;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.index.NDIndex;
+import ai.djl.ndarray.types.DataType;
+import ai.djl.ndarray.types.Shape;
+import ai.djl.nn.Activation;
+import ai.djl.util.Pair;
+import ai.djl.util.PairList;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+
+public class YOLOv3Loss extends Loss{
./gradlew formatJava
In api/src/main/java/ai/djl/training/loss/YOLOv3Loss.java:
+ boxBTrue.get(":,2:").expandDims(0).broadcast(A,B,2));
+ NDArray minXY = NDArrays.maximum(boxATrue.get(":,:2").expandDims(1).broadcast(A,B,2),
+ boxBTrue.get(":,:2").expandDims(0).broadcast(A,B,2));
+
+ NDArray inter = NDArrays.minimum(maxXY.sub(minXY),0);
+ inter = inter.get(":,:,0").mul(inter.get(":,:,1"));
+
+ //to calculate the area of prediction bbox and true bbox
+ NDArray areaA = boxATrue.get(":,2").sub(boxATrue.get(":,0")).mul(boxATrue.get(":,3").sub(boxATrue.get(":,1"))).expandDims( 1).broadcast(inter.getShape()),
+ areaB = boxBTrue.get(":,2").sub(boxBTrue.get(":,0")).mul(boxBTrue.get(":,3").sub(boxBTrue.get(":,1"))).expandDims( 0).broadcast(inter.getShape());
+
+ NDArray union = areaA.add(areaB).sub(inter);
+ return inter.div(union);
+ }
+
+ public ArrayList<PairList<Long,Rectangle>> getTargetFromCurrentLabel(NDArray labels){
Use List instead of ArrayList in function declaration
In api/src/main/java/ai/djl/training/loss/YOLOv3Loss.java:
+ float interHeight = Math.min(trueBottom,predBottom)-Math.max(trueTop,predTop);
+ float inter = interWidth*interHeight, union = wgt.get(i).getFloat()*inW*hgt.get(i).getFloat()*inH
+ + anchors.get(j,0).getFloat()*anchors.get(j,1).getFloat()-inter;
+ iou.set(new NDIndex(curIndex),inter/union);
+ }
+ }
+ return new NDList(iou,boxLossScale,groundTruth);
+ }
+
+
+
+ return null;
+ }
+
+ //calculate IOU is already defined in Rectangle
+ public NDArray calculateIOU(NDArray boxA,NDArray boxB){
javadoc
In examples/src/main/java/ai/djl/examples/training/TrainPikachu.java:
@@ -87,7 +87,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(new Metrics());
- Shape inputShape = new Shape(arguments.getBatchSize(), 3, 256, 256);
+ Shape inputShape = new Shape(1, 3, 256, 256);
Why?
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you authored the thread.Message ID: ***@***.***>
|
afa3241
to
efcaa88
Compare
/** {@inheritDoc} */ | ||
@Override | ||
public List<String> getClasses() { | ||
return synset; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to fail clearly rather than have something partial that may confuse people that it works. Here, just throw an exception, specifically at UnsupportedOperationException to indicate that it is not yet implemented.
@@ -87,7 +87,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans | |||
try (Trainer trainer = model.newTrainer(config)) { | |||
trainer.setMetrics(new Metrics()); | |||
|
|||
Shape inputShape = new Shape(arguments.getBatchSize(), 3, 256, 256); | |||
Shape inputShape = new Shape(1, 3, 256, 256); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably shouldn't modify the example to use a fixed batch size of 1 rather than the variable version
|
||
public class YOLOv3Test { | ||
@Test | ||
public void testDarkNet() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should add a test for the full yolov3 model as well
1. add Model of Yolov3(completely right I think, at least for Pytorch engine) 2. add Loss function of Yolov3 (still have some small bugs, it can be successfully trained for 4-5 minutes)... 3. add ObjectDetection in djl-zero, and currently just training SingleShotDetection Model 4. For ObjectDetectionDataSet, add a new method getClasses(), because for most models and Translators, we still need the classes of dataset
efcaa88
to
79900ea
Compare
This PR is mainly about yolov3 model in ObjectDetection.
TODO List: