Skip to content

Commit

Permalink
Transfer learning with pytorch engine on fresh fruit dataset (deepjav…
Browse files Browse the repository at this point in the history
…alibrary#2070)

* ATLearning on PyTorch, frozen pretrained model

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
Co-authored-by: Zach Kimberg <kimbergz@amazon.com>
  • Loading branch information
3 people authored and patins1 committed Oct 30, 2022
1 parent 1659a6a commit c1f7860
Show file tree
Hide file tree
Showing 33 changed files with 926 additions and 68 deletions.
37 changes: 37 additions & 0 deletions api/src/main/java/ai/djl/modality/cv/transform/OneHot.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.transform;

import ai.djl.ndarray.NDArray;
import ai.djl.translate.Transform;

/** A {@link Transform} that converts the labels {@link NDArray} to one-hot labels. */
public class OneHot implements Transform {

private int numClass;

/**
* Creates a {@code toOneHot} {@link Transform} that converts the sparse label to one-hot label.
*
* @param numClass number of classes
*/
public OneHot(int numClass) {
this.numClass = numClass;
}

/** {@inheritDoc} */
@Override
public NDArray transform(NDArray array) {
return array.oneHot(numClass);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,37 @@
*/
package ai.djl.modality.cv.transform;

import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.translate.Transform;
import ai.djl.util.RandomUtils;

import java.util.Random;

/**
* A {@link Transform} that randomly flip the input image left to right with a probability of 0.5.
*/
public class RandomFlipLeftRight implements Transform {
Integer seed;

/** Creates a new instance of {@code RandomFlipLeftRight}. */
public RandomFlipLeftRight() {}

/**
* Creates a new instance of {@code RandomFlipLeftRight} with the given seed.
*
* @param seed the value of the seed
*/
public RandomFlipLeftRight(int seed) {
this.seed = seed;
}

/** {@inheritDoc} */
@Override
public NDArray transform(NDArray array) {
return NDImageUtils.randomFlipLeftRight(array);
Random rnd = (seed != null) ? new Random(seed) : RandomUtils.RANDOM;
if (rnd.nextFloat() > 0.5) {
array.flip(1);
}
return array;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,38 @@
*/
package ai.djl.modality.cv.transform;

import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.translate.Transform;
import ai.djl.util.RandomUtils;

import java.util.Random;

/**
* A {@link Transform} that randomly flip the input image top to bottom with a probability of 0.5.
*/
public class RandomFlipTopBottom implements Transform {

Integer seed;

/** Creates a new instance of {@code RandomFlipTopBottom}. */
public RandomFlipTopBottom() {}

/**
* Creates a new instance of {@code RandomFlipTopBottom} with the given seed.
*
* @param seed the value of the seed
*/
public RandomFlipTopBottom(int seed) {
this.seed = seed;
}

/** {@inheritDoc} */
@Override
public NDArray transform(NDArray array) {
return NDImageUtils.randomFlipTopBottom(array);
Random rnd = (seed != null) ? new Random(seed) : RandomUtils.RANDOM;
if (rnd.nextFloat() > 0.5) {
array.flip(0);
}
return array;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@ public RandomResizedCrop(
this.maxAspectRatio = maxAspectRatio;
}

/**
* Creates a {@code RandomResizedCrop} {@link Transform}.
*
* @param width the output width of the image
* @param height the output height of the image
*/
public RandomResizedCrop(int width, int height) {
this.width = width;
this.height = height;
this.minAreaScale = 0.08;
this.maxAreaScale = 1.0;
this.minAspectRatio = 3.0 / 4.0;
this.maxAspectRatio = 4.0 / 3.0;
}

/** {@inheritDoc} */
@Override
public NDArray transform(NDArray array) {
Expand Down
8 changes: 4 additions & 4 deletions api/src/main/java/ai/djl/modality/cv/util/NDImageUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ private NDImageUtils() {}
*
* @param image the image to resize
* @param size the desired size
* @return the resized NDList
* @return the resized NDArray
*/
public static NDArray resize(NDArray image, int size) {
return resize(image, size, size, Image.Interpolation.BILINEAR);
Expand All @@ -42,7 +42,7 @@ public static NDArray resize(NDArray image, int size) {
* @param image the image to resize
* @param width the desired width
* @param height the desired height
* @return the resized NDList
* @return the resized NDArray
*/
public static NDArray resize(NDArray image, int width, int height) {
return resize(image, width, height, Image.Interpolation.BILINEAR);
Expand All @@ -55,7 +55,7 @@ public static NDArray resize(NDArray image, int width, int height) {
* @param width the desired width
* @param height the desired height
* @param interpolation the desired interpolation
* @return the resized NDList
* @return the resized NDArray
*/
public static NDArray resize(
NDArray image, int width, int height, Image.Interpolation interpolation) {
Expand Down Expand Up @@ -334,6 +334,6 @@ public static boolean isCHW(Shape shape) {
} else if (shape.get(2) == 1 || shape.get(2) == 3) {
return false;
}
throw new IllegalArgumentException("Image is not CHW or HWC");
throw new IllegalArgumentException("Image is neither CHW nor HWC");
}
}
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ default NDList forward(
void initialize(NDManager manager, DataType dataType, Shape... inputShapes);

/**
* Returns a boolean whether the block is initialized.
* Returns a boolean whether the block is initialized (block has inputShape and params have
* nonNull array).
*
* @return whether the block is initialized
*/
Expand Down
27 changes: 25 additions & 2 deletions api/src/main/java/ai/djl/nn/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ public void setShape(Shape shape) {
this.shape = shape;
}

/**
* Gets the shape of this {@code Parameter}.
*
* @return the shape of this {@code Parameter}
*/
public Shape getShape() {
return shape;
}

/**
* Gets the values of this {@code Parameter} as an {@link NDArray}.
*
Expand Down Expand Up @@ -170,6 +179,16 @@ public void setInitializer(Initializer initializer) {
this.initializer = initializer;
}

/**
* Returns the {@link Initializer} for this {@code Parameter}, if not already set. If overwrite
* flag is true, sets the initializer regardless.
*
* @return the initializer of this {@code Parameter}
*/
public Initializer getInitializer() {
return initializer;
}

/**
* Initializes the parameter with the given {@link NDManager}, with given {@link DataType} for
* the given expected input shapes.
Expand All @@ -178,9 +197,13 @@ public void setInitializer(Initializer initializer) {
* @param dataType the datatype of the {@code Parameter}
*/
public void initialize(NDManager manager, DataType dataType) {
Objects.requireNonNull(initializer, "No initializer has been set");
Objects.requireNonNull(shape, "No parameter shape has been set");
if (!isInitialized()) {
// Params in a PtSymbolBlock is set during model loading and its isInitialized()=true.
// Shouldn't further initialize it.
Objects.requireNonNull(initializer, "No initializer has been set");
// Params in a PtSymbolBlock can have null shape, but are still initialized (has nonNull
// array)
Objects.requireNonNull(shape, "No parameter shape has been set");
array = initializer.initialize(manager, shape, dataType);
array.setName(name);
}
Expand Down
13 changes: 11 additions & 2 deletions api/src/main/java/ai/djl/nn/core/Linear.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,19 @@
* <p>It has the following shapes:
*
* <ul>
* <li>input X: [x1, x2, , xn, input_dim]
* <li>input X: [x1, x2, ..., xn, input_dim]
* <li>weight W: [units, input_dim]
* <li>Bias b: [units]
* <li>output Y: [x1, x2, …, xn, units]
* <li>output Y: [x1, x2, ..., xn, units]
* </ul>
*
* <p>It is most typically used with a simple batched 1D input. In that case, the shape would be:
*
* <ul>
* <li>input X: [batch_num, input_dim]
* <li>weight W: [units, input_dim]
* <li>Bias b: [units]
* <li>output Y: [batch_num, units]
* </ul>
*
* <p>The Linear block should be constructed using {@link Linear.Builder}.
Expand Down
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/training/evaluator/Accuracy.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions)
predictionReduced = predictionReduced.reshape(label.getShape());
} else {
// Multi-class, one-hot label
predictionReduced = prediction;
predictionReduced = prediction.argMax(axis);
label = label.argMax(axis);
}
// result of sum is int64 now
long total = label.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ public MaskedSoftmaxCrossEntropyLoss(String name) {
* @param name the name of the loss
* @param weight the weight to apply on the loss value, default 1
* @param classAxis the axis that represents the class probabilities, default -1
* @param sparseLabel whether labels are integer array or probabilities, default true
* @param fromLogit whether predictions are log probabilities or un-normalized numbers, default
* false
* @param sparseLabel whether labels are 1-D integer array of [batch_size] (false) or 2-D
* probabilities of [batch_size, n-class] (true), default true
* @param fromLogit if true, the inputs are assumed to be the numbers before being applied with
* softmax. Then logSoftmax will be applied to input, default false
*/
public MaskedSoftmaxCrossEntropyLoss(
String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ public SoftmaxCrossEntropyLoss() {
* @param name the name of the loss
*/
public SoftmaxCrossEntropyLoss(String name) {
// By default, fromLogit=true, means it takes the prediction before being
// applied softmax.
this(name, 1, -1, true, true);
}

Expand All @@ -52,10 +54,10 @@ public SoftmaxCrossEntropyLoss(String name) {
* @param name the name of the loss
* @param weight the weight to apply on the loss value, default 1
* @param classAxis the axis that represents the class probabilities, default -1
* @param sparseLabel whether labels are 1-D integer array or 2-D probabilities of [batch_size,
* n-class], default true
* @param fromLogit whether predictions are un-normalized numbers or log probabilities, if true,
* logSoftmax will be applied to input, default true
* @param sparseLabel whether labels are rank-1 integer array of [batch_size] (false) or rank-2
* one-hot of [batch_size, n-class] (true), default true
* @param fromLogit if true, the inputs are assumed to be the numbers before being applied with
* softmax. Then logSoftmax will be applied to input, default true
*/
public SoftmaxCrossEntropyLoss(
String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit) {
Expand Down
13 changes: 7 additions & 6 deletions api/src/main/java/ai/djl/training/optimizer/Adagrad.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.training.tracker.ParameterTracker;
import ai.djl.training.tracker.Tracker;

import java.util.Map;
Expand All @@ -40,7 +41,7 @@
*/
public class Adagrad extends Optimizer {

private Tracker learningRateTracker;
private ParameterTracker learningRateTracker;
private float epsilon;

private Map<String, Map<Device, NDArray>> history;
Expand All @@ -61,7 +62,7 @@ protected Adagrad(Builder builder) {
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
int t = updateCount(parameterId);
float newLearningRate = learningRateTracker.getNewValue(t);
float newLearningRate = learningRateTracker.getNewValue(parameterId, t);
float weightDecay = getWeightDecay();

if (Float.isNaN(newLearningRate)
Expand Down Expand Up @@ -98,7 +99,7 @@ public static Builder builder() {
/** The Builder to construct an {@link Adagrad} object. */
public static final class Builder extends OptimizerBuilder<Builder> {

private Tracker learningRateTracker = Tracker.fixed(0.001f);
private ParameterTracker learningRateTracker = Tracker.fixed(0.001f);
private float epsilon = 1e-8f;

Builder() {}
Expand All @@ -110,12 +111,12 @@ protected Builder self() {
}

/**
* Sets the {@link Tracker} for this optimizer.
* Sets the {@link ParameterTracker} for this optimizer.
*
* @param learningRateTracker the {@link Tracker} to be set
* @param learningRateTracker the {@link ParameterTracker} to be set
* @return this {@code Builder}
*/
public Builder optLearningRateTracker(Tracker learningRateTracker) {
public Builder optLearningRateTracker(ParameterTracker learningRateTracker) {
this.learningRateTracker = learningRateTracker;
return this;
}
Expand Down
Loading

0 comments on commit c1f7860

Please sign in to comment.