Skip to content

Commit

Permalink
Revert "Creates DJL manual engine initialization (#2885)"
Browse files Browse the repository at this point in the history
This reverts commit 6141c48.
  • Loading branch information
frankfliu committed Feb 26, 2024
1 parent 4ae2dd1 commit c81de78
Show file tree
Hide file tree
Showing 14 changed files with 55 additions and 126 deletions.
57 changes: 13 additions & 44 deletions api/src/main/java/ai/djl/engine/Engine.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public abstract class Engine {

private static final Map<String, EngineProvider> ALL_ENGINES = new ConcurrentHashMap<>();

private static String defaultEngine = initEngine();
private static final String DEFAULT_ENGINE = initEngine();
private static final Pattern PATTERN =
Pattern.compile("KEY|TOKEN|PASSWORD", Pattern.CASE_INSENSITIVE);

Expand All @@ -69,10 +69,6 @@ public abstract class Engine {
private Integer seed;

private static synchronized String initEngine() {
if (Boolean.parseBoolean(Utils.getenv("DJL_ENGINE_MANUAL_INIT"))) {
return null;
}

ServiceLoader<EngineProvider> loaders = ServiceLoader.load(EngineProvider.class);
for (EngineProvider provider : loaders) {
registerEngine(provider);
Expand All @@ -84,21 +80,21 @@ private static synchronized String initEngine() {
}

String def = System.getProperty("ai.djl.default_engine");
String newDefaultEngine = Utils.getenv("DJL_DEFAULT_ENGINE", def);
if (newDefaultEngine == null || newDefaultEngine.isEmpty()) {
String defaultEngine = Utils.getenv("DJL_DEFAULT_ENGINE", def);
if (defaultEngine == null || defaultEngine.isEmpty()) {
int rank = Integer.MAX_VALUE;
for (EngineProvider provider : ALL_ENGINES.values()) {
if (provider.getEngineRank() < rank) {
newDefaultEngine = provider.getEngineName();
defaultEngine = provider.getEngineName();
rank = provider.getEngineRank();
}
}
} else if (!ALL_ENGINES.containsKey(newDefaultEngine)) {
throw new EngineException("Unknown default engine: " + newDefaultEngine);
} else if (!ALL_ENGINES.containsKey(defaultEngine)) {
throw new EngineException("Unknown default engine: " + defaultEngine);
}
logger.debug("Found default engine: {}", newDefaultEngine);
Ec2Utils.callHome(newDefaultEngine);
return newDefaultEngine;
logger.debug("Found default engine: {}", defaultEngine);
Ec2Utils.callHome(defaultEngine);
return defaultEngine;
}

/**
Expand Down Expand Up @@ -128,7 +124,7 @@ private static synchronized String initEngine() {
* @return the default Engine name
*/
public static String getDefaultEngineName() {
return System.getProperty("ai.djl.default_engine", defaultEngine);
return System.getProperty("ai.djl.default_engine", DEFAULT_ENGINE);
}

/**
Expand All @@ -138,7 +134,7 @@ public static String getDefaultEngineName() {
* @see EngineProvider
*/
public static Engine getInstance() {
if (defaultEngine == null) {
if (DEFAULT_ENGINE == null) {
throw new EngineException(
"No deep learning engine found."
+ System.lineSeparator()
Expand Down Expand Up @@ -167,29 +163,7 @@ public static boolean hasEngine(String engineName) {
*/
public static void registerEngine(EngineProvider provider) {
logger.debug("Registering EngineProvider: {}", provider.getEngineName());
ALL_ENGINES.put(provider.getEngineName(), provider);
}

/**
* Returns the default engine.
*
* @return the default engine
*/
public static String getDefaultEngine() {
return defaultEngine;
}

/**
* Sets the default engine returned by {@link #getInstance()}.
*
* @param engineName the new default engine's name
*/
public static void setDefaultEngine(String engineName) {
// Requires an engine to be loaded (without exception) before being the default
getEngine(engineName);

logger.debug("Setting new default engine: {}", engineName);
defaultEngine = engineName;
ALL_ENGINES.putIfAbsent(provider.getEngineName(), provider);
}

/**
Expand All @@ -213,12 +187,7 @@ public static Engine getEngine(String engineName) {
if (provider == null) {
throw new IllegalArgumentException("Deep learning engine not found: " + engineName);
}
Engine engine = provider.getEngine();
if (engine == null) {
throw new IllegalStateException(
"The engine " + engineName + " was not able to initialize");
}
return engine;
return provider.getEngine();
}

/**
Expand Down
5 changes: 0 additions & 5 deletions docs/development/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,6 @@ For more information, please refer to [DJL Cache Management](cache_management.md
It happened when you had a wrong version with DJL and Deep Engines.
You can check the combination [here](dependency_management.md) and use DJL BOM to solve the issue.

### 1.6 Manual initialization

If you are using manual engine initialization, you must both register an engine and set it as the default.
This can be done with `Engine.registerEngine(..)` and `Engine.setDefaultEngine(..)`.

## 2. IntelliJ throws the `No Log4j 2 configuration file found.` exception.
The following exception may appear after running the `./gradlew clean` command:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
/** {@code LgbmEngineProvider} is the LightGBM implementation of {@link EngineProvider}. */
public class LgbmEngineProvider implements EngineProvider {

private volatile Engine engine; // NOPMD

/** {@inheritDoc} */
@Override
public String getEngineName() {
Expand All @@ -35,13 +33,10 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
if (engine == null) {
synchronized (LgbmEngineProvider.class) {
if (engine == null) {
engine = LgbmEngine.newInstance();
}
}
}
return engine;
return InstanceHolder.INSTANCE;
}

private static class InstanceHolder {
static final Engine INSTANCE = LgbmEngine.newInstance();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
/** {@code XgbEngineProvider} is the XGBoost implementation of {@link EngineProvider}. */
public class XgbEngineProvider implements EngineProvider {

private volatile Engine engine; // NOPMD

/** {@inheritDoc} */
@Override
public String getEngineName() {
Expand All @@ -35,13 +33,10 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
if (engine == null) {
synchronized (XgbEngineProvider.class) {
if (engine == null) {
engine = XgbEngine.newInstance();
}
}
}
return engine;
return InstanceHolder.INSTANCE;
}

private static class InstanceHolder {
static final Engine INSTANCE = XgbEngine.newInstance();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
/** {@code MxEngineProvider} is the MXNet implementation of {@link EngineProvider}. */
public class MxEngineProvider implements EngineProvider {

private volatile Engine engine; // NOPMD

/** {@inheritDoc} */
@Override
public String getEngineName() {
Expand All @@ -35,13 +33,10 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
if (engine == null) {
synchronized (MxEngineProvider.class) {
if (engine == null) {
engine = MxEngine.newInstance();
}
}
}
return engine;
return InstanceHolder.INSTANCE;
}

private static class InstanceHolder {
static final Engine INSTANCE = MxEngine.newInstance();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
/** {@code OrtEngineProvider} is the ONNX Runtime implementation of {@link EngineProvider}. */
public class OrtEngineProvider implements EngineProvider {

private volatile Engine engine; // NOPMD

/** {@inheritDoc} */
@Override
public String getEngineName() {
Expand All @@ -35,13 +33,10 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
if (engine == null) {
synchronized (OrtEngineProvider.class) {
if (engine == null) {
engine = OrtEngine.newInstance();
}
}
}
return engine;
return InstanceHolder.INSTANCE;
}

private static class InstanceHolder {
static final Engine INSTANCE = OrtEngine.newInstance();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
/** {@code PpEngineProvider} is the PaddlePaddle implementation of {@link EngineProvider}. */
public class PpEngineProvider implements EngineProvider {

private volatile Engine engine; // NOPMD

/** {@inheritDoc} */
@Override
public String getEngineName() {
Expand All @@ -35,13 +33,10 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
if (engine == null) {
synchronized (PpEngineProvider.class) {
if (engine == null) {
engine = PpEngine.newInstance();
}
}
}
return engine;
return InstanceHolder.INSTANCE;
}

private static class InstanceHolder {
static final Engine INSTANCE = PpEngine.newInstance();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
/** {@code PtEngineProvider} is the PyTorch implementation of {@link EngineProvider}. */
public class PtEngineProvider implements EngineProvider {

private volatile Engine engine; // NOPMD
private static volatile Engine engine; // NOPMD

/** {@inheritDoc} */
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
/** {@code TfEngineProvider} is the TensorFlow implementation of {@link EngineProvider}. */
public class TfEngineProvider implements EngineProvider {

private volatile Engine engine; // NOPMD
private static volatile Engine engine; // NOPMD

/** {@inheritDoc} */
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
/** {@code TrtEngineProvider} is the TensorRT implementation of {@link EngineProvider}. */
public class TrtEngineProvider implements EngineProvider {

private volatile Engine engine; // NOPMD

/** {@inheritDoc} */
@Override
public String getEngineName() {
Expand All @@ -35,13 +33,10 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
if (engine == null) {
synchronized (TrtEngineProvider.class) {
if (engine == null) {
engine = TrtEngine.newInstance();
}
}
}
return engine;
return InstanceHolder.INSTANCE;
}

private static class InstanceHolder {
static final Engine INSTANCE = TrtEngine.newInstance();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public void getVersion() {
try {
Engine engine = Engine.getEngine("TensorRT");
version = engine.getVersion();
} catch (Exception ignore) {
} catch (Throwable ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
Assert.assertEquals(version, "8.4.1");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void testNDArray() {
Engine engine;
try {
engine = Engine.getEngine("TensorRT");
} catch (Exception ignore) {
} catch (Throwable ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!engine.defaultDevice().isGpu()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public void testTrtOnnx() throws ModelException, IOException, TranslateException
Engine engine;
try {
engine = Engine.getEngine("TensorRT");
} catch (Exception ignore) {
} catch (Throwable ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!engine.defaultDevice().isGpu()) {
Expand All @@ -75,7 +75,7 @@ public void testTrtUff() throws ModelException, IOException, TranslateException
Engine engine;
try {
engine = Engine.getEngine("TensorRT");
} catch (Exception ignore) {
} catch (Throwable ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!engine.defaultDevice().isGpu()) {
Expand Down Expand Up @@ -112,7 +112,7 @@ public void testSerializedEngine() throws ModelException, IOException, Translate
Engine engine;
try {
engine = Engine.getEngine("TensorRT");
} catch (Exception ignore) {
} catch (Throwable ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
Device device = engine.defaultDevice();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
/** {@code TfLiteEngineProvider} is the TFLite implementation of {@link EngineProvider}. */
public class TfLiteEngineProvider implements EngineProvider {

private volatile Engine engine; // NOPMD

/** {@inheritDoc} */
@Override
public String getEngineName() {
Expand All @@ -35,13 +33,10 @@ public int getEngineRank() {
/** {@inheritDoc} */
@Override
public Engine getEngine() {
if (engine == null) {
synchronized (TfLiteEngineProvider.class) {
if (engine == null) {
engine = TfLiteEngine.newInstance();
}
}
}
return engine;
return InstanceHolder.INSTANCE;
}

private static class InstanceHolder {
static final Engine INSTANCE = TfLiteEngine.newInstance();
}
}

0 comments on commit c81de78

Please sign in to comment.