Skip to content

Commit

Permalink
unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Oct 10, 2022
1 parent 80f0e49 commit ed0e7a6
Show file tree
Hide file tree
Showing 13 changed files with 190 additions and 54 deletions.
5 changes: 5 additions & 0 deletions api/src/main/java/ai/djl/modality/cv/transform/ToOneHot.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
public class ToOneHot implements Transform {
private final int numClass;

/**
* Creates a {@code toOneHot} {@link Transform} that converts the sparse label to one-hot label.
*
* @param numClass number of classes
*/
public ToOneHot(int numClass) {
this.numClass = numClass;
}
Expand Down
37 changes: 37 additions & 0 deletions api/src/main/java/ai/djl/modality/cv/transform/Transpose.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.transform;

import ai.djl.ndarray.NDArray;
import ai.djl.translate.Transform;

/** A {@link Transform} that transposes the image. */
public class Transpose implements Transform {

private int[] axis;

/**
* Creates a {@code Resize} {@link Transform} that transposes to the given size.
*
* @param axis the new order of axis
*/
public Transpose(int... axis) {
this.axis = axis;
}

/** {@inheritDoc} */
@Override
public NDArray transform(NDArray array) {
return array.transpose(axis);
}
}
9 changes: 5 additions & 4 deletions api/src/main/java/ai/djl/modality/cv/util/NDImageUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ private NDImageUtils() {}
*
* @param image the image to resize
* @param size the desired size
* @return the resized NDList
* @return the resized NDArray
*/
public static NDArray resize(NDArray image, int size) {
return resize(image, size, size, Image.Interpolation.BILINEAR);
Expand All @@ -42,7 +42,7 @@ public static NDArray resize(NDArray image, int size) {
* @param image the image to resize
* @param width the desired width
* @param height the desired height
* @return the resized NDList
* @return the resized NDArray
*/
public static NDArray resize(NDArray image, int width, int height) {
return resize(image, width, height, Image.Interpolation.BILINEAR);
Expand All @@ -55,7 +55,7 @@ public static NDArray resize(NDArray image, int width, int height) {
* @param width the desired width
* @param height the desired height
* @param interpolation the desired interpolation
* @return the resized NDList
* @return the resized NDArray
*/
public static NDArray resize(
NDArray image, int width, int height, Image.Interpolation interpolation) {
Expand Down Expand Up @@ -136,6 +136,7 @@ public static NDArray toTensor(NDArray image) {
* <p>Converts the labels {@link NDArray} to one-hot labels.
*
* @param label the label to convert
* @param numClass the number of classes
* @return the converted label
*/
public static NDArray toOneHot(NDArray label, int numClass) {
Expand Down Expand Up @@ -346,6 +347,6 @@ public static boolean isCHW(Shape shape) {
} else if (shape.get(2) == 1 || shape.get(2) == 3) {
return false;
}
throw new IllegalArgumentException("Image is not CHW or HWC");
throw new IllegalArgumentException("Image is neither CHW nor HWC");
}
}
4 changes: 4 additions & 0 deletions api/src/main/java/ai/djl/ndarray/internal/NDFormat.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ public static String format(
sb.append(" hasGradient");
}

if (DEBUG) {
sb.append(" ->");
}

NDFormat format;
DataType dataType = array.getDataType();

Expand Down
11 changes: 11 additions & 0 deletions api/src/main/java/ai/djl/nn/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ public void setShape(Shape shape) {
this.shape = shape;
}

/**
* Gets the shape of this {@code Parameter}.
*
* @return the shape of this {@code Parameter}
*/
public Shape getShape() {
return this.shape;
}
Expand Down Expand Up @@ -173,6 +178,12 @@ public void setInitializer(Initializer initializer) {
this.initializer = initializer;
}

/**
* Gets the {@link Initializer} for this {@code Parameter}, if not already set. If overwrite
* flag is true, sets the initializer regardless.
*
* @return the initializer of this {@code Parameter}
*/
public Initializer getInitializer() {
return initializer;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ public SoftmaxCrossEntropyLoss(String name) {
* @param name the name of the loss
* @param weight the weight to apply on the loss value, default 1
* @param classAxis the axis that represents the class probabilities, default -1
* @param sparseLabel whether labels are 1-D integer array of [batch_size] (false) or 2-D
* probabilities of [batch_size, n-class] (true), default true
* @param sparseLabel whether labels are rank-1 integer array of [batch_size] (false) or rank-2
* one-hot of [batch_size, n-class] (true), default true
* @param fromLogit if true, the inputs are assumed to be the numbers before being applied with
* softmax. Then logSoftmax will be applied to input, default true
*/
Expand All @@ -73,6 +73,7 @@ public SoftmaxCrossEntropyLoss(
/**
* Creates a new instance of {@code SoftmaxCrossEntropyLoss} from the output of softmax.
*
* @param name the name of the loss function
* @param fromSoftmax the input prediction is from the output of softmax, default false
*/
public SoftmaxCrossEntropyLoss(String name, boolean fromSoftmax) {
Expand All @@ -97,8 +98,6 @@ public NDArray evaluate(NDList label, NDList prediction) {
NDArray loss;
NDArray lab = label.singletonOrThrow();
if (sparseLabel) {
// TODO: should support only one label and the transformation of label, if needed, is
// done outside. Keep this function unique-purposed.
NDIndex pickIndex =
new NDIndex()
.addAllDim(Math.floorMod(classAxis, pred.getShape().dimension()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ public float getNewValue(int numUpdate) {
return valueMap.getOrDefault(this.parameterId, this.value);
}

/**
* Set parameterId for this Tracker.
*
* @param parameterId the parameter Id
*/
public void setParameterId(String parameterId) {
this.parameterId = parameterId;
}
Expand All @@ -59,20 +64,40 @@ public static Builder builder() {
public static final class Builder {

private float value;
private Map<String, Float> valueMap = new ConcurrentHashMap<String, Float>();
private Map<String, Float> valueMap = new ConcurrentHashMap<>();

/** Create a builder for {@link FixedPerVarTracker}. */
private Builder() {}

/**
* Set the default learning rate.
*
* @param value the default learning rate
* @return builder
*/
public Builder setDefaultValue(float value) {
this.value = value;
return this;
}

/**
* Add a kv pair of parameter and its learning rate.
*
* @param parameterId the parameter id
* @param value the default learning rate
* @return builder
*/
public Builder put(String parameterId, float value) {
this.valueMap.put(parameterId, value);
return this;
}

/**
* Add kv pairs of parameter and its learning rate.
*
* @param valueMap stores parameterId and learning rate
* @return builder
*/
public Builder putAll(Map<String, Float> valueMap) {
this.valueMap.putAll(valueMap);
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_loadExtraDir(
int size = vec_arg.size();
argv.reserve(vec_arg.size());
for (auto& arg : vec_arg) {
argv.push_back(const_cast<char*>(arg.data()));
argv.emplace_back(const_cast<char*>(arg.data()));
}
char** array = argv.data();
std::cout << "Pending Paddle fix to proceed with the option" << std::endl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,30 +571,22 @@ public PtNDArray resize(int width, int height, int interpolation) {
if (result.getDataType() != DataType.FLOAT32) {
result = result.toType(DataType.FLOAT32, true);
}

int dim = result.getShape().dimension();
if (dim == 3) {
result = result.expandDims(0);
}
boolean orderHWC = false;
if (result.getShape().get(1) != 3) {
// HWC order -> CHW order
orderHWC = true;
result = result.transpose(0, 3, 1, 2);
}
// change from HWC to CHW order
result = result.transpose(0, 3, 1, 2);
result =
JniUtils.interpolate(
array.getManager().from(result),
new long[] {height, width},
getInterpolationMode(interpolation),
false);
if (orderHWC) {
result = result.transpose(0, 2, 3, 1);
}
array.getManager().from(result),
new long[] {height, width},
getInterpolationMode(interpolation),
false)
.transpose(0, 2, 3, 1);
if (dim == 3) {
result = result.squeeze(0);
}

array.attach(subManager.getParentManager());
result.attach(subManager.getParentManager());
return (PtNDArray) result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleGetPar
auto* module_ptr = reinterpret_cast<torch::jit::script::Module*>(jhandle);
std::vector<jlong> jptrs;
for (const auto& tensor : module_ptr->parameters()) {
jptrs.push_back(reinterpret_cast<uintptr_t>(new torch::Tensor(tensor)));
jptrs.emplace_back(reinterpret_cast<uintptr_t>(new torch::Tensor(tensor)));
}
size_t len = jptrs.size();
jlongArray jarray = env->NewLongArray(len);
Expand All @@ -277,7 +277,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_moduleGetP
auto* module_ptr = reinterpret_cast<torch::jit::script::Module*>(jhandle);
std::vector<std::string> jptrs;
for (const auto& named_tensor : module_ptr->named_parameters()) {
jptrs.push_back(named_tensor.name);
jptrs.emplace_back(named_tensor.name);
}
return djl::utils::jni::GetStringArrayFromVec(env, jptrs);
API_END_RETURN()
Expand Down
1 change: 0 additions & 1 deletion examples/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dependencies {
implementation project(":basicdataset")
implementation project(":model-zoo")
implementation project(":extensions:timeseries")
implementation 'org.testng:testng:7.1.0'

runtimeOnly project(":engines:pytorch:pytorch-model-zoo")
runtimeOnly project(":engines:tensorflow:tensorflow-model-zoo")
Expand Down
Loading

0 comments on commit ed0e7a6

Please sign in to comment.