From f780c95ee5f9e0fe9e9502c14cde8a8d3986dc99 Mon Sep 17 00:00:00 2001 From: Ewan Date: Sat, 11 May 2024 15:23:11 +0800 Subject: [PATCH 1/2] [Lgbm] fix LgbmNDArray replaced.close() release data problem --- .../java/ai/djl/ml/lightgbm/LgbmNDArray.java | 33 ++++++++++++------- ...orch_jni_PyTorchLibrary_torch_pointwise.cc | 2 +- .../src/main/native/djl_pytorch_utils.h | 3 +- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDArray.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDArray.java index 5203457b623..c79e0196655 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDArray.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDArray.java @@ -39,8 +39,8 @@ public class LgbmNDArray extends NDArrayAdapter { private AtomicReference handle; private int typeConstant; - private SWIGTYPE_p_float floatData; - private SWIGTYPE_p_double doubleData; + private AtomicReference floatDataRef; + private AtomicReference doubleDataRef; LgbmNDArray( NDManager manager, @@ -53,6 +53,8 @@ public class LgbmNDArray extends NDArrayAdapter { this.format = SparseFormat.DENSE; manager.attachInternal(uid, this); handle = new AtomicReference<>(); + floatDataRef = new AtomicReference<>(); + doubleDataRef = new AtomicReference<>(); } /** {@inheritDoc} */ @@ -82,19 +84,19 @@ public SWIGTYPE_p_void getHandle() { if (getDataType() == DataType.FLOAT32) { typeConstant = lightgbmlibConstants.C_API_DTYPE_FLOAT32; FloatBuffer d1 = toByteBuffer().asFloatBuffer(); - floatData = lightgbmlib.new_floatArray(size); + floatDataRef.set(lightgbmlib.new_floatArray(size)); for (int i = 0; i < size; i++) { - lightgbmlib.floatArray_setitem(floatData, i, d1.get(i)); + lightgbmlib.floatArray_setitem(floatDataRef.get(), i, d1.get(i)); } - handle.set(lightgbmlib.float_to_voidp_ptr(floatData)); + handle.set(lightgbmlib.float_to_voidp_ptr(floatDataRef.get())); } else if (getDataType() == DataType.FLOAT64) { typeConstant = lightgbmlibConstants.C_API_DTYPE_FLOAT64; DoubleBuffer d1 = toByteBuffer().asDoubleBuffer(); - doubleData = lightgbmlib.new_doubleArray(size); + doubleDataRef.set(lightgbmlib.new_doubleArray(size)); for (int i = 0; i < size; i++) { - lightgbmlib.doubleArray_setitem(doubleData, i, d1.get(i)); + lightgbmlib.doubleArray_setitem(doubleDataRef.get(), i, d1.get(i)); } - handle.set(lightgbmlib.double_to_voidp_ptr(doubleData)); + handle.set(lightgbmlib.double_to_voidp_ptr(doubleDataRef.get())); } else { throw new IllegalArgumentException( "The LightGBM operation can only be performed with a Float32 or Float64" @@ -151,18 +153,21 @@ public ByteBuffer toByteBuffer() { /** {@inheritDoc} */ @Override public void intern(NDArray replaced) { + LgbmNDArray array = (LgbmNDArray) replaced; + + final SWIGTYPE_p_float floatData = + floatDataRef.getAndSet(array.floatDataRef.getAndSet(null)); if (floatData != null) { lightgbmlib.delete_floatArray(floatData); } + final SWIGTYPE_p_double doubleData = + doubleDataRef.getAndSet(array.doubleDataRef.getAndSet(null)); if (doubleData != null) { lightgbmlib.delete_doubleArray(doubleData); } - LgbmNDArray array = (LgbmNDArray) replaced; + handle.set(array.handle.getAndSet(null)); data = array.data; - handle = array.handle; format = array.format; - floatData = array.floatData; - doubleData = array.doubleData; typeConstant = array.typeConstant; shape = array.shape; dataType = array.dataType; @@ -180,11 +185,15 @@ public void detach() { @Override public void close() { super.close(); + final SWIGTYPE_p_float floatData = floatDataRef.getAndSet(null); if (floatData != null) { lightgbmlib.delete_floatArray(floatData); } + final SWIGTYPE_p_double doubleData = doubleDataRef.getAndSet(null); if (doubleData != null) { lightgbmlib.delete_doubleArray(doubleData); } + handle.set(null); + data = null; } } diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc index ccf2616dc65..251966570f7 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc @@ -356,7 +356,7 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAtan(JNIEnv* } JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAtan2( -JNIEnv* env, jobject jthis, jlong jself, jlong jother) { + JNIEnv* env, jobject jthis, jlong jself, jlong jother) { API_BEGIN() const auto* self_ptr = reinterpret_cast(jself); const auto* other_ptr = reinterpret_cast(jother); diff --git a/engines/pytorch/pytorch-native/src/main/native/djl_pytorch_utils.h b/engines/pytorch/pytorch-native/src/main/native/djl_pytorch_utils.h index 9683681804e..0426ac85cb2 100644 --- a/engines/pytorch/pytorch-native/src/main/native/djl_pytorch_utils.h +++ b/engines/pytorch/pytorch-native/src/main/native/djl_pytorch_utils.h @@ -18,8 +18,9 @@ #include #include #include -#include + #include +#include #include "djl_pytorch_jni_log.h" From f00f8f18f2fca76b7c2ddf7d038ca0ce908161df Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sat, 11 May 2024 12:52:36 -0700 Subject: [PATCH 2/2] Fix code code convention --- .../src/main/java/ai/djl/ml/lightgbm/LgbmNDArray.java | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDArray.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDArray.java index c79e0196655..310a0d059ad 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDArray.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmNDArray.java @@ -155,13 +155,11 @@ public ByteBuffer toByteBuffer() { public void intern(NDArray replaced) { LgbmNDArray array = (LgbmNDArray) replaced; - final SWIGTYPE_p_float floatData = - floatDataRef.getAndSet(array.floatDataRef.getAndSet(null)); + SWIGTYPE_p_float floatData = floatDataRef.getAndSet(array.floatDataRef.getAndSet(null)); if (floatData != null) { lightgbmlib.delete_floatArray(floatData); } - final SWIGTYPE_p_double doubleData = - doubleDataRef.getAndSet(array.doubleDataRef.getAndSet(null)); + SWIGTYPE_p_double doubleData = doubleDataRef.getAndSet(array.doubleDataRef.getAndSet(null)); if (doubleData != null) { lightgbmlib.delete_doubleArray(doubleData); } @@ -185,11 +183,11 @@ public void detach() { @Override public void close() { super.close(); - final SWIGTYPE_p_float floatData = floatDataRef.getAndSet(null); + SWIGTYPE_p_float floatData = floatDataRef.getAndSet(null); if (floatData != null) { lightgbmlib.delete_floatArray(floatData); } - final SWIGTYPE_p_double doubleData = doubleDataRef.getAndSet(null); + SWIGTYPE_p_double doubleData = doubleDataRef.getAndSet(null); if (doubleData != null) { lightgbmlib.delete_doubleArray(doubleData); }