Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Object Detection #1930

Merged
merged 2 commits into from
Oct 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
487 changes: 487 additions & 0 deletions api/src/main/java/ai/djl/training/loss/YOLOv3Loss.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ public PairList<Long, Rectangle> getObjects(long index) {
return new PairList<>(Collections.singletonList(labels.get((int) index)));
}

/** {@inheritDoc} */
@Override
public List<String> getClasses() {
return Collections.singletonList("banana");
}

/** {@inheritDoc} */
@Override
protected long availableSize() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,13 @@
* <p>Each image might have different {@link ai.djl.ndarray.types.Shape}s.
*/
public class CocoDetection extends ObjectDetectionDataset {

// TODO: Add synset logic for coco dataset
private static final String ARTIFACT_ID = "coco";
private static final String VERSION = "1.0";

private Usage usage;
private List<Path> imagePaths;
private List<PairList<Long, Rectangle>> labels;

private MRL mrl;
private boolean prepared;

Expand All @@ -79,6 +78,13 @@ public PairList<Long, Rectangle> getObjects(long index) {
return labels.get(Math.toIntExact(index));
}

/** {@inheritDoc} */
@Override
public List<String> getClasses() {
throw new UnsupportedOperationException(
"getClasses() for CocoDetection has not been implemented yet.");
}

/** {@inheritDoc} */
@Override
public void prepare(Progress progress) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
import ai.djl.util.PairList;

import java.io.IOException;
import java.util.List;

/**
* A helper to create {@link ai.djl.training.dataset.Dataset}s for {@link
* ai.djl.Application.CV#OBJECT_DETECTION}.
*/
public abstract class ObjectDetectionDataset extends ImageDataset {

/**
* Creates a new instance of {@link ObjectDetectionDataset} with the given necessary
* configurations.
Expand Down Expand Up @@ -68,4 +68,11 @@ public Record get(NDManager manager, long index) throws IOException {
* @throws IOException if the data could not be loaded
*/
public abstract PairList<Long, Rectangle> getObjects(long index) throws IOException;

/**
* Returns the classes that detected objects in the dataset can be classified into.
*
* @return the classes that detected objects in the dataset can be classified into.
*/
public abstract List<String> getClasses();
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ public PairList<Long, Rectangle> getObjects(long index) {
return new PairList<>(Collections.singletonList(labels.get((int) index)));
}

/** {@inheritDoc} */
@Override
public List<String> getClasses() {
return Collections.singletonList("pikachu");
}

/** {@inheritDoc} */
@Override
protected long availableSize() {
Expand Down
137 changes: 137 additions & 0 deletions djl-zero/src/main/java/ai/djl/zero/cv/ObjectDetection.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.zero.cv;

import ai.djl.Model;
import ai.djl.basicdataset.cv.ObjectDetectionDataset;
import ai.djl.basicmodelzoo.cv.object_detection.ssd.SingleShotDetection;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.SingleShotDetectionTranslator;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.BoundingBoxError;
import ai.djl.training.evaluator.SingleShotDetectionAccuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.SingleShotDetectionLoss;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.zero.Performance;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/** ObjectDetection takes an image and extract one or more main subjects in the image. */
public final class ObjectDetection {
private ObjectDetection() {}

/**
* Trains the recommended object detection model on a custom dataset. Currently, trains a
* SingleShotDetection Model.
*
* <p>In order to train on a custom dataset, you must create a custom {@link
* ObjectDetectionDataset} to load your data.
*
* @param dataset the data to train with
* @param performance to determine the desired model tradeoffs
* @return the model as a {@link ZooModel} with the {@link Translator} included
* @throws IOException if the dataset could not be loaded
* @throws TranslateException if the translator has errors
*/
public static ZooModel<Image, DetectedObjects> train(
ObjectDetectionDataset dataset, Performance performance)
throws IOException, TranslateException {
List<String> classes = dataset.getClasses();
int channels = dataset.getImageChannels();
int width =
dataset.getImageWidth()
.orElseThrow(
() ->
new IllegalArgumentException(
"The dataset must have a fixed image width"));
int height =
dataset.getImageHeight()
.orElseThrow(
() ->
new IllegalArgumentException(
"The dataset must have a fixed image height"));

Shape imageShape = new Shape(channels, height, width);

Dataset[] splitDataset = dataset.randomSplit(8, 2);
Dataset trainDataset = splitDataset[0];
Dataset validateDataset = splitDataset[1];

Block block = getSsdTrainBlock(classes.size());
Model model = Model.newInstance("ObjectDetection");
model.setBlock(block);

TrainingConfig trainingConfig =
new DefaultTrainingConfig(new SingleShotDetectionLoss())
.addEvaluator(new SingleShotDetectionAccuracy("classAccuracy"))
.addEvaluator(new BoundingBoxError("boundingBoxError"))
.addTrainingListeners(TrainingListener.Defaults.basic());

try (Trainer trainer = model.newTrainer(trainingConfig)) {
trainer.initialize(new Shape(1).addAll(imageShape));
EasyTrain.fit(trainer, 50, trainDataset, validateDataset);
}

Translator<Image, DetectedObjects> translator =
SingleShotDetectionTranslator.builder()
.addTransform(new ToTensor())
.optSynset(classes)
.optThreshold(0.6f)
.build();

return new ZooModel<>(model, translator);
}

private static Block getSsdTrainBlock(int numClasses) {
int[] numFilters = {16, 32, 64};
SequentialBlock baseBlock = new SequentialBlock();
for (int numFilter : numFilters) {
baseBlock.add(SingleShotDetection.getDownSamplingBlock(numFilter));
}

List<List<Float>> sizes = new ArrayList<>();
List<List<Float>> ratios = new ArrayList<>();
for (int i = 0; i < 5; i++) {
ratios.add(Arrays.asList(1f, 2f, 0.5f));
}
sizes.add(Arrays.asList(0.2f, 0.272f));
sizes.add(Arrays.asList(0.37f, 0.447f));
sizes.add(Arrays.asList(0.54f, 0.619f));
sizes.add(Arrays.asList(0.71f, 0.79f));
sizes.add(Arrays.asList(0.88f, 0.961f));

return SingleShotDetection.builder()
.setNumClasses(numClasses)
.setNumFeatures(3)
.optGlobalPool(true)
.setRatios(ratios)
.setSizes(sizes)
.setBaseNetwork(baseBlock)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arg
.optUsage(usage)
.optLimit(arguments.getLimit())
.optPipeline(pipeline)
.setSampling(arguments.getBatchSize(), true)
.setSampling(1, true)
.build();
pikachuDetection.prepare(new ProgressBar());

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.training;

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.PikachuDetection;
import ai.djl.basicmodelzoo.cv.object_detection.yolo.YOLOV3;
import ai.djl.engine.Engine;
import ai.djl.examples.training.util.Arguments;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.listener.SaveModelTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.YOLOv3Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;

import java.io.IOException;

public final class TrainPikachuWithYOLOV3 {
private TrainPikachuWithYOLOV3() {}

public static void main(String[] args)
throws IOException, TranslateException, MalformedModelException {
TrainPikachuWithYOLOV3.runExample(args);
}

public static TrainingResult runExample(String[] args) throws IOException, TranslateException {
Arguments arguments = new Arguments().parseArgs(args);
if (arguments == null) {
return null;
}

try (Model model = Model.newInstance("pikachu-yolov3")) {
model.setBlock(YOLOV3.builder().setNumClasses(1).build());
RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN, arguments);
RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST, arguments);

DefaultTrainingConfig config = setupTrainingConfig(arguments);

try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(new Metrics());

Shape inputShape = new Shape(1, 3, 256, 256);
trainer.initialize(inputShape);

EasyTrain.fit(trainer, arguments.getEpoch(), trainingSet, validateSet);

return trainer.getTrainingResult();
}
}
}

private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments)
throws IOException {
Pipeline pipeline = new Pipeline(new ToTensor());
PikachuDetection pikachuDetection =
PikachuDetection.builder()
.optUsage(usage)
.optLimit(arguments.getLimit())
.optPipeline(pipeline)
.setSampling(8, true)
.build();
pikachuDetection.prepare(new ProgressBar());

return pikachuDetection;
}

private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
String outputDir = arguments.getOutputDir();
SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir);
listener.setSaveModelCallback(
trainer -> {
TrainingResult result = trainer.getTrainingResult();
Model model = trainer.getModel();
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
});

float[] anchorsArray = YOLOv3Loss.getPresetAnchors();
for (int i = 0; i < anchorsArray.length; i++) {
anchorsArray[i] = anchorsArray[i] * 256 / 416; // reshaping into the
}

return new DefaultTrainingConfig(
YOLOv3Loss.builder()
.setNumClasses(1)
.setInputShape(new Shape(256, 256))
.setAnchorsArray(anchorsArray)
.build())
.optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus()))
.addTrainingListeners(TrainingListener.Defaults.basic())
.addTrainingListeners(listener);
}
}
Loading