Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow indexer to attach specific manager #1688

Merged
merged 4 commits into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ default void set(byte[] data) {
* index
*/
default void set(NDIndex index, NDArray value) {
getNDArrayInternal().getIndexer().set(this, index, value);
getNDArrayInternal().getIndexer(getManager()).set(this, index, value);
}

/**
Expand All @@ -466,7 +466,7 @@ default void set(NDIndex index, NDArray value) {
* @param value the value to replace with
*/
default void set(NDIndex index, Number value) {
getNDArrayInternal().getIndexer().set(this, index, value);
getNDArrayInternal().getIndexer(getManager()).set(this, index, value);
}

/**
Expand Down Expand Up @@ -498,7 +498,7 @@ default void set(NDArray index, Number value) {
* @throws IllegalArgumentException thrown if the index does not correspond to a single element
*/
default void setScalar(NDIndex index, Number value) {
getNDArrayInternal().getIndexer().setScalar(this, index, value);
getNDArrayInternal().getIndexer(getManager()).setScalar(this, index, value);
}

/**
Expand All @@ -508,7 +508,18 @@ default void setScalar(NDIndex index, Number value) {
* @return the partial {@code NDArray}
*/
default NDArray get(NDIndex index) {
return getNDArrayInternal().getIndexer().get(this, index);
return get(getManager(), index);
}

/**
* Returns a partial {@code NDArray}.
*
* @param manager the manager used to create the arrays
* @param index the section of this {@code NDArray} to return
* @return the partial {@code NDArray}
*/
default NDArray get(NDManager manager, NDIndex index) {
return getNDArrayInternal().getIndexer(manager).get(this, index);
}

/**
Expand Down Expand Up @@ -549,6 +560,18 @@ default NDArray get(long... indices) {
return get(new NDIndex(indices));
}

/**
* Returns a partial {@code NDArray}.
*
* @param manager the manager used to create the arrays
* @param indices the indices with each index corresponding to the dimensions and negative
* indices starting from the end
* @return the partial {@code NDArray}
*/
default NDArray get(NDManager manager, long... indices) {
return get(manager, new NDIndex(indices));
}

/**
* Returns a partial {@code NDArray} pointed by the indexed array. Given NDArray arr, NDArray
* idx, and long axis, the output is out_{ijk} = arr_{idx_{ijk}, j, k} if axis=0 or arr_{i,
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ public void attach(NDManager manager) {
/** {@inheritDoc} */
@Override
public void tempAttach(NDManager manager) {
detach();
NDManager original = this.manager;
detach();
this.manager = manager;
manager.tempAttachInternal(original, getUid(), this);
}
Expand Down
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,10 @@ default NDArray crop(int x, int y, int width, int height) {
/**
* Returns an {@link NDArrayIndexer}.
*
* @param manager the manager used to create the arrays
* @return an {@link NDArrayIndexer}
*/
NDArrayIndexer getIndexer();
NDArrayIndexer getIndexer(NDManager manager);

/**
* Returns elements chosen from the {@code NDArray} or the other {@code NDArray} depending on
Expand Down
6 changes: 2 additions & 4 deletions api/src/main/java/ai/djl/training/dataset/ArrayDataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,13 @@ public Record get(NDManager manager, long index) {
NDList datum = new NDList();
NDList label = new NDList();
for (NDArray array : data) {
datum.add(array.get(index));
datum.add(array.get(manager, index));
}
if (labels != null) {
for (NDArray array : labels) {
label.add(array.get(index));
label.add(array.get(manager, index));
}
}
datum.attach(manager);
label.attach(manager);
return new Record(datum, label);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -984,8 +984,8 @@ public NDArray randomColorJitter(

/** {@inheritDoc} */
@Override
public NDArrayIndexer getIndexer() {
return new MxNDArrayIndexer(array.getManager());
public NDArrayIndexer getIndexer(NDManager manager) {
return new MxNDArrayIndexer((MxNDManager) manager);
}

////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ public NDArray get(NDArray array, NDIndexFullPick fullPick) {
params.addParam("axis", fullPick.getAxis());
params.addParam("keepdims", true);
params.add("mode", "wrap");
return array.getManager()
.invoke("pick", new NDList(array, fullPick.getIndices()), params)
return manager.invoke("pick", new NDList(array, fullPick.getIndices()), params)
.singletonOrThrow();
}

Expand All @@ -51,7 +50,7 @@ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) {
params.addTupleParam("end", fullSlice.getMax());
params.addTupleParam("step", fullSlice.getStep());

NDArray result = ((MxNDManager) array.getManager()).invoke("_npi_slice", array, params);
NDArray result = manager.invoke("_npi_slice", array, params);
int[] toSqueeze = fullSlice.getToSqueeze();
if (toSqueeze.length > 0) {
NDArray oldResult = result;
Expand Down Expand Up @@ -83,12 +82,11 @@ public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) {
prepareValue.add(prepareValue.peek().reshape(targetShape));
prepareValue.add(prepareValue.peek().broadcast(fullSlice.getShape()));

array.getManager()
.invoke(
"_npi_slice_assign",
new NDArray[] {array, prepareValue.peek()},
new NDArray[] {array},
params);
manager.invoke(
"_npi_slice_assign",
new NDArray[] {array, prepareValue.peek()},
new NDArray[] {array},
params);
for (NDArray toClean : prepareValue) {
if (toClean != value) {
toClean.close();
Expand All @@ -105,11 +103,7 @@ public void set(NDArray array, NDIndexFullSlice fullSlice, Number value) {
params.addTupleParam("end", fullSlice.getMax());
params.addTupleParam("step", fullSlice.getStep());
params.addParam("scalar", value);
array.getManager()
.invoke(
"_npi_slice_assign_scalar",
new NDArray[] {array},
new NDArray[] {array},
params);
manager.invoke(
"_npi_slice_assign_scalar", new NDArray[] {array}, new NDArray[] {array}, params);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ public void set(Buffer data) {

/** {@inheritDoc} */
@Override
public NDArray get(long... indices) {
return JniUtils.getItem(this, indices);
public NDArray get(NDManager manager, long... indices) {
return JniUtils.getItem(this, indices, (PtNDManager) manager);
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -274,8 +274,8 @@ public void attach(NDManager manager) {
/** {@inheritDoc} */
@Override
public void tempAttach(NDManager manager) {
detach();
NDManager original = this.manager;
detach();
this.manager = (PtNDManager) manager;
manager.tempAttachInternal(original, getUid(), this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,8 @@ public NDArray randomColorJitter(

/** {@inheritDoc} */
@Override
public NDArrayIndexer getIndexer() {
return new PtNDArrayIndexer(array.getManager());
public NDArrayIndexer getIndexer(NDManager manager) {
return new PtNDArrayIndexer((PtNDManager) manager);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) {
long[] min = fullSlice.getMin();
long[] max = fullSlice.getMax();
long[] step = fullSlice.getStep();
try (PtNDArray res = JniUtils.index(manager.from(array), min, max, step)) {
try (PtNDArray res = JniUtils.index(manager.from(array), min, max, step, manager)) {
return res.squeeze(fullSlice.getToSqueeze());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,13 @@ public static PtNDArray slice(PtNDArray ndArray, long dim, long start, long stop
}

public static PtNDArray index(
PtNDArray ndArray, long[] minIndices, long[] maxIndices, long[] stepIndices) {
PtNDArray ndArray,
long[] minIndices,
long[] maxIndices,
long[] stepIndices,
PtNDManager manager) {
return new PtNDArray(
ndArray.getManager(),
manager,
PyTorchLibrary.LIB.torchIndex(
ndArray.getHandle(), minIndices, maxIndices, stepIndices));
}
Expand Down Expand Up @@ -413,18 +417,16 @@ public static void booleanMaskSet(PtNDArray ndArray, PtNDArray value, PtNDArray
ndArray.getHandle(), value.getHandle(), indicesNd.getHandle());
}

public static PtNDArray getItem(PtNDArray ndArray, long[] indices) {
public static PtNDArray getItem(PtNDArray ndArray, long[] indices, PtNDManager manager) {
// use a specialized API here
// due to significant performance gain
// for commonly used data loading call
if (indices.length == 1) {
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices[0]));
manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices[0]));
}
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices));
manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices));
}

public static PtNDArray clone(PtNDArray ndArray) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ public void attach(NDManager manager) {
/** {@inheritDoc} */
@Override
public void tempAttach(NDManager manager) {
detach();
NDManager original = this.manager;
detach();
this.manager = (TfNDManager) manager;
manager.tempAttachInternal(original, getUid(), this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDUtils;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.internal.NDArrayEx;
Expand Down Expand Up @@ -555,8 +556,8 @@ public NDArray randomColorJitter(

/** {@inheritDoc} */
@Override
public NDArrayIndexer getIndexer() {
return new TfNDArrayIndexer(array.getManager());
public NDArrayIndexer getIndexer(NDManager manager) {
return new TfNDArrayIndexer((TfNDManager) manager);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,12 @@ public NDArray get(NDArray array, NDIndexFullPick fullPick) {
@Override
public NDArray get(NDArray array, NDIndexFullSlice fullSlice) {
array = manager.from(array);
TfNDManager tfManager = (TfNDManager) array.getManager();
int[] toSqueeze = fullSlice.getToSqueeze();
try (NDArray begin = tfManager.create(fullSlice.getMin());
NDArray end = tfManager.create(fullSlice.getMax());
NDArray step = tfManager.create(fullSlice.getStep())) {
try (NDArray begin = manager.create(fullSlice.getMin());
NDArray end = manager.create(fullSlice.getMax());
NDArray step = manager.create(fullSlice.getStep())) {
NDArray result =
tfManager
.opExecutor("StridedSlice")
manager.opExecutor("StridedSlice")
.addInput(array)
.addInput(begin)
.addInput(end)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.integration.tests.ndarray;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;
import org.testng.Assert;
import org.testng.annotations.Test;

public class NDArrayAttachmentTest {

@Test
public void testReturnResource() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray array3x4 = manager.ones(new Shape(3, 4));
try (NDManager subManager = NDManager.newBaseManager()) {
array3x4.tempAttach(subManager);
Assert.assertEquals(array3x4.getManager(), subManager);
}
Assert.assertEquals(array3x4.getManager(), manager);
}
}

@Test
public void testIndexationUsesSpecificManager() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray array3x4 = manager.ones(new Shape(3, 4));
array3x4.setName("Test()");
NDArray array4 = array3x4.get(1);
Assert.assertEquals(array4.getManager(), manager);
try (NDManager subManager = NDManager.newBaseManager()) {
NDArray array4sub1 = array3x4.get(subManager, 1);
Assert.assertEquals(array4sub1.getManager(), subManager);
NDArray array4sub2 = array3x4.get(subManager, new NDIndex(1));
Assert.assertEquals(array4sub2.getManager(), subManager);
}
}
}
}