From 6141c480e7047b17b21e59eccfa7c58b0b916646 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Tue, 9 Jan 2024 12:10:00 -0800 Subject: [PATCH] Creates DJL manual engine initialization (#2885) * Creates DJL manual engine initialization fixes #2875 reverts #2876 This adds new support for DJL manual initialization of engines to support `DJL_ENGINE_MANUAL_INIT`. Once done, no engines providers will be found or loaded on startup. Instead, they can be added manually by: ```java PtEngineProvider provider = new PtEngineProvider(); provider.getEngine(); // Optional, throws exception if the provider can not load Engine.registerEngine(provider); Engine.setDefaultEngine(provider.getEngineName()); // Optional, sets as default ``` * Revert "[tensorflow] Revert InstanceHolder for TensorFlow engine (#2884)" This reverts commit 586bb0709dbde3950509f217c80a9c6b829a06fd. * Revert "[api] Replace double-check singlton with lazy initialization (#2826)" This reverts commit 39278672bebe9c5f9d590d73914a2a9e2f53fb91. * Make engines initialized This makes several updates: - engines will now initialize once per instance of EngineProvider rather than re-attempt to initialize - Registering an engine can overwrite the existing one - All engines now use the synchronized form rather than the static instance holder. This allows them to have multiple versions and be local to the instance rather than global (but the instance is saved globally) * Throws Exception on bad getEngine * Removes unnecessary check --- api/src/main/java/ai/djl/engine/Engine.java | 57 ++++++++++++++----- docs/development/troubleshooting.md | 5 ++ .../djl/ml/lightgbm/LgbmEngineProvider.java | 17 ++++-- .../ai/djl/ml/xgboost/XgbEngineProvider.java | 17 ++++-- .../ai/djl/mxnet/engine/MxEngineProvider.java | 17 ++++-- .../onnxruntime/engine/OrtEngineProvider.java | 17 ++++-- .../paddlepaddle/engine/PpEngineProvider.java | 17 ++++-- .../djl/pytorch/engine/PtEngineProvider.java | 8 ++- .../tensorflow/engine/TfEngineProvider.java | 8 ++- .../tensorrt/engine/TrtEngineProvider.java | 17 ++++-- .../ai/djl/tensorrt/engine/TrtEngineTest.java | 2 +- .../djl/tensorrt/engine/TrtNDManagerTest.java | 2 +- .../ai/djl/tensorrt/integration/TrtTest.java | 6 +- .../tflite/engine/TfLiteEngineProvider.java | 17 ++++-- 14 files changed, 148 insertions(+), 59 deletions(-) diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index 8a1fc8871ac..a799c70f600 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -59,7 +59,7 @@ public abstract class Engine { private static final Map ALL_ENGINES = new ConcurrentHashMap<>(); - private static final String DEFAULT_ENGINE = initEngine(); + private static String defaultEngine = initEngine(); private static final Pattern PATTERN = Pattern.compile("KEY|TOKEN|PASSWORD", Pattern.CASE_INSENSITIVE); @@ -69,6 +69,10 @@ public abstract class Engine { private Integer seed; private static synchronized String initEngine() { + if (Boolean.parseBoolean(Utils.getenv("DJL_ENGINE_MANUAL_INIT"))) { + return null; + } + ServiceLoader loaders = ServiceLoader.load(EngineProvider.class); for (EngineProvider provider : loaders) { registerEngine(provider); @@ -80,21 +84,21 @@ private static synchronized String initEngine() { } String def = System.getProperty("ai.djl.default_engine"); - String defaultEngine = Utils.getenv("DJL_DEFAULT_ENGINE", def); - if (defaultEngine == null || defaultEngine.isEmpty()) { + String newDefaultEngine = Utils.getenv("DJL_DEFAULT_ENGINE", def); + if (newDefaultEngine == null || newDefaultEngine.isEmpty()) { int rank = Integer.MAX_VALUE; for (EngineProvider provider : ALL_ENGINES.values()) { if (provider.getEngineRank() < rank) { - defaultEngine = provider.getEngineName(); + newDefaultEngine = provider.getEngineName(); rank = provider.getEngineRank(); } } - } else if (!ALL_ENGINES.containsKey(defaultEngine)) { - throw new EngineException("Unknown default engine: " + defaultEngine); + } else if (!ALL_ENGINES.containsKey(newDefaultEngine)) { + throw new EngineException("Unknown default engine: " + newDefaultEngine); } - logger.debug("Found default engine: {}", defaultEngine); - Ec2Utils.callHome(defaultEngine); - return defaultEngine; + logger.debug("Found default engine: {}", newDefaultEngine); + Ec2Utils.callHome(newDefaultEngine); + return newDefaultEngine; } /** @@ -124,7 +128,7 @@ private static synchronized String initEngine() { * @return the default Engine name */ public static String getDefaultEngineName() { - return System.getProperty("ai.djl.default_engine", DEFAULT_ENGINE); + return System.getProperty("ai.djl.default_engine", defaultEngine); } /** @@ -134,7 +138,7 @@ public static String getDefaultEngineName() { * @see EngineProvider */ public static Engine getInstance() { - if (DEFAULT_ENGINE == null) { + if (defaultEngine == null) { throw new EngineException( "No deep learning engine found." + System.lineSeparator() @@ -163,7 +167,29 @@ public static boolean hasEngine(String engineName) { */ public static void registerEngine(EngineProvider provider) { logger.debug("Registering EngineProvider: {}", provider.getEngineName()); - ALL_ENGINES.putIfAbsent(provider.getEngineName(), provider); + ALL_ENGINES.put(provider.getEngineName(), provider); + } + + /** + * Returns the default engine. + * + * @return the default engine + */ + public static String getDefaultEngine() { + return defaultEngine; + } + + /** + * Sets the default engine returned by {@link #getInstance()}. + * + * @param engineName the new default engine's name + */ + public static void setDefaultEngine(String engineName) { + // Requires an engine to be loaded (without exception) before being the default + getEngine(engineName); + + logger.debug("Setting new default engine: {}", engineName); + defaultEngine = engineName; } /** @@ -187,7 +213,12 @@ public static Engine getEngine(String engineName) { if (provider == null) { throw new IllegalArgumentException("Deep learning engine not found: " + engineName); } - return provider.getEngine(); + Engine engine = provider.getEngine(); + if (engine == null) { + throw new IllegalStateException( + "The engine " + engineName + " was not able to initialize"); + } + return engine; } /** diff --git a/docs/development/troubleshooting.md b/docs/development/troubleshooting.md index ff03d32648e..1a04592dc12 100644 --- a/docs/development/troubleshooting.md +++ b/docs/development/troubleshooting.md @@ -105,6 +105,11 @@ For more information, please refer to [DJL Cache Management](cache_management.md It happened when you had a wrong version with DJL and Deep Engines. You can check the combination [here](dependency_management.md) and use DJL BOM to solve the issue. +### 1.6 Manual initialization + +If you are using manual engine initialization, you must both register an engine and set it as the default. +This can be done with `Engine.registerEngine(..)` and `Engine.setDefaultEngine(..)`. + ## 2. IntelliJ throws the `No Log4j 2 configuration file found.` exception. The following exception may appear after running the `./gradlew clean` command: diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java index f8c84c753ef..583cd8132b2 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java @@ -18,6 +18,9 @@ /** {@code LgbmEngineProvider} is the LightGBM implementation of {@link EngineProvider}. */ public class LgbmEngineProvider implements EngineProvider { + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +36,14 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = LgbmEngine.newInstance(); + if (!initialized) { + synchronized (LgbmEngineProvider.class) { + if (!initialized) { + initialized = true; + engine = LgbmEngine.newInstance(); + } + } + } + return engine; } } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java index 5859f3f344d..8b534d5196c 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java @@ -18,6 +18,9 @@ /** {@code XgbEngineProvider} is the XGBoost implementation of {@link EngineProvider}. */ public class XgbEngineProvider implements EngineProvider { + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +36,14 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = XgbEngine.newInstance(); + if (!initialized) { + synchronized (XgbEngineProvider.class) { + if (!initialized) { + initialized = true; + engine = XgbEngine.newInstance(); + } + } + } + return engine; } } diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java index 5f45116f615..2a5ab970560 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java @@ -18,6 +18,9 @@ /** {@code MxEngineProvider} is the MXNet implementation of {@link EngineProvider}. */ public class MxEngineProvider implements EngineProvider { + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +36,14 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = MxEngine.newInstance(); + if (!initialized) { + synchronized (MxEngineProvider.class) { + if (!initialized) { + initialized = true; + engine = MxEngine.newInstance(); + } + } + } + return engine; } } diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java index 005c0fa25f1..5616eb80edb 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java @@ -18,6 +18,9 @@ /** {@code OrtEngineProvider} is the ONNX Runtime implementation of {@link EngineProvider}. */ public class OrtEngineProvider implements EngineProvider { + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +36,14 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = OrtEngine.newInstance(); + if (!initialized) { + synchronized (OrtEngineProvider.class) { + if (!initialized) { + initialized = true; + engine = OrtEngine.newInstance(); + } + } + } + return engine; } } diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java index 59e5cd90724..e2fb86974f5 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java @@ -18,6 +18,9 @@ /** {@code PpEngineProvider} is the PaddlePaddle implementation of {@link EngineProvider}. */ public class PpEngineProvider implements EngineProvider { + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +36,14 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = PpEngine.newInstance(); + if (!initialized) { + synchronized (PpEngineProvider.class) { + if (!initialized) { + initialized = true; + engine = PpEngine.newInstance(); + } + } + } + return engine; } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java index 42ca3c5b8a5..24be3e91d7a 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java @@ -18,7 +18,8 @@ /** {@code PtEngineProvider} is the PyTorch implementation of {@link EngineProvider}. */ public class PtEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (PtEngineProvider.class) { - if (engine == null) { + if (!initialized) { + initialized = true; engine = PtEngine.newInstance(); } } diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java index ad440a47951..fa7813a49fb 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java @@ -18,7 +18,8 @@ /** {@code TfEngineProvider} is the TensorFlow implementation of {@link EngineProvider}. */ public class TfEngineProvider implements EngineProvider { - private static volatile Engine engine; // NOPMD + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD /** {@inheritDoc} */ @Override @@ -35,9 +36,10 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - if (engine == null) { + if (!initialized) { synchronized (TfEngineProvider.class) { - if (engine == null) { + if (!initialized) { + initialized = true; engine = TfEngine.newInstance(); } } diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java index d92ed9e449d..8c90859c6c6 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java @@ -18,6 +18,9 @@ /** {@code TrtEngineProvider} is the TensorRT implementation of {@link EngineProvider}. */ public class TrtEngineProvider implements EngineProvider { + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +36,14 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = TrtEngine.newInstance(); + if (!initialized) { + synchronized (TrtEngineProvider.class) { + if (!initialized) { + initialized = true; + engine = TrtEngine.newInstance(); + } + } + } + return engine; } } diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java index efd9d89e509..96066b380e1 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java @@ -26,7 +26,7 @@ public void getVersion() { try { Engine engine = Engine.getEngine("TensorRT"); version = engine.getVersion(); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } Assert.assertEquals(version, "8.4.1"); diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java index 09001f0e2da..24d734af54c 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java @@ -28,7 +28,7 @@ public void testNDArray() { Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java index 99cbc6f763e..105e057ba0a 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java @@ -49,7 +49,7 @@ public void testTrtOnnx() throws ModelException, IOException, TranslateException Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { @@ -75,7 +75,7 @@ public void testTrtUff() throws ModelException, IOException, TranslateException Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { @@ -112,7 +112,7 @@ public void testSerializedEngine() throws ModelException, IOException, Translate Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } Device device = engine.defaultDevice(); diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java index fb61551a3bf..b46cad53b99 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java @@ -18,6 +18,9 @@ /** {@code TfLiteEngineProvider} is the TFLite implementation of {@link EngineProvider}. */ public class TfLiteEngineProvider implements EngineProvider { + private volatile Engine engine; // NOPMD + private volatile boolean initialized; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +36,14 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = TfLiteEngine.newInstance(); + if (!initialized) { + synchronized (TfLiteEngineProvider.class) { + if (!initialized) { + initialized = true; + engine = TfLiteEngine.newInstance(); + } + } + } + return engine; } }