Skip to content

Commit

Permalink
[pytorch] Add NDList to IValue unit test (#1762)
Browse files Browse the repository at this point in the history
The testing code can be used as example as well

Change-Id: I693624fd111afb0a5547d619a03d2cb15831af1b
  • Loading branch information
frankfliu authored Jul 1, 2022
1 parent 22cb398 commit 22a88b0
Showing 1 changed file with 59 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,27 @@
*/
package ai.djl.pytorch.jni;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.util.Arrays;

public class IValueUtilsTest {

@Test
public void getInputsTestTupleSyntax() {
try (PtNDManager manager = (PtNDManager) NDManager.newBaseManager()) {
PtNDArray array1 = (PtNDArray) manager.zeros(new Shape(1));
array1.setName("Test()");
PtNDArray array2 = (PtNDArray) manager.ones(new Shape(1));
array2.setName("Test()");
public void testTuple() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray array1 = manager.zeros(new Shape(1));
array1.setName("input1()");
NDArray array2 = manager.ones(new Shape(1));
array2.setName("input1()");
NDList input = new NDList(array1, array2);
// the NDList is mapped to (input1: Tuple(Tensor))
input.attach(manager);

IValue[] iValues = IValueUtils.getInputs(input);
Expand All @@ -42,4 +43,54 @@ public void getInputsTestTupleSyntax() {
Arrays.stream(iValues).forEach(IValue::close);
}
}

@Test
public void testMapOfTensor() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray array1 = manager.zeros(new Shape(1));
array1.setName("input1.key1");
NDArray array2 = manager.ones(new Shape(1));
array2.setName("input1.key2");
NDArray array3 = manager.zeros(new Shape(1));
array3.setName("input2.key1");
NDArray array4 = manager.ones(new Shape(1));
array4.setName("input2.key2");
NDArray array5 = manager.ones(new Shape(1));
array5.setName("input2.key3");
NDList input = new NDList(array1, array2, array3, array4, array5);
// the NDList is mapped to (input1: Dict(str, Tensor), input2: Dict(str, Tensor))
// the first part of NDArray name is the variable name of the inputs
// the 2nd part is the key of each value in the dict
input.attach(manager);

IValue[] iValues = IValueUtils.getInputs(input);
Assert.assertEquals(iValues.length, 2);
Assert.assertTrue(iValues[0].isMap());
Assert.assertEquals(iValues[1].toIValueMap().size(), 3);

Arrays.stream(iValues).forEach(IValue::close);
}
}

@Test
public void testListOfTensor() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray array1 = manager.zeros(new Shape(1));
array1.setName("input1[]");
NDArray array2 = manager.ones(new Shape(1));
array2.setName("input1[]");
NDArray array3 = manager.ones(new Shape(1));
array3.setName("input2");
NDList input = new NDList(array1, array2, array3);
// the NDList is mapped to (input1: list(Tensor), input2: Tensor)
input.attach(manager);

IValue[] iValues = IValueUtils.getInputs(input);
Assert.assertEquals(iValues.length, 2);
Assert.assertTrue(iValues[0].isList());
Assert.assertEquals(iValues[0].toIValueArray().length, 2);

Arrays.stream(iValues).forEach(IValue::close);
}
}
}

0 comments on commit 22a88b0

Please sign in to comment.