Skip to content

Commit

Permalink
Fixes bug in NDArray.oneHot() API (#1661)
Browse files Browse the repository at this point in the history
* Fixes bug in NDArray.oneHot() API

The on and off value parameters were misplaced.

* Add unit test for onehot encoding

Change-Id: I2ad5ea81a10e5bf60810f6cfe3d72e3c7e34687a

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
  • Loading branch information
pdradx and frankfliu authored May 18, 2022
1 parent e9a14c8 commit 8b85bc3
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -4683,7 +4683,7 @@ default NDArray oneHot(int depth) {
* href=https://d2l.djl.ai/chapter_linear-networks/softmax-regression.html#classification-problems>Classification-problems</a>
*/
default NDArray oneHot(int depth, DataType dataType) {
return oneHot(depth, 0f, 1f, dataType);
return oneHot(depth, 1f, 0f, dataType);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,7 @@ public void testOneHot() {
manager.create(
new float[][] {{0f, 1f, 0f}, {1f, 0f, 0f}, {0f, 0f, 1f}, {1f, 0f, 0f}});
Assert.assertEquals(array.oneHot(3), expected);
Assert.assertEquals(array.oneHot(3, DataType.FLOAT32), expected);
// test with all parameters
array = manager.create(new int[] {1, 0, 2, 0});
expected = manager.create(new int[][] {{1, 8, 1}, {8, 1, 1}, {1, 1, 8}, {8, 1, 1}});
Expand Down

0 comments on commit 8b85bc3

Please sign in to comment.