From 32e5f611365ffb2faa12ca416d5b8dcf91058668 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sun, 5 May 2024 11:51:09 -0700 Subject: [PATCH] [examples] Prepare for MXNet deprecation --- .../ai/djl/pytorch/engine/PtSymbolBlock.java | 2 +- examples/docs/train_transfer_fresh_fruit.md | 2 +- .../examples/inference/BertQaInference.java | 12 ++-- .../examples/inference/SentimentAnalysis.java | 1 + .../examples/inference/SpeechRecognition.java | 2 + .../examples/inference/clip/ClipModel.java | 2 + .../inference/cv/ImageClassification.java | 2 +- .../inference/cv/ObjectDetection.java | 12 +--- .../examples/inference/nlp/RollingBatch.java | 14 +--- .../inference/nlp/TextGeneration.java | 6 +- .../timeseries/M5ForecastingDeepAR.java | 9 +-- .../training/TrainAirfoilWithTabNet.java | 8 +-- .../examples/training/TrainBertOnCode.java | 9 ++- .../training/TrainBertOnGoemotions.java | 9 ++- .../djl/examples/training/TrainCaptcha.java | 5 +- .../ai/djl/examples/training/TrainMnist.java | 7 +- .../examples/training/TrainMnistWithLSTM.java | 7 +- .../djl/examples/training/TrainPikachu.java | 7 +- .../training/TrainPikachuWithYOLOV3.java | 5 +- .../training/TrainSentimentAnalysis.java | 20 ++---- .../djl/examples/training/TrainSeq2Seq.java | 7 +- .../djl/examples/training/TrainTicTacToe.java | 11 +-- .../examples/training/TrainTimeSeries.java | 6 +- .../djl/examples/training/TrainWithHpo.java | 5 +- .../training/TrainWithOptimizers.java | 24 ++----- .../TrainAmazonReviewRanking.java | 69 ++++++------------- .../TrainResnetWithCifar10.java | 19 ++--- .../transferlearning/TransferFreshFruit.java | 4 +- .../djl/examples/training/util/Arguments.java | 25 +++++-- .../inference/ActionRecognitionTest.java | 2 +- .../ai/djl/examples/inference/BertQaTest.java | 2 +- .../ai/djl/examples/inference/BigGANTest.java | 2 +- .../inference/FeatureComparisonTest.java | 3 +- .../inference/FeatureExtractionTest.java | 2 +- .../inference/InstanceSegmentationTest.java | 3 +- .../inference/LightFaceDetectionTest.java | 3 +- .../examples/inference/MaskDetectionTest.java | 2 +- .../inference/ObjectDetectionTest.java | 2 +- ...DetectionWithTensorflowSavedModelTest.java | 2 +- .../inference/PoseEstimationTest.java | 2 +- .../inference/RetinaFaceDetectionTest.java | 2 +- .../inference/SentimentAnalysisTest.java | 2 +- .../inference/SpeechRecognitionTest.java | 9 +-- .../examples/inference/StyleTransferTest.java | 11 +-- .../inference/SuperResolutionTest.java | 2 +- .../examples/inference/TimeSeriesTest.java | 2 +- .../UniversalSentenceEncoderTest.java | 2 +- .../inference/Yolov8DetectionTest.java | 2 +- .../inference/clip/ClipModelTest.java | 4 +- .../inference/nlp/TextGenerationTest.java | 4 +- .../training/TrainAirfoilWithTabNetTest.java | 3 +- .../training/TrainAmazonReviewTest.java | 6 +- .../training/TrainBertOnGoemotionsTest.java | 2 +- .../djl/examples/training/TrainBertTest.java | 2 +- .../examples/training/TrainCaptchaTest.java | 5 +- .../djl/examples/training/TrainMnistTest.java | 2 +- .../training/TrainMnistWithLSTMTest.java | 5 +- .../examples/training/TrainPikachuTest.java | 8 ++- .../examples/training/TrainResNetTest.java | 18 ++--- .../training/TrainSentimentAnalysisTest.java | 6 +- .../examples/training/TrainSeq2SeqTest.java | 5 +- .../examples/training/TrainTicTacToeTest.java | 7 +- .../training/TrainTimeSeriesTest.java | 5 +- .../training/TransferFreshFruitTest.java | 9 ++- .../java/ai/djl/testing/TestRequirements.java | 33 ++------- 65 files changed, 214 insertions(+), 278 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java index 7075cb05efa..cb4de34adac 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java @@ -231,7 +231,7 @@ public Shape[] getOutputShapes(Shape[] inputShapes) { /** {@inheritDoc} */ @Override public Shape[] getOutputShapes(Shape[] inputShapes, DataType[] dataTypes) { - try (NDManager manager = NDManager.newBaseManager()) { + try (NDManager manager = NDManager.newBaseManager("PyTorch")) { NDList list = new NDList(); for (int i = 0; i < inputShapes.length; i++) { list.add( diff --git a/examples/docs/train_transfer_fresh_fruit.md b/examples/docs/train_transfer_fresh_fruit.md index a4802c67733..2f1feb0cf6f 100644 --- a/examples/docs/train_transfer_fresh_fruit.md +++ b/examples/docs/train_transfer_fresh_fruit.md @@ -132,7 +132,7 @@ private static DefaultTrainingConfig setupTrainingConfig(Block baseBlock) { DefaultTrainingConfig config = new DefaultTrainingConfig(new SoftmaxCrossEntropy("SoftmaxCrossEntropy")) .addEvaluator(new Accuracy()) - .optDevices(Engine.getInstance().getDevices(1)) + .optDevices(Engine.getEngine("PyTorch").getDevices(1)) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); ... diff --git a/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java b/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java index 093e159bebb..9538b7c00bb 100644 --- a/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java +++ b/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java @@ -14,8 +14,8 @@ package ai.djl.examples.inference; import ai.djl.Application; +import ai.djl.Device; import ai.djl.ModelException; -import ai.djl.engine.Engine; import ai.djl.inference.Predictor; import ai.djl.modality.nlp.qa.QAInput; import ai.djl.repository.zoo.Criteria; @@ -68,14 +68,14 @@ public static String predict() throws IOException, TranslateException, ModelExce .optApplication(Application.NLP.QUESTION_ANSWER) .setTypes(QAInput.class, String.class) .optFilter("backbone", "bert") - .optEngine(Engine.getDefaultEngineName()) + .optEngine("PyTorch") + .optDevice(Device.cpu()) .optProgress(new ProgressBar()) .build(); - try (ZooModel model = criteria.loadModel()) { - try (Predictor predictor = model.newPredictor()) { - return predictor.predict(input); - } + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + return predictor.predict(input); } } } diff --git a/examples/src/main/java/ai/djl/examples/inference/SentimentAnalysis.java b/examples/src/main/java/ai/djl/examples/inference/SentimentAnalysis.java index c2ee112d52e..40045f21a95 100644 --- a/examples/src/main/java/ai/djl/examples/inference/SentimentAnalysis.java +++ b/examples/src/main/java/ai/djl/examples/inference/SentimentAnalysis.java @@ -60,6 +60,7 @@ public static Classifications predict() Criteria.builder() .optApplication(Application.NLP.SENTIMENT_ANALYSIS) .setTypes(String.class, Classifications.class) + .optEngine("PyTorch") // This model was traced on CPU and can only run on CPU .optDevice(Device.cpu()) .optProgress(new ProgressBar()) diff --git a/examples/src/main/java/ai/djl/examples/inference/SpeechRecognition.java b/examples/src/main/java/ai/djl/examples/inference/SpeechRecognition.java index 5d9a74fadde..1da995a3134 100644 --- a/examples/src/main/java/ai/djl/examples/inference/SpeechRecognition.java +++ b/examples/src/main/java/ai/djl/examples/inference/SpeechRecognition.java @@ -13,6 +13,7 @@ package ai.djl.examples.inference; +import ai.djl.Device; import ai.djl.ModelException; import ai.djl.inference.Predictor; import ai.djl.modality.audio.Audio; @@ -56,6 +57,7 @@ public static String predict() throws IOException, ModelException, TranslateExce Criteria.builder() .setTypes(Audio.class, String.class) .optModelUrls(url) + .optDevice(Device.cpu()) // torchscript model only support CPU .optTranslatorFactory(new SpeechRecognitionTranslatorFactory()) .optModelName("wav2vec2.ptl") .optEngine("PyTorch") diff --git a/examples/src/main/java/ai/djl/examples/inference/clip/ClipModel.java b/examples/src/main/java/ai/djl/examples/inference/clip/ClipModel.java index 3b7c87ef0de..0b93f224a43 100644 --- a/examples/src/main/java/ai/djl/examples/inference/clip/ClipModel.java +++ b/examples/src/main/java/ai/djl/examples/inference/clip/ClipModel.java @@ -12,6 +12,7 @@ */ package ai.djl.examples.inference.clip; +import ai.djl.Device; import ai.djl.ModelException; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; @@ -45,6 +46,7 @@ public ClipModel() throws ModelException, IOException { .optModelUrls("https://resources.djl.ai/demo/pytorch/clip.zip") .optTranslator(new NoopTranslator()) .optEngine("PyTorch") + .optDevice(Device.cpu()) // torchscript model only support CPU .build(); clip = criteria.loadModel(); imageFeatureExtractor = clip.newPredictor(new ImageTranslator()); diff --git a/examples/src/main/java/ai/djl/examples/inference/cv/ImageClassification.java b/examples/src/main/java/ai/djl/examples/inference/cv/ImageClassification.java index d0ba39e4354..cf9e7b18477 100644 --- a/examples/src/main/java/ai/djl/examples/inference/cv/ImageClassification.java +++ b/examples/src/main/java/ai/djl/examples/inference/cv/ImageClassification.java @@ -57,7 +57,7 @@ public static Classifications predict() throws IOException, ModelException, Tran Image img = ImageFactory.getInstance().fromFile(imageFile); String modelName = "mlp"; - try (Model model = Model.newInstance(modelName)) { + try (Model model = Model.newInstance(modelName, "PyTorch")) { model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64})); // Assume you have run TrainMnist.java example, and saved model in build/model folder. diff --git a/examples/src/main/java/ai/djl/examples/inference/cv/ObjectDetection.java b/examples/src/main/java/ai/djl/examples/inference/cv/ObjectDetection.java index f1844d10f32..287ddb248d6 100644 --- a/examples/src/main/java/ai/djl/examples/inference/cv/ObjectDetection.java +++ b/examples/src/main/java/ai/djl/examples/inference/cv/ObjectDetection.java @@ -14,7 +14,6 @@ import ai.djl.Application; import ai.djl.ModelException; -import ai.djl.engine.Engine; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; @@ -54,19 +53,12 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg"); Image img = ImageFactory.getInstance().fromFile(imageFile); - String backbone; - if ("TensorFlow".equals(Engine.getDefaultEngineName())) { - backbone = "mobilenet_v2"; - } else { - backbone = "resnet50"; - } - Criteria criteria = Criteria.builder() .optApplication(Application.CV.OBJECT_DETECTION) .setTypes(Image.class, DetectedObjects.class) - .optFilter("backbone", backbone) - .optEngine(Engine.getDefaultEngineName()) + .optFilter("backbone", "mobilenet_v2") + .optEngine("TensorFlow") .optProgress(new ProgressBar()) .build(); diff --git a/examples/src/main/java/ai/djl/examples/inference/nlp/RollingBatch.java b/examples/src/main/java/ai/djl/examples/inference/nlp/RollingBatch.java index 6751f5448d5..7e937344313 100644 --- a/examples/src/main/java/ai/djl/examples/inference/nlp/RollingBatch.java +++ b/examples/src/main/java/ai/djl/examples/inference/nlp/RollingBatch.java @@ -12,7 +12,7 @@ */ package ai.djl.examples.inference.nlp; -import ai.djl.MalformedModelException; +import ai.djl.ModelException; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; import ai.djl.inference.Predictor; import ai.djl.modality.nlp.generate.CausalLMOutput; @@ -22,7 +22,6 @@ import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ZooModel; import ai.djl.translate.DeferredTranslatorFactory; import ai.djl.translate.TranslateException; @@ -39,20 +38,13 @@ public final class RollingBatch { private RollingBatch() {} - public static void main(String[] args) - throws ModelNotFoundException, - MalformedModelException, - IOException, - TranslateException { + public static void main(String[] args) throws ModelException, IOException, TranslateException { String[] ret = seqBatchSchedulerWithPyTorchContrastive(); logger.info("{}", ret[0]); } public static String[] seqBatchSchedulerWithPyTorchContrastive() - throws ModelNotFoundException, - MalformedModelException, - IOException, - TranslateException { + throws ModelException, IOException, TranslateException { String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_pt.zip"; Criteria criteria = diff --git a/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java b/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java index 1de43457610..cac5866a1e4 100644 --- a/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java +++ b/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java @@ -13,6 +13,7 @@ package ai.djl.examples.inference.nlp; import ai.djl.MalformedModelException; +import ai.djl.ModelException; import ai.djl.huggingface.tokenizers.Encoding; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; import ai.djl.inference.Predictor; @@ -161,10 +162,7 @@ public static String[] generateTextWithPyTorchBeam() } public static String[] generateTextWithOnnxRuntimeBeam() - throws ModelNotFoundException, - MalformedModelException, - IOException, - TranslateException { + throws ModelException, IOException, TranslateException { SearchConfig config = new SearchConfig(); config.setMaxSeqLength(60); long padTokenId = 220; diff --git a/examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java b/examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java index ecb58b7e76d..8d20ff3c981 100644 --- a/examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java +++ b/examples/src/main/java/ai/djl/examples/inference/timeseries/M5ForecastingDeepAR.java @@ -18,7 +18,6 @@ import ai.djl.basicdataset.BasicDatasets; import ai.djl.basicdataset.tabular.utils.DynamicBuffer; import ai.djl.basicdataset.tabular.utils.Feature; -import ai.djl.engine.Engine; import ai.djl.inference.Predictor; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDArrays; @@ -63,7 +62,6 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -82,9 +80,7 @@ public static void main(String[] args) throws IOException, TranslateException, M public static Map predict() throws IOException, TranslateException, ModelException { - Engine engine = Engine.getInstance(); - NDManager manager = engine.newBaseManager(); - String engineName = engine.getEngineName().toLowerCase(Locale.ROOT); + NDManager manager = NDManager.newBaseManager("MXNet"); // To use local dataset, users can load data as follows // Repository repository = Repository.newInstance("local_dataset", @@ -102,12 +98,13 @@ public static Map predict() // https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008#file-m5torch-py // Here you can also use local file: modelUrl = "LOCAL_PATH/deepar.pt"; - String modelUrl = "djl://ai.djl." + engineName + "/deepar/0.0.1/m5forecast"; + String modelUrl = "djl://ai.djl.mxnet/deepar/0.0.1/m5forecast"; int predictionLength = 4; Criteria criteria = Criteria.builder() .setTypes(TimeSeriesData.class, Forecast.class) .optModelUrls(modelUrl) + .optEngine("MXNet") .optTranslatorFactory(new DeepARTranslatorFactory()) .optArgument("prediction_length", predictionLength) .optArgument("freq", "W") diff --git a/examples/src/main/java/ai/djl/examples/training/TrainAirfoilWithTabNet.java b/examples/src/main/java/ai/djl/examples/training/TrainAirfoilWithTabNet.java index 06cf66b2a6f..d9ecf6b7f81 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainAirfoilWithTabNet.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainAirfoilWithTabNet.java @@ -18,7 +18,6 @@ import ai.djl.basicdataset.tabular.TabularDataset; import ai.djl.basicdataset.tabular.TabularResults; import ai.djl.basicmodelzoo.tabular.TabNet; -import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; import ai.djl.inference.Predictor; import ai.djl.metric.Metrics; @@ -55,7 +54,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans // Construct a tabNet instance Block tabNet = TabNet.builder().setInputDim(5).setOutDim(1).build(); - try (Model model = Model.newInstance("tabNet")) { + try (Model model = Model.newInstance("tabNet", arguments.getEngine())) { model.setBlock(tabNet); // get the training and validation dataset @@ -103,13 +102,12 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { }); return new DefaultTrainingConfig(new TabNetRegressionLoss()) - .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optDevices(arguments.getMaxGpus()) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); } - private static TabularDataset getDataset(Arguments arguments) - throws IOException, TranslateException { + private static TabularDataset getDataset(Arguments arguments) throws IOException { AirfoilRandomAccess.Builder airfoilBuilder = AirfoilRandomAccess.builder(); // only train dataset is available, so we get train dataset and split them diff --git a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java index aa2b12af420..bac7e187334 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java @@ -13,7 +13,6 @@ package ai.djl.examples.training; import ai.djl.Model; -import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; import ai.djl.examples.training.util.BertCodeDataset; import ai.djl.ndarray.types.Shape; @@ -59,7 +58,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans dataset.prepare(); // Create model & trainer - try (Model model = createBertPretrainingModel(dataset.getVocabularySize())) { + try (Model model = createBertPretrainingModel(arguments, dataset.getVocabularySize())) { TrainingConfig config = createTrainingConfig(arguments); try (Trainer trainer = model.newTrainer(config)) { @@ -74,7 +73,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans } } - private static Model createBertPretrainingModel(long vocabularySize) { + private static Model createBertPretrainingModel(Arguments arguments, long vocabularySize) { Block block = new BertPretrainingBlock( BertBlock.builder() @@ -82,7 +81,7 @@ private static Model createBertPretrainingModel(long vocabularySize) { .setTokenDictionarySize(Math.toIntExact(vocabularySize))); block.setInitializer(new TruncatedNormalInitializer(0.02f), Parameter.Type.WEIGHT); - Model model = Model.newInstance("Bert Pretraining"); + Model model = Model.newInstance("Bert Pretraining", arguments.getEngine()); model.setBlock(block); return model; } @@ -108,7 +107,7 @@ private static TrainingConfig createTrainingConfig(BertArguments arguments) { .build(); return new DefaultTrainingConfig(new BertPretrainingLoss()) .optOptimizer(optimizer) - .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optDevices(arguments.getMaxGpus()) .addTrainingListeners(Defaults.logging()); } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainBertOnGoemotions.java b/examples/src/main/java/ai/djl/examples/training/TrainBertOnGoemotions.java index 3753ebe8c51..65a51408708 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainBertOnGoemotions.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainBertOnGoemotions.java @@ -14,7 +14,6 @@ import ai.djl.Model; import ai.djl.basicdataset.nlp.GoEmotions; -import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; import ai.djl.examples.training.util.BertGoemotionsDataset; import ai.djl.modality.nlp.embedding.EmbeddingException; @@ -67,7 +66,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans dataset.prepare(); // Create model & trainer - try (Model model = createBertPretrainingModel(dataset.getVocabularySize())) { + try (Model model = createBertPretrainingModel(arguments, dataset.getVocabularySize())) { TrainingConfig config = createTrainingConfig(arguments); try (Trainer trainer = model.newTrainer(config)) { // Initialize training @@ -105,11 +104,11 @@ private static TrainingConfig createTrainingConfig( .build(); return new DefaultTrainingConfig(new BertPretrainingLoss()) .optOptimizer(optimizer) - .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optDevices(arguments.getMaxGpus()) .addTrainingListeners(TrainingListener.Defaults.logging()); } - private static Model createBertPretrainingModel(long vocabularySize) { + private static Model createBertPretrainingModel(Arguments arguments, long vocabularySize) { Block block = new BertPretrainingBlock( BertBlock.builder() @@ -117,7 +116,7 @@ private static Model createBertPretrainingModel(long vocabularySize) { .setTokenDictionarySize(Math.toIntExact(vocabularySize))); block.setInitializer(new TruncatedNormalInitializer(0.02f), Parameter.Type.WEIGHT); - Model model = Model.newInstance("Bert Pretraining"); + Model model = Model.newInstance("Bert Pretraining", arguments.getEngine()); model.setBlock(block); return model; } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainCaptcha.java b/examples/src/main/java/ai/djl/examples/training/TrainCaptcha.java index 570dd7e6b65..2d2f1f7c2d1 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainCaptcha.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainCaptcha.java @@ -15,7 +15,6 @@ import ai.djl.Model; import ai.djl.basicdataset.cv.classification.CaptchaDataset; import ai.djl.basicmodelzoo.cv.classification.ResNetV1; -import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; import ai.djl.metric.Metrics; import ai.djl.ndarray.NDArray; @@ -63,7 +62,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans return null; } - try (Model model = Model.newInstance("captcha")) { + try (Model model = Model.newInstance("captcha", arguments.getEngine())) { model.setBlock(getBlock()); // get training and validation dataset @@ -107,7 +106,7 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { DefaultTrainingConfig config = new DefaultTrainingConfig(loss) - .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optDevices(arguments.getMaxGpus()) .addEvaluators(loss.getComponents()) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); diff --git a/examples/src/main/java/ai/djl/examples/training/TrainMnist.java b/examples/src/main/java/ai/djl/examples/training/TrainMnist.java index 786a71bfbed..9c98b7a6dad 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainMnist.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainMnist.java @@ -15,9 +15,9 @@ import ai.djl.Model; import ai.djl.basicdataset.cv.classification.Mnist; import ai.djl.basicmodelzoo.basic.Mlp; -import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; import ai.djl.metric.Metrics; +import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; import ai.djl.training.DefaultTrainingConfig; @@ -63,7 +63,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans Mnist.NUM_CLASSES, new int[] {128, 64}); - try (Model model = Model.newInstance("mlp")) { + try (Model model = Model.newInstance("mlp", arguments.getEngine())) { model.setBlock(block); // get training and validation dataset @@ -105,7 +105,7 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { }); return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) - .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optDevices(arguments.getMaxGpus()) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); } @@ -115,6 +115,7 @@ private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arg Mnist mnist = Mnist.builder() .optUsage(usage) + .optManager(NDManager.newBaseManager(arguments.getEngine())) .setSampling(arguments.getBatchSize(), true) .optLimit(arguments.getLimit()) .build(); diff --git a/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java b/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java index fa6e7641ba4..65f961c177d 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java @@ -14,9 +14,9 @@ import ai.djl.Model; import ai.djl.basicdataset.cv.classification.Mnist; -import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; import ai.djl.metric.Metrics; +import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; import ai.djl.nn.Blocks; @@ -53,7 +53,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans return null; } - try (Model model = Model.newInstance("lstm")) { + try (Model model = Model.newInstance("lstm", arguments.getEngine())) { model.setBlock(getLSTMModel()); // get training and validation dataset @@ -119,7 +119,7 @@ public static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) - .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optDevices(arguments.getMaxGpus()) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); } @@ -129,6 +129,7 @@ public static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments argu Mnist mnist = Mnist.builder() .optUsage(usage) + .optManager(NDManager.newBaseManager(arguments.getEngine())) .setSampling(arguments.getBatchSize(), false, true) .optLimit(arguments.getLimit()) .build(); diff --git a/examples/src/main/java/ai/djl/examples/training/TrainPikachu.java b/examples/src/main/java/ai/djl/examples/training/TrainPikachu.java index 534f58d8562..662742860cd 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainPikachu.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainPikachu.java @@ -16,7 +16,6 @@ import ai.djl.Model; import ai.djl.basicdataset.cv.PikachuDetection; import ai.djl.basicmodelzoo.cv.object_detection.ssd.SingleShotDetection; -import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; import ai.djl.inference.Predictor; import ai.djl.metric.Metrics; @@ -77,7 +76,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans return null; } - try (Model model = Model.newInstance("pikachu-ssd")) { + try (Model model = Model.newInstance("pikachu-ssd", arguments.getEngine())) { model.setBlock(getSsdTrainBlock()); RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN, arguments); RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST, arguments); @@ -99,7 +98,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans public static int predict(String outputDir, String imageFile) throws IOException, MalformedModelException, TranslateException { - try (Model model = Model.newInstance("pikachu-ssd")) { + try (Model model = Model.newInstance("pikachu-ssd", "PyTorch")) { float detectionThreshold = 0.6f; // load parameters back to original training block model.setBlock(getSsdTrainBlock()); @@ -156,7 +155,7 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { return new DefaultTrainingConfig(new SingleShotDetectionLoss()) .addEvaluator(new SingleShotDetectionAccuracy("classAccuracy")) .addEvaluator(new BoundingBoxError("boundingBoxError")) - .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optDevices(arguments.getMaxGpus()) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainPikachuWithYOLOV3.java b/examples/src/main/java/ai/djl/examples/training/TrainPikachuWithYOLOV3.java index adc72e78e88..e8001c67a09 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainPikachuWithYOLOV3.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainPikachuWithYOLOV3.java @@ -16,7 +16,6 @@ import ai.djl.Model; import ai.djl.basicdataset.cv.PikachuDetection; import ai.djl.basicmodelzoo.cv.object_detection.yolo.YOLOV3; -import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; import ai.djl.metric.Metrics; import ai.djl.modality.cv.transform.ToTensor; @@ -50,7 +49,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans return null; } - try (Model model = Model.newInstance("pikachu-yolov3")) { + try (Model model = Model.newInstance("pikachu-yolov3", arguments.getEngine())) { model.setBlock(YOLOV3.builder().setNumClasses(1).build()); RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN, arguments); RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST, arguments); @@ -106,7 +105,7 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { .setInputShape(new Shape(256, 256)) .setAnchorsArray(anchorsArray) .build()) - .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optDevices(arguments.getMaxGpus()) .addTrainingListeners(TrainingListener.Defaults.basic()) .addTrainingListeners(listener); } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java b/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java index d89129e7bf5..ec313b53565 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java @@ -13,12 +13,11 @@ package ai.djl.examples.training; import ai.djl.Application; -import ai.djl.MalformedModelException; import ai.djl.Model; +import ai.djl.ModelException; import ai.djl.basicdataset.nlp.StanfordMovieReview; import ai.djl.basicdataset.utils.FixedBucketSampler; import ai.djl.basicdataset.utils.TextData; -import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; import ai.djl.inference.Predictor; import ai.djl.metric.Metrics; @@ -39,7 +38,6 @@ import ai.djl.nn.core.Linear; import ai.djl.nn.recurrent.LSTM; import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.EasyTrain; @@ -75,19 +73,12 @@ public final class TrainSentimentAnalysis { private TrainSentimentAnalysis() {} - public static void main(String[] args) - throws IOException, - ModelNotFoundException, - MalformedModelException, - TranslateException { + public static void main(String[] args) throws IOException, ModelException, TranslateException { TrainSentimentAnalysis.runExample(args); } public static TrainingResult runExample(String[] args) - throws IOException, - ModelNotFoundException, - MalformedModelException, - TranslateException { + throws IOException, ModelException, TranslateException { Arguments arguments = new Arguments().parseArgs(args); if (arguments == null) { return null; @@ -99,10 +90,11 @@ public static TrainingResult runExample(String[] args) .optApplication(Application.NLP.WORD_EMBEDDING) .setTypes(String.class, NDList.class) .optArtifactId("glove") + .optEngine(arguments.getEngine()) .optFilter("dimensions", "50") .build(); - try (Model model = Model.newInstance("stanfordSentimentAnalysis"); + try (Model model = Model.newInstance("sentimentAnalysis", arguments.getEngine()); ZooModel embedding = criteria.loadModel()) { ModelZooTextEmbedding modelZooTextEmbedding = new ModelZooTextEmbedding(embedding); // get training and validation dataset @@ -186,7 +178,7 @@ public static DefaultTrainingConfig setupTrainingConfig( return new DefaultTrainingConfig(new SoftmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) - .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optDevices(arguments.getMaxGpus()) .optExecutorService(executorService) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); diff --git a/examples/src/main/java/ai/djl/examples/training/TrainSeq2Seq.java b/examples/src/main/java/ai/djl/examples/training/TrainSeq2Seq.java index bf77b321a76..e8367a39919 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainSeq2Seq.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainSeq2Seq.java @@ -18,7 +18,6 @@ import ai.djl.basicdataset.utils.TextData.Configuration; import ai.djl.basicmodelzoo.nlp.SimpleTextDecoder; import ai.djl.basicmodelzoo.nlp.SimpleTextEncoder; -import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; import ai.djl.metric.Metrics; import ai.djl.modality.nlp.EncoderDecoder; @@ -29,6 +28,7 @@ import ai.djl.modality.nlp.preprocess.SimpleTokenizer; import ai.djl.modality.nlp.preprocess.TextTerminator; import ai.djl.modality.nlp.preprocess.TextTruncator; +import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; import ai.djl.nn.recurrent.LSTM; @@ -66,7 +66,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans } ExecutorService executorService = Executors.newFixedThreadPool(8); - try (Model model = Model.newInstance("seq2seqMTEn-Fr")) { + try (Model model = Model.newInstance("seq2seqMTEn-Fr", arguments.getEngine())) { // get training and validation dataset TextDataset trainingSet = getDataset(Dataset.Usage.TRAIN, arguments, null, null); // Fetch TextEmbedding from dataset @@ -151,7 +151,7 @@ public static DefaultTrainingConfig setupTrainingConfig( return new DefaultTrainingConfig(new MaskedSoftmaxCrossEntropyLoss()) .addEvaluator(new Accuracy("Accuracy", 2)) - .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optDevices(arguments.getMaxGpus()) .optExecutorService(executorService) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); @@ -179,6 +179,7 @@ public static TextDataset getDataset( .addPad(0, 0, (m) -> m.ones(new Shape(1)), 10) .build()) .optUsage(usage) + .optManager(NDManager.newBaseManager(arguments.getEngine())) .optPrefetchNumber(8) .optLimit(limit); Configuration sourceConfig = diff --git a/examples/src/main/java/ai/djl/examples/training/TrainTicTacToe.java b/examples/src/main/java/ai/djl/examples/training/TrainTicTacToe.java index 7fdb94d9493..a2e3125a9c6 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainTicTacToe.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainTicTacToe.java @@ -64,6 +64,7 @@ public static TrainingResult runExample(String[] args) throws IOException { return null; } + String engine = arguments.getEngine(); int epoch = arguments.getEpoch(); int batchSize = arguments.getBatchSize(); int replayBufferSize = 1024; @@ -76,14 +77,15 @@ public static TrainingResult runExample(String[] args) throws IOException { gamesPerEpoch = Math.toIntExact(arguments.getLimit()); } - TicTacToe game = new TicTacToe(NDManager.newBaseManager(), batchSize, replayBufferSize); + TicTacToe game = + new TicTacToe(NDManager.newBaseManager(engine), batchSize, replayBufferSize); Block block = getBlock(); - try (Model model = Model.newInstance("tic-tac-toe")) { + try (Model model = Model.newInstance("tic-tac-toe", engine)) { model.setBlock(block); - DefaultTrainingConfig config = setupTrainingConfig(); + DefaultTrainingConfig config = setupTrainingConfig(arguments); try (Trainer trainer = model.newTrainer(config)) { trainer.initialize( new Shape(batchSize, 9), new Shape(batchSize), new Shape(batchSize)); @@ -162,9 +164,10 @@ public static Block getBlock() { .add(new Mlp(11, 1, new int[] {20, 10})); } - public static DefaultTrainingConfig setupTrainingConfig() { + public static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { return new DefaultTrainingConfig(Loss.l2Loss()) .addTrainingListeners(TrainingListener.Defaults.basic()) + .optDevices(arguments.getMaxGpus()) .optOptimizer( Adam.builder().optLearningRateTracker(Tracker.fixed(0.0001F)).build()); } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java b/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java index e0143ed524b..b9abb016e5f 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainTimeSeries.java @@ -17,7 +17,6 @@ import ai.djl.ModelException; import ai.djl.basicdataset.BasicDatasets; import ai.djl.basicdataset.tabular.utils.Feature; -import ai.djl.engine.Engine; import ai.djl.examples.inference.timeseries.M5ForecastingDeepAR; import ai.djl.examples.training.util.Arguments; import ai.djl.inference.Predictor; @@ -83,9 +82,8 @@ public static void main(String[] args) throws IOException, TranslateException, M } public static TrainingResult runExample(String[] args) throws IOException, TranslateException { - Arguments arguments = new Arguments().parseArgs(args); - try (Model model = Model.newInstance("deepar")) { + try (Model model = Model.newInstance("deepar", arguments.getEngine())) { // specify the model distribution output, for M5 case, NegativeBinomial best describe it DistributionOutput distributionOutput = new NegativeBinomialOutput(); DefaultTrainingConfig config = setupTrainingConfig(arguments, distributionOutput); @@ -212,7 +210,7 @@ private static DefaultTrainingConfig setupTrainingConfig( return new DefaultTrainingConfig(new DistributionLoss("Loss", distributionOutput)) .addEvaluator(new Rmsse(distributionOutput)) - .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optDevices(arguments.getMaxGpus()) .optInitializer(new XavierInitializer(), Parameter.Type.WEIGHT) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); diff --git a/examples/src/main/java/ai/djl/examples/training/TrainWithHpo.java b/examples/src/main/java/ai/djl/examples/training/TrainWithHpo.java index acea661676a..67863556a2e 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainWithHpo.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainWithHpo.java @@ -15,7 +15,6 @@ import ai.djl.Model; import ai.djl.basicdataset.cv.classification.Mnist; import ai.djl.basicmodelzoo.basic.Mlp; -import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; @@ -83,7 +82,7 @@ protected TrainingConfig setupTrainingConfig(HpSet hpVals) { return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) - .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optDevices(arguments.getMaxGpus()) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); } @@ -116,7 +115,7 @@ protected Model buildModel(HpSet hpVals) { Block block = new Mlp(Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES, hidden); - Model model = Model.newInstance("mlp"); + Model model = Model.newInstance("mlp", arguments.getEngine()); model.setBlock(block); return model; } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java b/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java index f2db8c8a7c2..b9f254eea5f 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java @@ -12,12 +12,11 @@ */ package ai.djl.examples.training; -import ai.djl.MalformedModelException; import ai.djl.Model; +import ai.djl.ModelException; import ai.djl.basicdataset.cv.classification.Cifar10; import ai.djl.basicmodelzoo.BasicModelZoo; import ai.djl.basicmodelzoo.cv.classification.ResNetV1; -import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; import ai.djl.metric.Metrics; import ai.djl.modality.Classifications; @@ -31,7 +30,6 @@ import ai.djl.nn.SymbolBlock; import ai.djl.nn.core.Linear; import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.EasyTrain; import ai.djl.training.Trainer; @@ -63,20 +61,12 @@ public final class TrainWithOptimizers { private TrainWithOptimizers() {} public static void main(String[] args) - throws IOException, - ParseException, - ModelNotFoundException, - MalformedModelException, - TranslateException { + throws IOException, ParseException, ModelException, TranslateException { TrainWithOptimizers.runExample(args); } public static TrainingResult runExample(String[] args) - throws IOException, - ParseException, - ModelNotFoundException, - MalformedModelException, - TranslateException { + throws IOException, ParseException, ModelException, TranslateException { OptimizerArguments arguments = (OptimizerArguments) new OptimizerArguments().parseArgs(args); @@ -106,14 +96,14 @@ public static TrainingResult runExample(String[] args) } } - private static Model getModel(Arguments arguments) - throws IOException, ModelNotFoundException, MalformedModelException { + private static Model getModel(Arguments arguments) throws IOException, ModelException { boolean isSymbolic = arguments.isSymbolic(); boolean preTrained = arguments.isPreTrained(); Map options = arguments.getCriteria(); Criteria.Builder builder = Criteria.builder() .setTypes(Image.class, Classifications.class) + .optEngine(arguments.getEngine()) .optProgress(new ProgressBar()) .optArtifactId("resnet"); if (isSymbolic) { @@ -155,7 +145,7 @@ private static Model getModel(Arguments arguments) return builder.build().loadModel(); } else { // construct new ResNet50 without pre-trained weights - Model model = Model.newInstance("resnetv1"); + Model model = Model.newInstance("resnetv1", arguments.getEngine()); Block resNet50 = ResNetV1.builder() .setImageShape(new Shape(3, Cifar10.IMAGE_HEIGHT, Cifar10.IMAGE_WIDTH)) @@ -182,7 +172,7 @@ private static DefaultTrainingConfig setupTrainingConfig(OptimizerArguments argu return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) .optOptimizer(setupOptimizer(arguments)) - .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optDevices(arguments.getMaxGpus()) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); } diff --git a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainAmazonReviewRanking.java b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainAmazonReviewRanking.java index d87626a1cec..efa0e0611fc 100644 --- a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainAmazonReviewRanking.java +++ b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainAmazonReviewRanking.java @@ -19,7 +19,6 @@ import ai.djl.basicdataset.tabular.utils.DynamicBuffer; import ai.djl.basicdataset.tabular.utils.Feature; import ai.djl.basicdataset.tabular.utils.Featurizer.DataFeaturizer; -import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; import ai.djl.metric.Metrics; import ai.djl.modality.nlp.DefaultVocabulary; @@ -39,7 +38,6 @@ import ai.djl.repository.zoo.ZooModel; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.EasyTrain; -import ai.djl.training.ParameterStore; import ai.djl.training.Trainer; import ai.djl.training.TrainingResult; import ai.djl.training.dataset.RandomAccessDataset; @@ -74,24 +72,21 @@ public static TrainingResult runExample(String[] args) return null; } - // MXNet base model - String modelUrls = "https://resources.djl.ai/test-models/distilbert.zip"; - if ("PyTorch".equals(Engine.getDefaultEngineName())) { - modelUrls = - "https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip"; - } + String engine = arguments.getEngine(); + String modelUrls = + "https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip"; Criteria criteria = Criteria.builder() .optApplication(Application.NLP.WORD_EMBEDDING) .setTypes(NDList.class, NDList.class) .optModelUrls(modelUrls) - .optEngine(Engine.getDefaultEngineName()) + .optEngine(engine) .optProgress(new ProgressBar()) .optOption("trainParam", "true") .build(); int maxTokenLength = 64; - try (Model model = Model.newInstance("AmazonReviewRatingClassification"); + try (Model model = Model.newInstance("AmazonReviewRatingClassification", engine); ZooModel embedding = criteria.loadModel()) { // Prepare the vocabulary DefaultVocabulary vocabulary = @@ -148,42 +143,22 @@ private static CsvDataset getDataset( private static Block getBlock(Block embedder) { SequentialBlock classifier = new SequentialBlock(); // text embedding layer - if ("PyTorch".equals(Engine.getDefaultEngineName())) { - LambdaBlock lambda = - new LambdaBlock( - ndList -> { - NDArray data = ndList.singletonOrThrow(); - NDList inputs = new NDList(); - inputs.add(data.toType(DataType.INT64, false)); - inputs.add( - data.getManager().full(data.getShape(), 1, DataType.INT64)); - inputs.add( - data.getManager() - .arange(data.getShape().get(1)) // maxLen - .toType(DataType.INT64, false) - .broadcast(data.getShape())); - return inputs; - }); - classifier.add(lambda); - classifier.add(embedder); - } else { - // MXNet - LambdaBlock lambda = - new LambdaBlock( - ndList -> { - NDArray data = ndList.singletonOrThrow(); - long batchSize = data.getShape().get(0); - float maxLength = data.getShape().get(1); - return embedder.forward( - new ParameterStore(), - new NDList( - data, - data.getManager() - .full(new Shape(batchSize), maxLength)), - true); - }); - classifier.add(lambda); - } + LambdaBlock lambda = + new LambdaBlock( + ndList -> { + NDArray data = ndList.singletonOrThrow(); + NDList inputs = new NDList(); + inputs.add(data.toType(DataType.INT64, false)); + inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64)); + inputs.add( + data.getManager() + .arange(data.getShape().get(1)) // maxLen + .toType(DataType.INT64, false) + .broadcast(data.getShape())); + return inputs; + }); + classifier.add(lambda); + classifier.add(embedder); // Classification layers classifier .add(Linear.builder().setUnits(768).build()) // pre classifier @@ -207,7 +182,7 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { }); return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) - .optDevices(Engine.getInstance().getDevices(1)) + .optDevices(arguments.getMaxGpus()) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); } diff --git a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java index 7acb2f3531f..e6e7e33d996 100644 --- a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java +++ b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java @@ -13,13 +13,11 @@ package ai.djl.examples.training.transferlearning; import ai.djl.Application; -import ai.djl.MalformedModelException; import ai.djl.Model; import ai.djl.ModelException; import ai.djl.basicdataset.cv.classification.Cifar10; import ai.djl.basicmodelzoo.BasicModelZoo; import ai.djl.basicmodelzoo.cv.classification.ResNetV1; -import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; import ai.djl.inference.Predictor; import ai.djl.metric.Metrics; @@ -29,6 +27,7 @@ import ai.djl.modality.cv.transform.Normalize; import ai.djl.modality.cv.transform.ToTensor; import ai.djl.modality.cv.translator.ImageClassificationTranslator; +import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; import ai.djl.nn.Blocks; @@ -36,7 +35,6 @@ import ai.djl.nn.SymbolBlock; import ai.djl.nn.core.Linear; import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.EasyTrain; @@ -114,15 +112,15 @@ public static TrainingResult runExample(String[] args) Path modelPath = Paths.get("build/model"); model.save(modelPath, "resnetv1"); - Classifications classifications = testSaveParameters(model.getBlock(), modelPath); + Classifications classifications = + testSaveParameters(model.getBlock(), modelPath, arguments); logger.info("Predict result: {}", classifications.topK(3)); return result; } } } - private static Model getModel(Arguments arguments) - throws IOException, ModelNotFoundException, MalformedModelException { + private static Model getModel(Arguments arguments) throws IOException, ModelException { boolean isSymbolic = arguments.isSymbolic(); boolean preTrained = arguments.isPreTrained(); Map options = arguments.getCriteria(); @@ -130,6 +128,7 @@ private static Model getModel(Arguments arguments) Criteria.builder() .optApplication(Application.CV.IMAGE_CLASSIFICATION) .setTypes(Image.class, Classifications.class) + .optEngine(arguments.getEngine()) .optProgress(new ProgressBar()) .optArtifactId("resnet"); if (isSymbolic) { @@ -170,7 +169,7 @@ private static Model getModel(Arguments arguments) return builder.build().loadModel(); } else { // construct new ResNet50 without pre-trained weights - Model model = Model.newInstance("resnetv1"); + Model model = Model.newInstance("resnetv1", arguments.getEngine()); Block resNet50 = ResNetV1.builder() .setImageShape(new Shape(3, 32, 32)) @@ -182,7 +181,7 @@ private static Model getModel(Arguments arguments) } } - private static Classifications testSaveParameters(Block block, Path path) + private static Classifications testSaveParameters(Block block, Path path, Arguments arguments) throws IOException, ModelException, TranslateException { String synsetUrl = "https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset_cifar10.txt"; @@ -200,6 +199,7 @@ private static Classifications testSaveParameters(Block block, Path path) Criteria.builder() .setTypes(Image.class, Classifications.class) .optModelPath(path) + .optEngine(arguments.getEngine()) .optTranslator(translator) .optBlock(block) .optModelName("resnetv1") @@ -214,7 +214,7 @@ private static Classifications testSaveParameters(Block block, Path path) private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) - .optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus())) + .optDevices(arguments.getMaxGpus()) .addTrainingListeners(TrainingListener.Defaults.logging(arguments.getOutputDir())); } @@ -227,6 +227,7 @@ private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arg Cifar10 cifar10 = Cifar10.builder() .optUsage(usage) + .optManager(NDManager.newBaseManager(arguments.getEngine())) .setSampling(arguments.getBatchSize(), true) .optLimit(arguments.getLimit()) .optPipeline(pipeline) diff --git a/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java b/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java index 6fa1b22167c..ae4d2376083 100644 --- a/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java +++ b/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java @@ -88,7 +88,7 @@ public static TrainingResult runExample(String[] args) .add(Linear.builder().setUnits(2).build()) .addSingleton(nd -> nd.softmax(1)); - Model model = Model.newInstance("TransferFreshFruit"); + Model model = Model.newInstance("TransferFreshFruit", "PyTorch"); model.setBlock(blocks); // Configure trainer @@ -161,7 +161,7 @@ private static DefaultTrainingConfig setupTrainingConfig(Block baseBlock) { DefaultTrainingConfig config = new DefaultTrainingConfig(new SoftmaxCrossEntropy("SoftmaxCrossEntropy")) .addEvaluator(new Accuracy()) - .optDevices(Engine.getInstance().getDevices(1)) + .optDevices(Engine.getEngine("PyTorch").getDevices(1)) .addTrainingListeners(TrainingListener.Defaults.logging(outputDir)) .addTrainingListeners(listener); diff --git a/examples/src/main/java/ai/djl/examples/training/util/Arguments.java b/examples/src/main/java/ai/djl/examples/training/util/Arguments.java index bbfa48f6381..5a0048226c4 100644 --- a/examples/src/main/java/ai/djl/examples/training/util/Arguments.java +++ b/examples/src/main/java/ai/djl/examples/training/util/Arguments.java @@ -12,6 +12,7 @@ */ package ai.djl.examples.training.util; +import ai.djl.Device; import ai.djl.engine.Engine; import ai.djl.util.JsonUtils; @@ -38,10 +39,10 @@ public class Arguments { protected long limit; protected String modelDir; protected Map criteria; + protected String engine; protected void initialize() { epoch = 2; - maxGpus = Engine.getInstance().getGpuCount(); outputDir = "build/model"; limit = Long.MAX_VALUE; modelDir = null; @@ -52,7 +53,7 @@ protected void setCmd(CommandLine cmd) { epoch = Integer.parseInt(cmd.getOptionValue("epoch")); } if (cmd.hasOption("max-gpus")) { - maxGpus = Math.min(Integer.parseInt(cmd.getOptionValue("max-gpus")), maxGpus); + maxGpus = Integer.parseInt(cmd.getOptionValue("max-gpus")); } if (cmd.hasOption("batch-size")) { batchSize = Integer.parseInt(cmd.getOptionValue("batch-size")); @@ -75,6 +76,11 @@ protected void setCmd(CommandLine cmd) { Type type = new TypeToken>() {}.getType(); criteria = JsonUtils.GSON.fromJson(cmd.getOptionValue("criteria"), type); } + if (cmd.hasOption("engine")) { + engine = cmd.getOptionValue("engine"); + } else { + engine = "PyTorch"; + } } public Arguments parseArgs(String[] args) { @@ -162,6 +168,13 @@ public Options getOptions() { .argName("CRITERIA") .desc("The criteria used for the model.") .build()); + options.addOption( + Option.builder() + .longOpt("engine") + .hasArg() + .argName("ENGINE") + .desc("The engine for the model.") + .build()); return options; } @@ -173,8 +186,8 @@ public int getEpoch() { return epoch; } - public int getMaxGpus() { - return maxGpus; + public Device[] getMaxGpus() { + return Engine.getEngine(engine).getDevices(maxGpus); } public boolean isSymbolic() { @@ -201,6 +214,10 @@ public Map getCriteria() { return criteria; } + public String getEngine() { + return engine; + } + private void printHelp(String msg, Options options) { HelpFormatter formatter = new HelpFormatter(); formatter.setLeftPadding(1); diff --git a/examples/src/test/java/ai/djl/examples/inference/ActionRecognitionTest.java b/examples/src/test/java/ai/djl/examples/inference/ActionRecognitionTest.java index acf83fbb15a..0534c2a4f40 100644 --- a/examples/src/test/java/ai/djl/examples/inference/ActionRecognitionTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/ActionRecognitionTest.java @@ -27,7 +27,7 @@ public class ActionRecognitionTest { @Test public void testActionRecognition() throws ModelException, TranslateException, IOException { - TestRequirements.engine("MXNet"); + TestRequirements.linux(); Classifications result = ActionRecognition.predict(); Classifications.Classification best = result.best(); diff --git a/examples/src/test/java/ai/djl/examples/inference/BertQaTest.java b/examples/src/test/java/ai/djl/examples/inference/BertQaTest.java index 5b9de95a58a..06194ac513e 100644 --- a/examples/src/test/java/ai/djl/examples/inference/BertQaTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/BertQaTest.java @@ -25,7 +25,7 @@ public class BertQaTest { @Test public void testBertQa() throws ModelException, TranslateException, IOException { - TestRequirements.engine("MXNet", "PyTorch"); + TestRequirements.linux(); String result = BertQaInference.predict(); Assert.assertEquals(result, "december 2004"); diff --git a/examples/src/test/java/ai/djl/examples/inference/BigGANTest.java b/examples/src/test/java/ai/djl/examples/inference/BigGANTest.java index 307a6e915a7..8158db0508e 100644 --- a/examples/src/test/java/ai/djl/examples/inference/BigGANTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/BigGANTest.java @@ -27,7 +27,7 @@ public class BigGANTest { @Test public void testBigGAN() throws ModelException, TranslateException, IOException { - TestRequirements.engine("PyTorch"); + TestRequirements.linux(); Image[] generatedImages = BigGAN.generate(); diff --git a/examples/src/test/java/ai/djl/examples/inference/FeatureComparisonTest.java b/examples/src/test/java/ai/djl/examples/inference/FeatureComparisonTest.java index ceadef177b1..7e355ecfac8 100644 --- a/examples/src/test/java/ai/djl/examples/inference/FeatureComparisonTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/FeatureComparisonTest.java @@ -31,7 +31,8 @@ public class FeatureComparisonTest { @Test public void testFeatureComparison() throws ModelException, TranslateException, IOException { - TestRequirements.engine("PyTorch"); + TestRequirements.linux(); + TestRequirements.nightly(); Path imageFile1 = Paths.get("src/test/resources/kana1.jpg"); diff --git a/examples/src/test/java/ai/djl/examples/inference/FeatureExtractionTest.java b/examples/src/test/java/ai/djl/examples/inference/FeatureExtractionTest.java index 56ea8f1bc2a..c5ac6bd1119 100644 --- a/examples/src/test/java/ai/djl/examples/inference/FeatureExtractionTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/FeatureExtractionTest.java @@ -30,7 +30,7 @@ public class FeatureExtractionTest { @Test public void testFeatureComparison() throws ModelException, TranslateException, IOException { - TestRequirements.engine("PyTorch"); + TestRequirements.linux(); Path imageFile = Paths.get("src/test/resources/kana1.jpg"); Image img = ImageFactory.getInstance().fromFile(imageFile); diff --git a/examples/src/test/java/ai/djl/examples/inference/InstanceSegmentationTest.java b/examples/src/test/java/ai/djl/examples/inference/InstanceSegmentationTest.java index d17d6dcae05..1cf6c06d7d7 100644 --- a/examples/src/test/java/ai/djl/examples/inference/InstanceSegmentationTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/InstanceSegmentationTest.java @@ -28,8 +28,7 @@ public class InstanceSegmentationTest { @Test public void testInstanceSegmentation() throws ModelException, TranslateException, IOException { - TestRequirements.engine("MXNet"); - TestRequirements.nightly(); + TestRequirements.linux(); DetectedObjects result = InstanceSegmentation.predict(); Classifications.Classification best = result.best(); diff --git a/examples/src/test/java/ai/djl/examples/inference/LightFaceDetectionTest.java b/examples/src/test/java/ai/djl/examples/inference/LightFaceDetectionTest.java index bf4baee088d..360a76049f3 100644 --- a/examples/src/test/java/ai/djl/examples/inference/LightFaceDetectionTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/LightFaceDetectionTest.java @@ -30,8 +30,7 @@ public class LightFaceDetectionTest { @Test public void testLightFaceDetection() throws ModelException, TranslateException, IOException { - TestRequirements.engine("PyTorch"); - TestRequirements.nightly(); + TestRequirements.linux(); DetectedObjects result = LightFaceDetection.predict(); diff --git a/examples/src/test/java/ai/djl/examples/inference/MaskDetectionTest.java b/examples/src/test/java/ai/djl/examples/inference/MaskDetectionTest.java index 818be1f65ff..551920bff4a 100644 --- a/examples/src/test/java/ai/djl/examples/inference/MaskDetectionTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/MaskDetectionTest.java @@ -34,7 +34,7 @@ public class MaskDetectionTest { @Test public void testMaskDetection() throws ModelException, TranslateException, IOException { - TestRequirements.engine("OnnxRuntime"); + TestRequirements.linux(); DetectedObjects result = MaskDetection.predict(); logger.info("{}", result); diff --git a/examples/src/test/java/ai/djl/examples/inference/ObjectDetectionTest.java b/examples/src/test/java/ai/djl/examples/inference/ObjectDetectionTest.java index 5f612ec8702..a1007cc34c8 100644 --- a/examples/src/test/java/ai/djl/examples/inference/ObjectDetectionTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/ObjectDetectionTest.java @@ -34,7 +34,7 @@ public class ObjectDetectionTest { @Test public void testObjectDetection() throws ModelException, TranslateException, IOException { - TestRequirements.engine("MXNet", "PyTorch", "TensorFlow"); + TestRequirements.linux(); DetectedObjects result = ObjectDetection.predict(); logger.info("{}", result); diff --git a/examples/src/test/java/ai/djl/examples/inference/ObjectDetectionWithTensorflowSavedModelTest.java b/examples/src/test/java/ai/djl/examples/inference/ObjectDetectionWithTensorflowSavedModelTest.java index c3849598ae4..afb3005514f 100644 --- a/examples/src/test/java/ai/djl/examples/inference/ObjectDetectionWithTensorflowSavedModelTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/ObjectDetectionWithTensorflowSavedModelTest.java @@ -33,8 +33,8 @@ public class ObjectDetectionWithTensorflowSavedModelTest { public void testObjectDetection() throws ModelException, TranslateException, IOException { // Only run nightly, this example download the synset file from github, this can cause // throttling and will fail the test. + TestRequirements.linux(); TestRequirements.nightly(); - TestRequirements.engine("TensorFlow"); DetectedObjects result = ObjectDetectionWithTensorflowSavedModel.predict(); diff --git a/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java b/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java index da927720613..d0afdb90be1 100644 --- a/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/PoseEstimationTest.java @@ -28,7 +28,7 @@ public class PoseEstimationTest { @Test public void testPoseEstimation() throws ModelException, TranslateException, IOException { - TestRequirements.engine("MXNet"); + TestRequirements.linux(); List result = PoseEstimation.predict(); Assert.assertTrue(result.get(0).getJoints().get(0).getConfidence() > 0.6d); diff --git a/examples/src/test/java/ai/djl/examples/inference/RetinaFaceDetectionTest.java b/examples/src/test/java/ai/djl/examples/inference/RetinaFaceDetectionTest.java index 5711e626f43..0060cd3e67d 100644 --- a/examples/src/test/java/ai/djl/examples/inference/RetinaFaceDetectionTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/RetinaFaceDetectionTest.java @@ -30,7 +30,7 @@ public class RetinaFaceDetectionTest { @Test public void testRetinaFaceDetection() throws ModelException, TranslateException, IOException { - TestRequirements.engine("PyTorch"); + TestRequirements.linux(); TestRequirements.nightly(); DetectedObjects result = RetinaFaceDetection.predict(); diff --git a/examples/src/test/java/ai/djl/examples/inference/SentimentAnalysisTest.java b/examples/src/test/java/ai/djl/examples/inference/SentimentAnalysisTest.java index 504208e71ee..4091c91525e 100644 --- a/examples/src/test/java/ai/djl/examples/inference/SentimentAnalysisTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/SentimentAnalysisTest.java @@ -26,8 +26,8 @@ public class SentimentAnalysisTest { @Test public void testSentimentAnalysis() throws ModelException, TranslateException, IOException { + TestRequirements.linux(); TestRequirements.nightly(); - TestRequirements.engine("PyTorch"); Classifications result = SentimentAnalysis.predict(); Assert.assertEquals(result.best().getClassName(), "Positive"); diff --git a/examples/src/test/java/ai/djl/examples/inference/SpeechRecognitionTest.java b/examples/src/test/java/ai/djl/examples/inference/SpeechRecognitionTest.java index 4801081ce50..e897381aa5f 100644 --- a/examples/src/test/java/ai/djl/examples/inference/SpeechRecognitionTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/SpeechRecognitionTest.java @@ -22,15 +22,12 @@ import java.io.IOException; -import javax.sound.sampled.UnsupportedAudioFileException; - public class SpeechRecognitionTest { @Test - public void testSpeechRecognition() - throws ModelException, TranslateException, IOException, UnsupportedAudioFileException { + public void testSpeechRecognition() throws ModelException, TranslateException, IOException { + TestRequirements.linux(); TestRequirements.nightly(); - TestRequirements.engine("PyTorch"); try { String result = SpeechRecognition.predict(); @@ -40,7 +37,7 @@ public void testSpeechRecognition() + " DOOR TO YOU SHALL I CALL ON HIM AS I PASS "); } catch (EngineException e) { // wav2vec2.ptl model requires avx2 - if (!"Unknown qengine".equals(e.getMessage())) { + if (!"Unknown engine".equals(e.getMessage())) { throw e; } } diff --git a/examples/src/test/java/ai/djl/examples/inference/StyleTransferTest.java b/examples/src/test/java/ai/djl/examples/inference/StyleTransferTest.java index 1981cda7643..645a88b57bb 100644 --- a/examples/src/test/java/ai/djl/examples/inference/StyleTransferTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/StyleTransferTest.java @@ -12,11 +12,10 @@ */ package ai.djl.examples.inference; -import ai.djl.MalformedModelException; +import ai.djl.ModelException; import ai.djl.examples.inference.cv.StyleTransfer; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; -import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.testing.TestRequirements; import ai.djl.translate.TranslateException; @@ -29,12 +28,8 @@ public class StyleTransferTest { @Test - public void testStyleTransfer() - throws IOException, - ModelNotFoundException, - MalformedModelException, - TranslateException { - TestRequirements.engine("PyTorch"); + public void testStyleTransfer() throws IOException, ModelException, TranslateException { + TestRequirements.linux(); String imagePath = "src/test/resources/mountains.png"; Image input = ImageFactory.getInstance().fromFile(Paths.get(imagePath)); diff --git a/examples/src/test/java/ai/djl/examples/inference/SuperResolutionTest.java b/examples/src/test/java/ai/djl/examples/inference/SuperResolutionTest.java index 81cde8ab876..2c1a363a0bd 100644 --- a/examples/src/test/java/ai/djl/examples/inference/SuperResolutionTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/SuperResolutionTest.java @@ -31,7 +31,7 @@ public class SuperResolutionTest { @Test public void testSuperResolution() throws ModelException, TranslateException, IOException { - TestRequirements.engine("TensorFlow"); + TestRequirements.linux(); String imagePath = "src/test/resources/"; Image fox = ImageFactory.getInstance().fromFile(Paths.get(imagePath + "fox.png")); diff --git a/examples/src/test/java/ai/djl/examples/inference/TimeSeriesTest.java b/examples/src/test/java/ai/djl/examples/inference/TimeSeriesTest.java index f6b315cee36..cd37ef957c0 100644 --- a/examples/src/test/java/ai/djl/examples/inference/TimeSeriesTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/TimeSeriesTest.java @@ -33,7 +33,7 @@ public class TimeSeriesTest { @Test public void testM5Forecasting() throws ModelException, TranslateException, IOException { - TestRequirements.engine("MXNet"); + TestRequirements.linux(); Map result = M5ForecastingDeepAR.predict(); diff --git a/examples/src/test/java/ai/djl/examples/inference/UniversalSentenceEncoderTest.java b/examples/src/test/java/ai/djl/examples/inference/UniversalSentenceEncoderTest.java index 326fd18f41f..9784b47b3ad 100644 --- a/examples/src/test/java/ai/djl/examples/inference/UniversalSentenceEncoderTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/UniversalSentenceEncoderTest.java @@ -27,8 +27,8 @@ public class UniversalSentenceEncoderTest { @Test public void testSentimentAnalysis() throws ModelException, TranslateException, IOException { + TestRequirements.linux(); TestRequirements.nightly(); - TestRequirements.engine("TensorFlow"); List inputs = new ArrayList<>(); inputs.add("The quick brown fox jumps over the lazy dog."); diff --git a/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java b/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java index 780d90dc9f2..9fb9cc079ce 100644 --- a/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/Yolov8DetectionTest.java @@ -28,7 +28,7 @@ public class Yolov8DetectionTest { @Test public void testYolov8Detection() throws ModelException, TranslateException, IOException { - TestRequirements.engine("MXNet", "PyTorch"); + TestRequirements.linux(); DetectedObjects result = Yolov8Detection.predict(); diff --git a/examples/src/test/java/ai/djl/examples/inference/clip/ClipModelTest.java b/examples/src/test/java/ai/djl/examples/inference/clip/ClipModelTest.java index 70b0a6e9c48..7f310bf7794 100644 --- a/examples/src/test/java/ai/djl/examples/inference/clip/ClipModelTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/clip/ClipModelTest.java @@ -27,8 +27,8 @@ public class ClipModelTest { @Test public void testClipFeature() throws ModelException, IOException, TranslateException { + TestRequirements.linux(); TestRequirements.nightly(); - TestRequirements.engine("PyTorch"); try (ClipModel model = new ClipModel()) { String input = "This is a nice day"; @@ -44,8 +44,8 @@ public void testClipFeature() throws ModelException, IOException, TranslateExcep @Test public void testClipComparison() throws ModelException, IOException, TranslateException { + TestRequirements.linux(); TestRequirements.nightly(); - TestRequirements.engine("PyTorch"); String text = "A photo of cats"; String text2 = "A photo of dogs"; diff --git a/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java b/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java index d8393d6c07d..29557174503 100644 --- a/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/nlp/TextGenerationTest.java @@ -25,8 +25,8 @@ public class TextGenerationTest { @Test public void testTextGeneration() throws TranslateException, ModelException, IOException { + TestRequirements.linux(); TestRequirements.weekly(); - TestRequirements.engine("PyTorch"); // Beam with Ort String[] output0 = TextGeneration.generateTextWithOnnxRuntimeBeam(); @@ -91,8 +91,8 @@ public void testTextGeneration() throws TranslateException, ModelException, IOEx @Test public void testSeqBatchScheduler() throws TranslateException, ModelException, IOException { + TestRequirements.linux(); TestRequirements.weekly(); - TestRequirements.engine("PyTorch"); String[] output = RollingBatch.seqBatchSchedulerWithPyTorchContrastive(); Assert.assertEquals( output[0], diff --git a/examples/src/test/java/ai/djl/examples/training/TrainAirfoilWithTabNetTest.java b/examples/src/test/java/ai/djl/examples/training/TrainAirfoilWithTabNetTest.java index f5bb6642cf7..35d831777a3 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainAirfoilWithTabNetTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainAirfoilWithTabNetTest.java @@ -25,7 +25,8 @@ public class TrainAirfoilWithTabNetTest { @Test public void testTrainAirfoilWithTabNet() throws TranslateException, IOException { - TestRequirements.engine("MXNet", "PyTorch"); + TestRequirements.linux(); + String[] args = {"-g", "1", "-e", "20", "-b", "32"}; if (!Boolean.getBoolean("nightly")) { args[3] = "2"; diff --git a/examples/src/test/java/ai/djl/examples/training/TrainAmazonReviewTest.java b/examples/src/test/java/ai/djl/examples/training/TrainAmazonReviewTest.java index fafaa2c7c6c..5998f72500f 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainAmazonReviewTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainAmazonReviewTest.java @@ -28,11 +28,11 @@ public class TrainAmazonReviewTest { @Test public void testRankTraining() throws ModelException, TranslateException, IOException, URISyntaxException { - TestRequirements.engine("MXNet", "PyTorch"); - TestRequirements.nightly(); + TestRequirements.linux(); + TestRequirements.weekly(); String[] args; - if (Engine.getInstance().getGpuCount() > 0) { + if (Engine.getEngine("PyTorch").getGpuCount() > 0) { args = new String[] {"-e", "1", "-g", "1"}; } else { args = new String[] {"-e", "1", "-m", "2"}; diff --git a/examples/src/test/java/ai/djl/examples/training/TrainBertOnGoemotionsTest.java b/examples/src/test/java/ai/djl/examples/training/TrainBertOnGoemotionsTest.java index 7131fb4e9b5..285a22c6700 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainBertOnGoemotionsTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainBertOnGoemotionsTest.java @@ -23,7 +23,7 @@ public class TrainBertOnGoemotionsTest { @Test public void testTrainBert() throws IOException, TranslateException { - TestRequirements.engine("MXNet", "PyTorch"); + TestRequirements.linux(); String[] args = new String[] {"-g", "1", "-m", "1", "-e", "1"}; TrainBertOnGoemotions.runExample(args); diff --git a/examples/src/test/java/ai/djl/examples/training/TrainBertTest.java b/examples/src/test/java/ai/djl/examples/training/TrainBertTest.java index a5ad700b5f5..cdb101a29aa 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainBertTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainBertTest.java @@ -23,7 +23,7 @@ public class TrainBertTest { @Test public void testTrainBert() throws IOException, TranslateException { - TestRequirements.engine("MXNet", "PyTorch"); + TestRequirements.linux(); String[] args = new String[] {"-g", "1", "-m", "1", "-e", "1"}; TrainBertOnCode.runExample(args); diff --git a/examples/src/test/java/ai/djl/examples/training/TrainCaptchaTest.java b/examples/src/test/java/ai/djl/examples/training/TrainCaptchaTest.java index eea5d8d43af..a2ccbbf4357 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainCaptchaTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainCaptchaTest.java @@ -25,9 +25,10 @@ public class TrainCaptchaTest { @Test public void testTrainCaptcha() throws IOException, TranslateException { - TestRequirements.engine("MXNet"); + TestRequirements.linux(); - String[] args = new String[] {"-g", "1", "-e", "1", "-m", "2"}; + // TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH + String[] args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"}; TrainingResult result = TrainCaptcha.runExample(args); Assert.assertNotNull(result); } diff --git a/examples/src/test/java/ai/djl/examples/training/TrainMnistTest.java b/examples/src/test/java/ai/djl/examples/training/TrainMnistTest.java index e72d3cf5427..1e54b24c13e 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainMnistTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainMnistTest.java @@ -28,7 +28,7 @@ public class TrainMnistTest { @Test public void testTrainMnist() throws ModelException, TranslateException, IOException { - TestRequirements.engine("MXNet", "PyTorch"); + TestRequirements.linux(); double expectedProb; if (Boolean.getBoolean("nightly")) { diff --git a/examples/src/test/java/ai/djl/examples/training/TrainMnistWithLSTMTest.java b/examples/src/test/java/ai/djl/examples/training/TrainMnistWithLSTMTest.java index 4edd139e140..7b9b0380492 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainMnistWithLSTMTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainMnistWithLSTMTest.java @@ -25,9 +25,10 @@ public class TrainMnistWithLSTMTest { @Test public void testTrainMnistWithLSTM() throws IOException, TranslateException { - TestRequirements.engine("MXNet", "PyTorch"); + TestRequirements.linux(); - String[] args = new String[] {"-g", "1", "-e", "1", "-m", "2"}; + // TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH + String[] args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"}; TrainingResult result = TrainMnistWithLSTM.runExample(args); Assert.assertNotNull(result); } diff --git a/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java b/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java index 1a5699836c8..77616cfc623 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainPikachuTest.java @@ -27,20 +27,22 @@ public class TrainPikachuTest { @Test public void testDetection() throws IOException, MalformedModelException, TranslateException { + TestRequirements.linux(); TestRequirements.nightly(); String[] args; float expectedLoss = 0; int expectedMinNumber = 0; int expectedMaxNumber = 0; - if (Engine.getInstance().getGpuCount() > 0) { - args = new String[] {"-e", "20", "-b", "32", "-g", "1"}; + // TODO: implement PyTorch multiBoxTarget or object detection training + if (Engine.getEngine("MXNet").getGpuCount() > 0) { + args = new String[] {"-e", "20", "-b", "32", "-g", "1", "--engine", "MXNet"}; expectedLoss = 2.5e-3f; expectedMaxNumber = 15; expectedMinNumber = 6; } else { // test train 1 epoch and predict workflow works on CPU - args = new String[] {"-e", "1", "-m", "1", "-b", "32"}; + args = new String[] {"-e", "1", "-m", "1", "-b", "32", "--engine", "MXNet"}; } // test train TrainingResult result = TrainPikachu.runExample(args); diff --git a/examples/src/test/java/ai/djl/examples/training/TrainResNetTest.java b/examples/src/test/java/ai/djl/examples/training/TrainResNetTest.java index 1d2c8d75d3f..b55bc78f9de 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainResNetTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainResNetTest.java @@ -30,11 +30,12 @@ public class TrainResNetTest { @Test public void testTrainResNet() throws ModelException, IOException, TranslateException { - TestRequirements.engine("MXNet"); + TestRequirements.linux(); // Limit max 4 gpu for cifar10 training to make it converge faster. // and only train 10 batch for unit test. - String[] args = {"-e", "2", "-g", "4", "-m", "10", "-s", "-p"}; + // only MXNet support symbolic model + String[] args = {"-e", "2", "-g", "4", "-m", "10", "-s", "-p", "--engine", "MXNet"}; TrainingResult result = TrainResnetWithCifar10.runExample(args); Assert.assertNotNull(result); @@ -45,13 +46,14 @@ public void testTrainResNetSymbolicNightly() throws ModelException, IOException, TranslateException { TestRequirements.engine("MXNet"); TestRequirements.nightly(); - TestRequirements.gpu(); + TestRequirements.gpu("MXNet"); // Limit max 4 gpu for cifar10 training to make it converge faster. // and only train 10 batch for unit test. - String[] args = {"-e", "10", "-g", "4", "-s", "-p"}; + // only MXNet support symbolic model + String[] args = {"-e", "10", "-g", "4", "-s", "-p", "--engine", "MXNet"}; - Engine.getInstance().setRandomSeed(SEED); + Engine.getEngine("MXNet").setRandomSeed(SEED); TrainingResult result = TrainResnetWithCifar10.runExample(args); Assert.assertNotNull(result); @@ -64,13 +66,13 @@ public void testTrainResNetSymbolicNightly() public void testTrainResNetImperativeNightly() throws ModelException, IOException, TranslateException { TestRequirements.nightly(); - TestRequirements.gpu(); + TestRequirements.gpu("MXNet"); // Limit max 4 gpu for cifar10 training to make it converge faster. // and only train 10 batch for unit test. - String[] args = {"-e", "10", "-g", "4"}; + String[] args = {"-e", "10", "-g", "4", "--engine", "MXNet"}; - Engine.getInstance().setRandomSeed(SEED); + Engine.getEngine("MXNet").setRandomSeed(SEED); TrainingResult result = TrainResnetWithCifar10.runExample(args); Assert.assertNotNull(result); diff --git a/examples/src/test/java/ai/djl/examples/training/TrainSentimentAnalysisTest.java b/examples/src/test/java/ai/djl/examples/training/TrainSentimentAnalysisTest.java index 43ce2d8b7a4..a07d95b5a5b 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainSentimentAnalysisTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainSentimentAnalysisTest.java @@ -25,11 +25,11 @@ public class TrainSentimentAnalysisTest { @Test public void testTrainSentimentAnalysis() throws ModelException, TranslateException, IOException { - TestRequirements.engine("MXNet"); + TestRequirements.linux(); TestRequirements.nightly(); - TestRequirements.gpu(); + TestRequirements.gpu("MXNet"); - String[] args = new String[] {"-e", "1", "-g", "1"}; + String[] args = new String[] {"-e", "1", "-g", "1", "--engine", "MXNet"}; TrainSentimentAnalysis.runExample(args); } } diff --git a/examples/src/test/java/ai/djl/examples/training/TrainSeq2SeqTest.java b/examples/src/test/java/ai/djl/examples/training/TrainSeq2SeqTest.java index df6abf24858..58c46e21488 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainSeq2SeqTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainSeq2SeqTest.java @@ -26,9 +26,10 @@ public class TrainSeq2SeqTest { @Test public void testTrainSeq2Seq() throws IOException, TranslateException { - TestRequirements.engine("MXNet"); + TestRequirements.linux(); - String[] args = new String[] {"-g", "1", "-e", "1", "-m", "2"}; + // TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH + String[] args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"}; TrainingResult result = TrainSeq2Seq.runExample(args); Assert.assertNotNull(result); } diff --git a/examples/src/test/java/ai/djl/examples/training/TrainTicTacToeTest.java b/examples/src/test/java/ai/djl/examples/training/TrainTicTacToeTest.java index 941cf7cc7b2..6d077552d1f 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainTicTacToeTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainTicTacToeTest.java @@ -25,11 +25,12 @@ public class TrainTicTacToeTest { @Test public void testTrainTicTacToe() throws IOException { - TestRequirements.engine("MXNet", "PyTorch"); + TestRequirements.linux(); - if (Boolean.getBoolean("nightly") && Engine.getInstance().getGpuCount() > 0) { + Engine engine = Engine.getEngine("PyTorch"); + if (Boolean.getBoolean("nightly") && engine.getGpuCount() > 0) { String[] args = new String[] {"-g", "1", "-e", "6"}; - Engine.getInstance().setRandomSeed(1234); + engine.setRandomSeed(1234); TrainingResult result = TrainTicTacToe.runExample(args); Assert.assertNotNull(result); diff --git a/examples/src/test/java/ai/djl/examples/training/TrainTimeSeriesTest.java b/examples/src/test/java/ai/djl/examples/training/TrainTimeSeriesTest.java index b92c1bb8249..c1b6ffbca51 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainTimeSeriesTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainTimeSeriesTest.java @@ -26,9 +26,10 @@ public class TrainTimeSeriesTest { @Test public void testTrainTimeSeries() throws TranslateException, IOException { - TestRequirements.engine("MXNet"); + TestRequirements.linux(); - String[] args = new String[] {"-g", "1", "-e", "5", "-b", "32"}; + // TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH + String[] args = new String[] {"-g", "1", "-e", "5", "-b", "32", "--engine", "MXNet"}; TrainingResult result = TrainTimeSeries.runExample(args); Assert.assertNotNull(result); float loss = result.getTrainLoss(); diff --git a/examples/src/test/java/ai/djl/examples/training/TransferFreshFruitTest.java b/examples/src/test/java/ai/djl/examples/training/TransferFreshFruitTest.java index 5764cabe1ec..b52913b94c7 100644 --- a/examples/src/test/java/ai/djl/examples/training/TransferFreshFruitTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TransferFreshFruitTest.java @@ -30,15 +30,14 @@ public class TransferFreshFruitTest { @Test public void testTransferFreshFruit() throws ModelException, TranslateException, IOException, URISyntaxException { - TestRequirements.engine("PyTorch"); + TestRequirements.linux(); String[][] args = {{}, {"-p"}}; - Engine.getInstance().setRandomSeed(1234); - TrainingResult result; + Engine.getEngine("PyTorch").setRandomSeed(1234); for (String[] arg : args) { - result = TransferFreshFruit.runExample(arg); + TrainingResult result = TransferFreshFruit.runExample(arg); Assert.assertNotNull(result); - Assert.assertTrue(result.getEvaluations().get("validate_Accuracy") > 0.9f); + Assert.assertTrue(result.getEvaluations().get("validate_Accuracy") > 0.76f); Assert.assertTrue(result.getEvaluations().get("train_Accuracy") > 0.9f); } } diff --git a/examples/src/test/java/ai/djl/testing/TestRequirements.java b/examples/src/test/java/ai/djl/testing/TestRequirements.java index 01eef756201..0e7d528cfb3 100644 --- a/examples/src/test/java/ai/djl/testing/TestRequirements.java +++ b/examples/src/test/java/ai/djl/testing/TestRequirements.java @@ -14,7 +14,6 @@ import ai.djl.engine.Engine; import ai.djl.engine.EngineException; -import ai.djl.util.Utils; import org.testng.SkipException; @@ -44,13 +43,6 @@ public static void weekly() { } } - /** Requires a test not be run in offline mode. */ - public static void notOffline() { - if (Utils.isOfflineMode()) { - throw new SkipException("This test can not run while offline"); - } - } - /** * Requires a test only with the allowed engine(s). * @@ -73,31 +65,20 @@ public static void engine(String... engines) { } /** - * Requires a test have any engines except for those listed. + * Requires a test have runs on Linux. * - * @param engines the engine(s) to not run the test on + *

Avoid running multiple engines on Windows and PyTorch on macos x86_64 machine */ - public static void notEngine(String... engines) { - String engineName = Engine.getDefaultEngineName(); - for (String e : engines) { - if (engineName.equals(e)) { - throw new SkipException( - "This test requires not using the engines: " + Arrays.toString(engines)); - } + public static void linux() { + if (!System.getProperty("os.name").toLowerCase().startsWith("linux")) { + throw new SkipException("This test requires a Linux os."); } } /** Requires a test have at least one gpu. */ - public static void gpu() { - if (Engine.getInstance().getGpuCount() == 0) { + public static void gpu(String engine) { + if (Engine.getEngine(engine).getGpuCount() == 0) { throw new SkipException("This test requires a GPU to run"); } } - - /** Requires that the test runs on OSX or linux, not windows. */ - public static void notWindows() { - if (System.getProperty("os.name").toLowerCase().startsWith("win")) { - throw new SkipException("This test requires a non-windows os."); - } - } }