Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transfer learning with pytorch engine on fresh fruit dataset #2070

Merged
merged 19 commits into from
Oct 28, 2022
Merged
37 changes: 37 additions & 0 deletions api/src/main/java/ai/djl/modality/cv/transform/OneHot.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.transform;

import ai.djl.ndarray.NDArray;
import ai.djl.translate.Transform;

/** A {@link Transform} that converts the labels {@link NDArray} to one-hot labels. */
public class OneHot implements Transform {

private int numClass;

/**
* Creates a {@code toOneHot} {@link Transform} that converts the sparse label to one-hot label.
*
* @param numClass number of classes
*/
public OneHot(int numClass) {
this.numClass = numClass;
}

/** {@inheritDoc} */
@Override
public NDArray transform(NDArray array) {
return array.oneHot(numClass);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,37 @@
*/
package ai.djl.modality.cv.transform;

import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.translate.Transform;
import ai.djl.util.RandomUtils;

import java.util.Random;

/**
* A {@link Transform} that randomly flip the input image left to right with a probability of 0.5.
*/
public class RandomFlipLeftRight implements Transform {
Integer seed;

/** Creates a new instance of {@code RandomFlipLeftRight}. */
public RandomFlipLeftRight() {}

/**
* Creates a new instance of {@code RandomFlipLeftRight} with the given seed.
*
* @param seed the value of the seed
*/
public RandomFlipLeftRight(int seed) {
this.seed = seed;
}

/** {@inheritDoc} */
@Override
public NDArray transform(NDArray array) {
return NDImageUtils.randomFlipLeftRight(array);
Random rnd = (seed != null) ? new Random(seed) : RandomUtils.RANDOM;
if (rnd.nextFloat() > 0.5) {
array.flip(1);
}
return array;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,38 @@
*/
package ai.djl.modality.cv.transform;

import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.translate.Transform;
import ai.djl.util.RandomUtils;

import java.util.Random;

/**
* A {@link Transform} that randomly flip the input image top to bottom with a probability of 0.5.
*/
public class RandomFlipTopBottom implements Transform {

Integer seed;

/** Creates a new instance of {@code RandomFlipTopBottom}. */
public RandomFlipTopBottom() {}

/**
* Creates a new instance of {@code RandomFlipTopBottom} with the given seed.
*
* @param seed the value of the seed
*/
public RandomFlipTopBottom(int seed) {
this.seed = seed;
}

/** {@inheritDoc} */
@Override
public NDArray transform(NDArray array) {
return NDImageUtils.randomFlipTopBottom(array);
Random rnd = (seed != null) ? new Random(seed) : RandomUtils.RANDOM;
if (rnd.nextFloat() > 0.5) {
array.flip(0);
}
return array;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@ public RandomResizedCrop(
this.maxAspectRatio = maxAspectRatio;
}

/**
* Creates a {@code RandomResizedCrop} {@link Transform}.
*
* @param width the output width of the image
* @param height the output height of the image
*/
public RandomResizedCrop(int width, int height) {
this.width = width;
this.height = height;
this.minAreaScale = 0.08;
this.maxAreaScale = 1.0;
this.minAspectRatio = 3.0 / 4.0;
this.maxAspectRatio = 4.0 / 3.0;
}

/** {@inheritDoc} */
@Override
public NDArray transform(NDArray array) {
Expand Down
8 changes: 4 additions & 4 deletions api/src/main/java/ai/djl/modality/cv/util/NDImageUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ private NDImageUtils() {}
*
* @param image the image to resize
* @param size the desired size
* @return the resized NDList
* @return the resized NDArray
*/
public static NDArray resize(NDArray image, int size) {
return resize(image, size, size, Image.Interpolation.BILINEAR);
Expand All @@ -42,7 +42,7 @@ public static NDArray resize(NDArray image, int size) {
* @param image the image to resize
* @param width the desired width
* @param height the desired height
* @return the resized NDList
* @return the resized NDArray
*/
public static NDArray resize(NDArray image, int width, int height) {
return resize(image, width, height, Image.Interpolation.BILINEAR);
Expand All @@ -55,7 +55,7 @@ public static NDArray resize(NDArray image, int width, int height) {
* @param width the desired width
* @param height the desired height
* @param interpolation the desired interpolation
* @return the resized NDList
* @return the resized NDArray
*/
public static NDArray resize(
NDArray image, int width, int height, Image.Interpolation interpolation) {
Expand Down Expand Up @@ -334,6 +334,6 @@ public static boolean isCHW(Shape shape) {
} else if (shape.get(2) == 1 || shape.get(2) == 3) {
return false;
}
throw new IllegalArgumentException("Image is not CHW or HWC");
throw new IllegalArgumentException("Image is neither CHW nor HWC");
}
}
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ default NDList forward(
void initialize(NDManager manager, DataType dataType, Shape... inputShapes);

/**
* Returns a boolean whether the block is initialized.
* Returns a boolean whether the block is initialized (block has inputShape and params have
* nonNull array).
*
* @return whether the block is initialized
*/
Expand Down
27 changes: 25 additions & 2 deletions api/src/main/java/ai/djl/nn/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ public void setShape(Shape shape) {
this.shape = shape;
}

/**
* Gets the shape of this {@code Parameter}.
*
* @return the shape of this {@code Parameter}
*/
public Shape getShape() {
return shape;
}

/**
* Gets the values of this {@code Parameter} as an {@link NDArray}.
*
Expand Down Expand Up @@ -170,6 +179,16 @@ public void setInitializer(Initializer initializer) {
this.initializer = initializer;
}

/**
* Returns the {@link Initializer} for this {@code Parameter}, if not already set. If overwrite
* flag is true, sets the initializer regardless.
*
* @return the initializer of this {@code Parameter}
*/
public Initializer getInitializer() {
return initializer;
}

/**
* Initializes the parameter with the given {@link NDManager}, with given {@link DataType} for
* the given expected input shapes.
Expand All @@ -178,9 +197,13 @@ public void setInitializer(Initializer initializer) {
* @param dataType the datatype of the {@code Parameter}
*/
public void initialize(NDManager manager, DataType dataType) {
Objects.requireNonNull(initializer, "No initializer has been set");
Objects.requireNonNull(shape, "No parameter shape has been set");
if (!isInitialized()) {
// Params in a PtSymbolBlock is set during model loading and its isInitialized()=true.
// Shouldn't further initialize it.
Objects.requireNonNull(initializer, "No initializer has been set");
// Params in a PtSymbolBlock can have null shape, but are still initialized (has nonNull
// array)
Objects.requireNonNull(shape, "No parameter shape has been set");
array = initializer.initialize(manager, shape, dataType);
array.setName(name);
}
Expand Down
13 changes: 11 additions & 2 deletions api/src/main/java/ai/djl/nn/core/Linear.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,19 @@
* <p>It has the following shapes:
*
* <ul>
* <li>input X: [x1, x2, , xn, input_dim]
* <li>input X: [x1, x2, ..., xn, input_dim]
* <li>weight W: [units, input_dim]
* <li>Bias b: [units]
* <li>output Y: [x1, x2, …, xn, units]
* <li>output Y: [x1, x2, ..., xn, units]
* </ul>
*
* <p>It is most typically used with a simple batched 1D input. In that case, the shape would be:
*
* <ul>
* <li>input X: [batch_num, input_dim]
* <li>weight W: [units, input_dim]
* <li>Bias b: [units]
* <li>output Y: [batch_num, units]
* </ul>
*
* <p>The Linear block should be constructed using {@link Linear.Builder}.
Expand Down
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/training/evaluator/Accuracy.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions)
predictionReduced = predictionReduced.reshape(label.getShape());
} else {
// Multi-class, one-hot label
predictionReduced = prediction;
predictionReduced = prediction.argMax(axis);
label = label.argMax(axis);
}
// result of sum is int64 now
long total = label.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ public MaskedSoftmaxCrossEntropyLoss(String name) {
* @param name the name of the loss
* @param weight the weight to apply on the loss value, default 1
* @param classAxis the axis that represents the class probabilities, default -1
* @param sparseLabel whether labels are integer array or probabilities, default true
* @param fromLogit whether predictions are log probabilities or un-normalized numbers, default
* false
* @param sparseLabel whether labels are 1-D integer array of [batch_size] (false) or 2-D
* probabilities of [batch_size, n-class] (true), default true
* @param fromLogit if true, the inputs are assumed to be the numbers before being applied with
* softmax. Then logSoftmax will be applied to input, default false
*/
public MaskedSoftmaxCrossEntropyLoss(
String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ public SoftmaxCrossEntropyLoss() {
* @param name the name of the loss
*/
public SoftmaxCrossEntropyLoss(String name) {
// By default, fromLogit=true, means it takes the prediction before being
// applied softmax.
this(name, 1, -1, true, true);
}

Expand All @@ -52,10 +54,10 @@ public SoftmaxCrossEntropyLoss(String name) {
* @param name the name of the loss
* @param weight the weight to apply on the loss value, default 1
* @param classAxis the axis that represents the class probabilities, default -1
* @param sparseLabel whether labels are 1-D integer array or 2-D probabilities of [batch_size,
* n-class], default true
* @param fromLogit whether predictions are un-normalized numbers or log probabilities, if true,
* logSoftmax will be applied to input, default true
* @param sparseLabel whether labels are rank-1 integer array of [batch_size] (false) or rank-2
* one-hot of [batch_size, n-class] (true), default true
* @param fromLogit if true, the inputs are assumed to be the numbers before being applied with
* softmax. Then logSoftmax will be applied to input, default true
*/
public SoftmaxCrossEntropyLoss(
String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit) {
Expand Down
13 changes: 7 additions & 6 deletions api/src/main/java/ai/djl/training/optimizer/Adagrad.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.training.tracker.ParameterTracker;
import ai.djl.training.tracker.Tracker;

import java.util.Map;
Expand All @@ -40,7 +41,7 @@
*/
public class Adagrad extends Optimizer {

private Tracker learningRateTracker;
private ParameterTracker learningRateTracker;
private float epsilon;

private Map<String, Map<Device, NDArray>> history;
Expand All @@ -61,7 +62,7 @@ protected Adagrad(Builder builder) {
@Override
public void update(String parameterId, NDArray weight, NDArray grad) {
int t = updateCount(parameterId);
float newLearningRate = learningRateTracker.getNewValue(t);
float newLearningRate = learningRateTracker.getNewValue(parameterId, t);
float weightDecay = getWeightDecay();

if (Float.isNaN(newLearningRate)
Expand Down Expand Up @@ -98,7 +99,7 @@ public static Builder builder() {
/** The Builder to construct an {@link Adagrad} object. */
public static final class Builder extends OptimizerBuilder<Builder> {

private Tracker learningRateTracker = Tracker.fixed(0.001f);
private ParameterTracker learningRateTracker = Tracker.fixed(0.001f);
private float epsilon = 1e-8f;

Builder() {}
Expand All @@ -110,12 +111,12 @@ protected Builder self() {
}

/**
* Sets the {@link Tracker} for this optimizer.
* Sets the {@link ParameterTracker} for this optimizer.
*
* @param learningRateTracker the {@link Tracker} to be set
* @param learningRateTracker the {@link ParameterTracker} to be set
* @return this {@code Builder}
*/
public Builder optLearningRateTracker(Tracker learningRateTracker) {
public Builder optLearningRateTracker(ParameterTracker learningRateTracker) {
this.learningRateTracker = learningRateTracker;
return this;
}
Expand Down
Loading