Skip to content

Commit

Permalink
simple loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Oct 11, 2022
1 parent a7d957b commit 5dc220c
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 54 deletions.
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/internal/NDFormat.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public static String format(
}

if (DEBUG) {
sb.append(" ->");
sb.append(" ~>");
}

NDFormat format;
Expand Down
13 changes: 0 additions & 13 deletions api/src/main/java/ai/djl/training/loss/Loss.java
Original file line number Diff line number Diff line change
Expand Up @@ -196,19 +196,6 @@ public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(
return new SoftmaxCrossEntropyLoss(name, weight, classAxis, sparseLabel, fromLogit);
}

/**
* Returns a new instance of {@link SoftmaxCrossEntropyLoss}, which assumes softmax already
* applied.
*
* @param name the name of the loss
* @param fromSoftmax the input prediction is from the output of softmax, default false
* @return a new instance of {@link SoftmaxCrossEntropyLoss}
*/
public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(
String name, boolean fromSoftmax) {
return new SoftmaxCrossEntropyLoss(name, fromSoftmax);
}

/**
* Returns a new instance of {@link MaskedSoftmaxCrossEntropyLoss} with default arguments.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ public class SoftmaxCrossEntropyLoss extends Loss {
private int classAxis;
private boolean sparseLabel;
private boolean fromLogit;
private boolean fromSoftmax;

/** Creates a new instance of {@code SoftmaxCrossEntropyLoss} with default parameters. */
public SoftmaxCrossEntropyLoss() {
Expand Down Expand Up @@ -67,22 +66,6 @@ public SoftmaxCrossEntropyLoss(
this.classAxis = classAxis;
this.sparseLabel = sparseLabel;
this.fromLogit = fromLogit;
this.fromSoftmax = false;
}

/**
* Creates a new instance of {@code SoftmaxCrossEntropyLoss} from the output of softmax.
*
* @param name the name of the loss function
* @param fromSoftmax the input prediction is from the output of softmax, default false
*/
public SoftmaxCrossEntropyLoss(String name, boolean fromSoftmax) {
super(name);
this.weight = 1;
this.classAxis = -1;
this.sparseLabel = false;
this.fromLogit = false;
this.fromSoftmax = fromSoftmax;
}

/** {@inheritDoc} */
Expand All @@ -92,9 +75,6 @@ public NDArray evaluate(NDList label, NDList prediction) {
if (fromLogit) {
pred = pred.logSoftmax(classAxis);
}
if (fromSoftmax) {
pred = pred.log();
}
NDArray loss;
NDArray lab = label.singletonOrThrow();
if (sparseLabel) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.transform.Transpose;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
Expand Down Expand Up @@ -118,7 +119,7 @@ public static TrainingResult runExample(String[] args)
ImageFolder datasetTrain = getData("test", "banana", batchSize);

// Train
EasyTrain.fit(trainer, 10, datasetTrain, null);
EasyTrain.fit(trainer, 6, datasetTrain, null);

// Save model
// model.save("your-model-path");
Expand All @@ -128,25 +129,6 @@ public static TrainingResult runExample(String[] args)
return null;
}

private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
String outputDir = arguments.getOutputDir();
SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir);
listener.setSaveModelCallback(
trainer -> {
TrainingResult result = trainer.getTrainingResult();
Model model = trainer.getModel();
float accuracy = result.getValidateEvaluation("Accuracy");
model.setProperty("Accuracy", String.format("%.5f", accuracy));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
});

return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss("SoftmaxCrossEntropy", true))
.addEvaluator(new Accuracy())
.optDevices(Engine.getInstance().getDevices(1))
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
}

private static ImageFolder getData(String subfolderName, String fruit, int batchSize)
throws TranslateException, IOException {
// The dataset is from <a
Expand All @@ -168,4 +150,46 @@ private static ImageFolder getData(String subfolderName, String fruit, int batch
dataset.prepare();
return dataset;
}

private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
String outputDir = arguments.getOutputDir();
SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir);
listener.setSaveModelCallback(
trainer -> {
TrainingResult result = trainer.getTrainingResult();
Model model = trainer.getModel();
float accuracy = result.getValidateEvaluation("Accuracy");
model.setProperty("Accuracy", String.format("%.5f", accuracy));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
});

return new DefaultTrainingConfig(new SoftmaxCrossEntropy("SoftmaxCrossEntropy"))
.addEvaluator(new Accuracy())
.optDevices(Engine.getInstance().getDevices(1))
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
}

private static class SoftmaxCrossEntropy extends Loss {

/**
* Base class for metric with abstract update methods.
*
* @param name The display name of the Loss
*/
public SoftmaxCrossEntropy(String name) {
super(name);
}

/** {@inheritDoc} */
@Override
public NDArray evaluate(NDList labels, NDList predictions) {
// Here the labels are supposed to be one-hot
int classAxis = -1;
NDArray pred = predictions.singletonOrThrow().log();
NDArray lab = labels.singletonOrThrow().reshape(pred.getShape());
NDArray loss = pred.mul(lab).neg().sum(new int[] {classAxis}, true);
return loss.mean();
}
}
}

0 comments on commit 5dc220c

Please sign in to comment.