From 78f5d2d0b4caf385c3e18e21f04c8abf52020e8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=97=AD?= Date: Mon, 26 Feb 2024 10:15:09 +0800 Subject: [PATCH] Fixes cases where the getEngine method in the EngineProvider class returns null when called concurrently. --- .../main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java | 6 ++---- .../src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java | 6 ++---- .../src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java | 6 ++---- .../java/ai/djl/onnxruntime/engine/OrtEngineProvider.java | 6 ++---- .../java/ai/djl/paddlepaddle/engine/PpEngineProvider.java | 6 ++---- .../main/java/ai/djl/pytorch/engine/PtEngineProvider.java | 6 ++---- .../java/ai/djl/tensorflow/engine/TfEngineProvider.java | 6 ++---- .../main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java | 6 ++---- .../java/ai/djl/tflite/engine/TfLiteEngineProvider.java | 6 ++---- 9 files changed, 18 insertions(+), 36 deletions(-) diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java index 583cd8132b2d..f1a3a5032cfb 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java @@ -19,7 +19,6 @@ public class LgbmEngineProvider implements EngineProvider { private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -36,10 +35,9 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { + if (engine == null) { synchronized (LgbmEngineProvider.class) { - if (!initialized) { - initialized = true; + if (engine == null) { engine = LgbmEngine.newInstance(); } } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java index 8b534d5196c3..0669ebb760b1 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java @@ -19,7 +19,6 @@ public class XgbEngineProvider implements EngineProvider { private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -36,10 +35,9 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { + if (engine == null) { synchronized (XgbEngineProvider.class) { - if (!initialized) { - initialized = true; + if (engine == null) { engine = XgbEngine.newInstance(); } } diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java index 2a5ab970560d..eb1c99d83be3 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java @@ -19,7 +19,6 @@ public class MxEngineProvider implements EngineProvider { private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -36,10 +35,9 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { + if (engine == null) { synchronized (MxEngineProvider.class) { - if (!initialized) { - initialized = true; + if (engine == null) { engine = MxEngine.newInstance(); } } diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java index 5616eb80edbe..137a40367402 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java @@ -19,7 +19,6 @@ public class OrtEngineProvider implements EngineProvider { private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -36,10 +35,9 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { + if (engine == null) { synchronized (OrtEngineProvider.class) { - if (!initialized) { - initialized = true; + if (engine == null) { engine = OrtEngine.newInstance(); } } diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java index e2fb86974f52..7ad12d5c8f94 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java @@ -19,7 +19,6 @@ public class PpEngineProvider implements EngineProvider { private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -36,10 +35,9 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { + if (engine == null) { synchronized (PpEngineProvider.class) { - if (!initialized) { - initialized = true; + if (engine == null) { engine = PpEngine.newInstance(); } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java index 24be3e91d7ad..6f408fdac485 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java @@ -19,7 +19,6 @@ public class PtEngineProvider implements EngineProvider { private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -36,10 +35,9 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { + if (engine == null) { synchronized (PtEngineProvider.class) { - if (!initialized) { - initialized = true; + if (engine == null) { engine = PtEngine.newInstance(); } } diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java index fa7813a49fbd..fe84c9d3a196 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java @@ -19,7 +19,6 @@ public class TfEngineProvider implements EngineProvider { private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -36,10 +35,9 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { + if (engine == null) { synchronized (TfEngineProvider.class) { - if (!initialized) { - initialized = true; + if (engine == null) { engine = TfEngine.newInstance(); } } diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java index 8c90859c6c66..389fb74aef6d 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java @@ -19,7 +19,6 @@ public class TrtEngineProvider implements EngineProvider { private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -36,10 +35,9 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { + if (engine == null) { synchronized (TrtEngineProvider.class) { - if (!initialized) { - initialized = true; + if (engine == null) { engine = TrtEngine.newInstance(); } } diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java index b46cad53b99f..98da96dac330 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java @@ -19,7 +19,6 @@ public class TfLiteEngineProvider implements EngineProvider { private volatile Engine engine; // NOPMD - private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -36,10 +35,9 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (!initialized) { + if (engine == null) { synchronized (TfLiteEngineProvider.class) { - if (!initialized) { - initialized = true; + if (engine == null) { engine = TfLiteEngine.newInstance(); } }