Skip to content

Commit

Permalink
yolov3 training
Browse files Browse the repository at this point in the history
  • Loading branch information
warthecatalyst committed Oct 4, 2022
1 parent 25fc6e7 commit afa3241
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 129 deletions.
119 changes: 40 additions & 79 deletions api/src/main/java/ai/djl/training/loss/YOLOv3Loss.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
package ai.djl.training.loss;

import ai.djl.modality.cv.output.Rectangle;
import ai.djl.ndarray.*;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;

public class YOLOv3Loss extends Loss {
/**
* {@code YOLOv3Loss} is an implementation of {@link Loss}. It is used to compute the
* loss while training a YOLOv3 model for object detection. It involves
* computing the targets given the generated anchors, labels and predictions, and then computing the
* sum of class predictions and bounding box predictions.
*/
public final class YOLOv3Loss extends Loss {
//TODO: currently not finished, still have some bugs inside and it can only be trained with PyTorch Engine
/*
PRESETANCHORS shapes come from the K-means clustering of COCO dataset, which image size is 416*416
it can be reshaped into any shape like 256*256, just multiply each value with 256/416
Expand All @@ -25,12 +31,7 @@ public class YOLOv3Loss extends Loss {
private float ignoreThreshold;
private NDManager manager;
private static final float EPSILON = 1e-7f;
private int cnt = 0;
/**
* Creates a YOLOv3Loss with a Builder
*
* @param builder a builder to build YOLOv3Loss
*/

private YOLOv3Loss(Builder builder) {
super(builder.name);
this.anchors = builder.anchorsArray;
Expand All @@ -40,16 +41,39 @@ private YOLOv3Loss(Builder builder) {
this.ignoreThreshold = builder.ignoreThreshold;
}

/**
* Make the value of given NDArray between tMin and tMax.
*
* @param tList the given NDArray
* @param tMin the min value
* @param tMax the max value
* @return a NDArray where values are set between tMin and tMax
*/
public NDArray clipByTensor(NDArray tList, float tMin, float tMax) {
NDArray result = tList.gte(tMin).mul(tList).add(tList.lt(tMin).mul(tMin));
result = result.lte(tMax).mul(result).add(result.gt(tMax).mul(tMax));
return result;
}

/**
* Calculates the MSELoss between prediction and target.
*
* @param prediction the prediction array
* @param target the target array
* @return the MSELoss between prediction and target
*/
public NDArray MSELoss(NDArray prediction, NDArray target) {
return prediction.sub(target).pow(2);
}


/**
* Calculates the BCELoss between prediction and target.
*
* @param prediction the prediction array
* @param target the target array
* @return the BCELoss between prediction and target
*/
public NDArray BCELoss(NDArray prediction, NDArray target) {
prediction = clipByTensor(prediction, EPSILON, (float) (1.0 - EPSILON));
return prediction.log().mul(target).add(
Expand All @@ -59,7 +83,6 @@ public NDArray BCELoss(NDArray prediction, NDArray target) {

@Override
public NDArray evaluate(NDList labels, NDList predictions) {
System.out.println("In YOLOv3Loss evaluate initialize: "+cnt++);
manager = predictions.getManager();

/*
Expand All @@ -72,15 +95,14 @@ public NDArray evaluate(NDList labels, NDList predictions) {
NDArray[] lossComponents = new NDArray[3];
for (int i = 0; i < 3; i++) {
lossComponents[i] = evaluateOneOutput(i, predictions.get(i), labels.singletonOrThrow());
System.out.println("finalLoss = "+lossComponents[i]);;
}

// calculate the final loss
return NDArrays.add(lossComponents);
}

/**
* for one NDArray, to compute the loss
* Computes the Loss for one outputLayer.
*
* @param componentIndex which outputLayer does current input represent. the shape should be
* (13*13,26*26,52*52)
Expand All @@ -89,10 +111,6 @@ public NDArray evaluate(NDList labels, NDList predictions) {
* @return the total loss of a outputLayer
*/
public NDArray evaluateOneOutput(int componentIndex, NDArray input, NDArray labels){
System.out.println("In yolov3 Loss evaluateOneOutput: "+componentIndex);
// System.out.println(Arrays.toString(anchors));
System.out.println("input = "+input);

int batchSize = (int) input.getShape().get(0),
inW = (int) input.getShape().get(2),
inH = (int) input.getShape().get(3);
Expand All @@ -101,23 +119,17 @@ public NDArray evaluateOneOutput(int componentIndex, NDArray input, NDArray labe
input.reshape(batchSize,3,boxAttr,inW,inH)
.transpose(1, 0, 3, 4, 2); //reshape into (3,batchSize,inW,inH,attrs)

System.out.println("prediction shape = "+prediction.getShape());

//the prediction value of x,y,w,h which shape should be (3,batchSize,inW,inH)
NDArray x = Activation.sigmoid(prediction.get("...,0"));
NDArray y = Activation.sigmoid(prediction.get("...,1"));
NDArray w = prediction.get("...,2");
NDArray h = prediction.get("...,3");

// Confidence of whether there is an object and conditional probability of each class
// it should be reshaped into (batchSize,3

// it should be reshaped into (batchSize,3)
NDArray conf = Activation.sigmoid(prediction.get("...,4")).transpose(1,0,2,3);
NDArray predClass = Activation.sigmoid(prediction.get("...,5:")).transpose(1,0,2,3,4);

System.out.println("conf = "+conf);
System.out.println("predClass = "+predClass);

// get an NDList of groundTruth which contains boxLossScale and groundTruth
NDList truthList = getTarget(labels, inH, inW);

Expand All @@ -127,16 +139,9 @@ public NDArray evaluateOneOutput(int componentIndex, NDArray input, NDArray labe
*/
NDArray boxLossScale = truthList.get(0).transpose(1,0,2,3), groundTruth = truthList.get(1);

System.out.println("groundTruth = "+groundTruth);

// iou shape should be: (batchSize,3 ,inW,inH)
NDArray iou = calculateIOU(x, y, groundTruth.get("...,0:4"), componentIndex).transpose(1,0,2,3);

System.out.println("iouTruth = ");
for(int j = 0;j<3;j++){
System.out.println(iou.get(0+","+j));
}

// get noObjMask and objMask
NDArray noObjMask =
NDArrays.where(
Expand All @@ -149,16 +154,6 @@ public NDArray evaluateOneOutput(int componentIndex, NDArray input, NDArray labe
objMask = NDArrays.where(iou.gte(ignoreThreshold/2),objMask,manager.zeros(objMask.getShape())); //to get rid of wrong ones
noObjMask = NDArrays.where(objMask.eq(1f),manager.zeros(noObjMask.getShape()),noObjMask);

System.out.println("objMask = ");
for(int j = 0;j<3;j++){
System.out.println(objMask.get(0+","+j));
}

System.out.println("noObjMask = ");
for(int j = 0;j<3;j++){
System.out.println(noObjMask.get(0+","+j));
}

NDArray xTrue = groundTruth.get("...,0");
NDArray yTrue = groundTruth.get("...,1");
NDArray wTrue = groundTruth.get("...,2");
Expand All @@ -169,18 +164,11 @@ public NDArray evaluateOneOutput(int componentIndex, NDArray input, NDArray labe
anchors[componentIndex*6+2],
anchors[componentIndex*6+4]}).div(inputShape.get(0));

System.out.println("widths = "+widths);
NDArray width1 = widths.broadcast(inH,inW,batchSize,3).transpose(3,2,1,0);
for(int i = 0;i<3;i++){
System.out.println(width1.get(i));
}

NDArray heights = manager.create(new float[]{anchors[componentIndex*6+1],
anchors[componentIndex*6+3],
anchors[componentIndex*6+5]}).div(inputShape.get(1));

System.out.println("heights = "+heights);

// three loss parts: box Loss, confidence Loss, and class Loss
NDArray boxLoss =
objMask.mul(boxLossScale).mul(
Expand All @@ -200,10 +188,6 @@ public NDArray evaluateOneOutput(int componentIndex, NDArray input, NDArray labe

NDArray noObjLoss = noObjMask.mul(conf.mul(-1).add(1+EPSILON).log().mul(-1)).sum();

System.out.println("boxLoss = "+boxLoss);
System.out.println("confLoss = "+confLoss);
System.out.println("noObjLoss = "+noObjLoss);

return boxLoss.add(confLoss).add(noObjLoss).div(batchSize);
}

Expand All @@ -216,8 +200,6 @@ public NDArray evaluateOneOutput(int componentIndex, NDArray input, NDArray labe
* @return an NDList of {boxLossScale and groundTruth}
*/
public NDList getTarget(NDArray labels, int inH, int inW) {
System.out.println("in getTarget");

int batchSize = (int) labels.size(0);

// the loss Scale of a box, used to punctuate small boxes
Expand All @@ -237,7 +219,6 @@ public NDList getTarget(NDArray labels, int inH, int inW) {
NDArray groundTruth = manager.zeros(new Shape(inW,inH,boxAttr-1),DataType.FLOAT32);

NDArray picture = labels.get(batch);
System.out.println(picture);
// the shape should be (objectNums,5)
NDArray xgt = picture.get("...,1").add(picture.get("...,3").div(2)).mul(inW);
// Center of x should be X value in labels and add half of the width and multiplies the
Expand Down Expand Up @@ -296,7 +277,6 @@ public NDList getTarget(NDArray labels, int inH, int inW) {
*/
public NDArray calculateIOU(
NDArray predx, NDArray predy, NDArray groundTruth,int componentIndex) {
System.out.println("in calculate IOU");
int inW = (int) predx.getShape().get(2), inH = (int) predx.getShape().get(3);
int strideW = (int) inputShape.get(0)/inW, strideH = (int) inputShape.get(1)/inH;

Expand All @@ -307,56 +287,31 @@ public NDArray calculateIOU(
for(int i = 0;i<3;i++){
NDArray curPredx = predx.get(i),curPredy = predy.get(i);
float width = anchors[componentIndex*6+2*i]/strideW, height = anchors[componentIndex*6+2*i+1]/strideH;
System.out.println("width = "+width);
System.out.println("height = "+height);
System.out.println("curPredx = "+curPredx);
System.out.println("curPredy = "+curPredy);

NDArray predLeft = curPredx.sub(width/2),
predRight = curPredx.add(width/2),
predTop = curPredy.sub(height/2),
predBottom = curPredy.add(height/2);

System.out.println("predLeft = "+predLeft);
System.out.println("predRight = "+predRight);
System.out.println("predTop = "+predTop);
System.out.println("predBottom = "+predBottom);

NDArray truth = groundTruth.get(i);
System.out.println("truth = "+truth);

NDArray trueLeft = truth.get("...,0").sub(truth.get("...,2").mul(inW).div(2)),
trueRight = truth.get("...,0").add(truth.get("...,2").mul(inW).div(2)),
trueTop = truth.get("...,1").sub(truth.get("...,3").mul(inH).div(2)),
trueBottom = truth.get("...,1").add(truth.get("...,3").mul(inH).div(2));

System.out.println("trueLeft = "+trueLeft);
System.out.println("trueRight = "+trueRight);
System.out.println("trueTop = "+trueTop);
System.out.println("trueBottom = "+trueBottom);

NDArray left = NDArrays.maximum(predLeft,trueLeft),
right = NDArrays.minimum(predRight,trueRight),
top = NDArrays.maximum(predTop,trueTop),
bottom = NDArrays.minimum(predBottom,trueBottom);

System.out.println("left = "+left);
System.out.println("right = "+right);
System.out.println("top = "+top);
System.out.println("bottom = "+bottom);

NDArray inter = right.sub(left).mul(bottom.sub(top));
System.out.println("inter = "+inter);

NDArray union =
truth.get("...,2").mul(inW)
.mul(truth.get("...,3").mul(inH))
.add(width * height)
.sub(inter)
.add(EPSILON); //should not be divided by zero

System.out.println("union = "+union);
System.out.println("inter/union = "+inter.div(union));
iouComponent.add(inter.div(union));
}

Expand Down Expand Up @@ -397,6 +352,12 @@ public Builder setNumClasses(int numClasses) {
return this;
}

/**
* Sets the shape of the input picture.
*
* @param inputShape
* @return
*/
public Builder setInputShape(Shape inputShape) {
this.inputShape = inputShape;
return this;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package ai.djl.examples.training;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.PikachuDetection;
import ai.djl.basicmodelzoo.cv.object_detection.ssd.SingleShotDetection;
import ai.djl.basicmodelzoo.cv.object_detection.yolo.YOLOV3;
import ai.djl.engine.Engine;
import ai.djl.examples.training.util.Arguments;
Expand All @@ -15,7 +13,6 @@
import ai.djl.modality.cv.MultiBoxDetection;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.SingleShotDetectionTranslator;
import ai.djl.modality.cv.translator.YoloV3Translator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
Expand Down Expand Up @@ -49,7 +46,7 @@
import java.util.Collections;
import java.util.List;

public class TrainPikachuWithYOLOV3 {
public final class TrainPikachuWithYOLOV3 {
private TrainPikachuWithYOLOV3() {}

public static void main(String[] args) throws IOException, TranslateException, MalformedModelException {
Expand Down Expand Up @@ -131,7 +128,7 @@ private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arg
.optUsage(usage)
.optLimit(arguments.getLimit())
.optPipeline(pipeline)
.setSampling(1, true)
.setSampling(8, true)
.build();
pikachuDetection.prepare(new ProgressBar());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public static void main(String[] args) {
new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.optDevices(Engine.getInstance().getDevices(2)) // use a simple gpu
.optInitializer(Initializer.ONES, Parameter.Type.WEIGHT);
Block yolov3 = new YOLOV3(YOLOV3.builder());
Block yolov3 = YOLOV3.builder().build();
yolov3.getOutputShapes(new Shape[] {new Shape(1, 3, 416, 416)});
try (Model model = Model.newInstance("YOLOv3")) {
model.setBlock(yolov3);
Expand Down
Loading

0 comments on commit afa3241

Please sign in to comment.