diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index cceb5efd494..b27cd0ac1d6 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -40,7 +40,7 @@ public abstract class NDArrayAdapter implements NDArray { protected NDManager manager; protected NDManager alternativeManager; - private NDArray alternativeArray; + protected NDArray alternativeArray; protected Shape shape; protected DataType dataType; diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java index 69eb6914b0e..990fad7128f 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java @@ -88,14 +88,33 @@ public ByteBuffer toByteBuffer(boolean tryDirect) { /** {@inheritDoc} */ @Override public void intern(NDArray replaced) { + if (!(replaced instanceof XgbNDArray)) { + throw new IllegalArgumentException( + "The replaced NDArray must be an instance of XgbNDArray."); + } + XgbNDArray array = (XgbNDArray) replaced; + if (isReleased()) { + throw new IllegalArgumentException("This array is already closed"); + } + if (replaced.isReleased()) { + throw new IllegalArgumentException("This target array is already closed"); + } + if (handle != null && handle.get() != 0L) { long pointer = handle.getAndSet(0L); JniUtils.deleteDMatrix(pointer); } - XgbNDArray array = (XgbNDArray) replaced; + if (alternativeArray != null) { + alternativeArray.close(); + } + data = array.data; handle = array.handle; format = array.format; + alternativeArray = array.alternativeArray; + array.handle = null; + array.alternativeArray = null; + array.close(); } /** {@inheritDoc} */