Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Object Detection #1930

Merged
merged 2 commits into from
Oct 7, 2022
Merged

Conversation

warthecatalyst
Copy link
Contributor

@warthecatalyst warthecatalyst commented Aug 19, 2022

This PR is mainly about yolov3 model in ObjectDetection.

  1. add Yolov3 Model in basic model zoo and can be trained with a simple loss function for Pytorch engine. And corresponding test.
  2. add Yolov3 Loss function, maybe still have some bugs to fix
  3. add an example of yolov3Training: TrainPikachuWithYOLOV3
  4. About ObjectDetectionDataSet, add a new function getClasses() because for most ObjectDetection Models and Translators, we need to know how many types of objects in the dataset.
  5. Add ObjectDetection for DJL-Zero, and due to small bugs for yolov3, currently just train a SingleShotDetection Model.

TODO List:

  1. Translator of Yolov3(actually I'm written a demo but it can't pass the code check so I stash it.
  2. For COCODataSet, we should update synset.

@zachgk zachgk marked this pull request as draft August 19, 2022 16:56
@warthecatalyst warthecatalyst force-pushed the ObjectDetection branch 2 times, most recently from 3b530c1 to 7474f50 Compare August 26, 2022 01:12
@@ -48,6 +48,8 @@ public SingleShotDetectionLoss() {
@Override
protected Pair<NDList, NDList> inputForComponent(
int componentIndex, NDList labels, NDList predictions) {
System.out.println(labels.singletonOrThrow()); //print labels for test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove testing code

import java.util.ArrayList;
import java.util.Arrays;

public class YOLOv3Loss extends Loss{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

./gradlew formatJava

return inter.div(union);
}

public ArrayList<PairList<Long,Rectangle>> getTargetFromCurrentLabel(NDArray labels){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use List instead of ArrayList in function declaration

}

//calculate IOU is already defined in Rectangle
public NDArray calculateIOU(NDArray boxA,NDArray boxB){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

javadoc

@@ -87,7 +87,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(new Metrics());

Shape inputShape = new Shape(arguments.getBatchSize(), 3, 256, 256);
Shape inputShape = new Shape(1, 3, 256, 256);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably shouldn't modify the example to use a fixed batch size of 1 rather than the variable version

@warthecatalyst
Copy link
Contributor Author

warthecatalyst commented Aug 29, 2022 via email

@warthecatalyst warthecatalyst force-pushed the ObjectDetection branch 2 times, most recently from afa3241 to efcaa88 Compare October 5, 2022 15:31
@warthecatalyst warthecatalyst marked this pull request as ready for review October 6, 2022 06:18
@warthecatalyst warthecatalyst changed the title Object detection Object Detection Oct 6, 2022
/** {@inheritDoc} */
@Override
public List<String> getClasses() {
return synset;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to fail clearly rather than have something partial that may confuse people that it works. Here, just throw an exception, specifically at UnsupportedOperationException to indicate that it is not yet implemented.

@@ -87,7 +87,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(new Metrics());

Shape inputShape = new Shape(arguments.getBatchSize(), 3, 256, 256);
Shape inputShape = new Shape(1, 3, 256, 256);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably shouldn't modify the example to use a fixed batch size of 1 rather than the variable version


public class YOLOv3Test {
@Test
public void testDarkNet() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should add a test for the full yolov3 model as well

1. add Model of Yolov3(completely right I think, at least for Pytorch engine)
2. add Loss function of Yolov3 (still have some small bugs, it can be successfully trained for 4-5 minutes)...
3. add ObjectDetection in djl-zero, and currently just training SingleShotDetection Model
4. For ObjectDetectionDataSet, add a new method getClasses(), because for most models and Translators, we still need the classes of dataset
@zachgk zachgk merged commit b25ec9c into deepjavalibrary:master Oct 7, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants