Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples] Prepare for MXNet deprecation #3157

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ public Shape[] getOutputShapes(Shape[] inputShapes) {
/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(Shape[] inputShapes, DataType[] dataTypes) {
try (NDManager manager = NDManager.newBaseManager()) {
try (NDManager manager = NDManager.newBaseManager("PyTorch")) {
NDList list = new NDList();
for (int i = 0; i < inputShapes.length; i++) {
list.add(
Expand Down
2 changes: 1 addition & 1 deletion examples/docs/train_transfer_fresh_fruit.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ private static DefaultTrainingConfig setupTrainingConfig(Block baseBlock) {

DefaultTrainingConfig config = new DefaultTrainingConfig(new SoftmaxCrossEntropy("SoftmaxCrossEntropy"))
.addEvaluator(new Accuracy())
.optDevices(Engine.getInstance().getDevices(1))
.optDevices(Engine.getEngine("PyTorch").getDevices(1))
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
package ai.djl.examples.inference;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.repository.zoo.Criteria;
Expand Down Expand Up @@ -68,14 +68,14 @@ public static String predict() throws IOException, TranslateException, ModelExce
.optApplication(Application.NLP.QUESTION_ANSWER)
.setTypes(QAInput.class, String.class)
.optFilter("backbone", "bert")
.optEngine(Engine.getDefaultEngineName())
.optEngine("PyTorch")
.optDevice(Device.cpu())
.optProgress(new ProgressBar())
.build();

try (ZooModel<QAInput, String> model = criteria.loadModel()) {
try (Predictor<QAInput, String> predictor = model.newPredictor()) {
return predictor.predict(input);
}
try (ZooModel<QAInput, String> model = criteria.loadModel();
Predictor<QAInput, String> predictor = model.newPredictor()) {
return predictor.predict(input);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public static Classifications predict()
Criteria.builder()
.optApplication(Application.NLP.SENTIMENT_ANALYSIS)
.setTypes(String.class, Classifications.class)
.optEngine("PyTorch")
// This model was traced on CPU and can only run on CPU
.optDevice(Device.cpu())
.optProgress(new ProgressBar())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

package ai.djl.examples.inference;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.audio.Audio;
Expand Down Expand Up @@ -56,6 +57,7 @@ public static String predict() throws IOException, ModelException, TranslateExce
Criteria.builder()
.setTypes(Audio.class, String.class)
.optModelUrls(url)
.optDevice(Device.cpu()) // torchscript model only support CPU
.optTranslatorFactory(new SpeechRecognitionTranslatorFactory())
.optModelName("wav2vec2.ptl")
.optEngine("PyTorch")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.examples.inference.clip;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
Expand Down Expand Up @@ -45,6 +46,7 @@ public ClipModel() throws ModelException, IOException {
.optModelUrls("https://resources.djl.ai/demo/pytorch/clip.zip")
.optTranslator(new NoopTranslator())
.optEngine("PyTorch")
.optDevice(Device.cpu()) // torchscript model only support CPU
.build();
clip = criteria.loadModel();
imageFeatureExtractor = clip.newPredictor(new ImageTranslator());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public static Classifications predict() throws IOException, ModelException, Tran
Image img = ImageFactory.getInstance().fromFile(imageFile);

String modelName = "mlp";
try (Model model = Model.newInstance(modelName)) {
try (Model model = Model.newInstance(modelName, "PyTorch")) {
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));

// Assume you have run TrainMnist.java example, and saved model in build/model folder.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
Expand Down Expand Up @@ -54,19 +53,12 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran
Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg");
Image img = ImageFactory.getInstance().fromFile(imageFile);

String backbone;
if ("TensorFlow".equals(Engine.getDefaultEngineName())) {
backbone = "mobilenet_v2";
} else {
backbone = "resnet50";
}

Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION)
.setTypes(Image.class, DetectedObjects.class)
.optFilter("backbone", backbone)
.optEngine(Engine.getDefaultEngineName())
.optFilter("backbone", "mobilenet_v2")
.optEngine("TensorFlow")
.optProgress(new ProgressBar())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
*/
package ai.djl.examples.inference.nlp;

import ai.djl.MalformedModelException;
import ai.djl.ModelException;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.generate.CausalLMOutput;
Expand All @@ -22,7 +22,6 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.DeferredTranslatorFactory;
import ai.djl.translate.TranslateException;
Expand All @@ -39,20 +38,13 @@ public final class RollingBatch {

private RollingBatch() {}

public static void main(String[] args)
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
public static void main(String[] args) throws ModelException, IOException, TranslateException {
String[] ret = seqBatchSchedulerWithPyTorchContrastive();
logger.info("{}", ret[0]);
}

public static String[] seqBatchSchedulerWithPyTorchContrastive()
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
throws ModelException, IOException, TranslateException {
String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_pt.zip";

Criteria<NDList, CausalLMOutput> criteria =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.examples.inference.nlp;

import ai.djl.MalformedModelException;
import ai.djl.ModelException;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
Expand Down Expand Up @@ -161,10 +162,7 @@ public static String[] generateTextWithPyTorchBeam()
}

public static String[] generateTextWithOnnxRuntimeBeam()
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
throws ModelException, IOException, TranslateException {
SearchConfig config = new SearchConfig();
config.setMaxSeqLength(60);
long padTokenId = 220;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.basicdataset.tabular.utils.DynamicBuffer;
import ai.djl.basicdataset.tabular.utils.Feature;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
Expand Down Expand Up @@ -63,7 +62,6 @@
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

Expand All @@ -82,9 +80,7 @@ public static void main(String[] args) throws IOException, TranslateException, M

public static Map<String, Float> predict()
throws IOException, TranslateException, ModelException {
Engine engine = Engine.getInstance();
NDManager manager = engine.newBaseManager();
String engineName = engine.getEngineName().toLowerCase(Locale.ROOT);
NDManager manager = NDManager.newBaseManager("MXNet");

// To use local dataset, users can load data as follows
// Repository repository = Repository.newInstance("local_dataset",
Expand All @@ -102,12 +98,13 @@ public static Map<String, Float> predict()
// https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008#file-m5torch-py

// Here you can also use local file: modelUrl = "LOCAL_PATH/deepar.pt";
String modelUrl = "djl://ai.djl." + engineName + "/deepar/0.0.1/m5forecast";
String modelUrl = "djl://ai.djl.mxnet/deepar/0.0.1/m5forecast";
int predictionLength = 4;
Criteria<TimeSeriesData, Forecast> criteria =
Criteria.builder()
.setTypes(TimeSeriesData.class, Forecast.class)
.optModelUrls(modelUrl)
.optEngine("MXNet")
.optTranslatorFactory(new DeepARTranslatorFactory())
.optArgument("prediction_length", predictionLength)
.optArgument("freq", "W")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import ai.djl.basicdataset.tabular.TabularDataset;
import ai.djl.basicdataset.tabular.TabularResults;
import ai.djl.basicmodelzoo.tabular.TabNet;
import ai.djl.engine.Engine;
import ai.djl.examples.training.util.Arguments;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
Expand Down Expand Up @@ -55,7 +54,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
// Construct a tabNet instance
Block tabNet = TabNet.builder().setInputDim(5).setOutDim(1).build();

try (Model model = Model.newInstance("tabNet")) {
try (Model model = Model.newInstance("tabNet", arguments.getEngine())) {
model.setBlock(tabNet);

// get the training and validation dataset
Expand Down Expand Up @@ -103,13 +102,12 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
});

return new DefaultTrainingConfig(new TabNetRegressionLoss())
.optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus()))
.optDevices(arguments.getMaxGpus())
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
}

private static TabularDataset getDataset(Arguments arguments)
throws IOException, TranslateException {
private static TabularDataset getDataset(Arguments arguments) throws IOException {
AirfoilRandomAccess.Builder airfoilBuilder = AirfoilRandomAccess.builder();

// only train dataset is available, so we get train dataset and split them
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
package ai.djl.examples.training;

import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.BertCodeDataset;
import ai.djl.ndarray.types.Shape;
Expand Down Expand Up @@ -59,7 +58,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
dataset.prepare();

// Create model & trainer
try (Model model = createBertPretrainingModel(dataset.getVocabularySize())) {
try (Model model = createBertPretrainingModel(arguments, dataset.getVocabularySize())) {

TrainingConfig config = createTrainingConfig(arguments);
try (Trainer trainer = model.newTrainer(config)) {
Expand All @@ -74,15 +73,15 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
}
}

private static Model createBertPretrainingModel(long vocabularySize) {
private static Model createBertPretrainingModel(Arguments arguments, long vocabularySize) {
Block block =
new BertPretrainingBlock(
BertBlock.builder()
.micro()
.setTokenDictionarySize(Math.toIntExact(vocabularySize)));
block.setInitializer(new TruncatedNormalInitializer(0.02f), Parameter.Type.WEIGHT);

Model model = Model.newInstance("Bert Pretraining");
Model model = Model.newInstance("Bert Pretraining", arguments.getEngine());
model.setBlock(block);
return model;
}
Expand All @@ -108,7 +107,7 @@ private static TrainingConfig createTrainingConfig(BertArguments arguments) {
.build();
return new DefaultTrainingConfig(new BertPretrainingLoss())
.optOptimizer(optimizer)
.optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus()))
.optDevices(arguments.getMaxGpus())
.addTrainingListeners(Defaults.logging());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import ai.djl.Model;
import ai.djl.basicdataset.nlp.GoEmotions;
import ai.djl.engine.Engine;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.BertGoemotionsDataset;
import ai.djl.modality.nlp.embedding.EmbeddingException;
Expand Down Expand Up @@ -67,7 +66,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
dataset.prepare();

// Create model & trainer
try (Model model = createBertPretrainingModel(dataset.getVocabularySize())) {
try (Model model = createBertPretrainingModel(arguments, dataset.getVocabularySize())) {
TrainingConfig config = createTrainingConfig(arguments);
try (Trainer trainer = model.newTrainer(config)) {
// Initialize training
Expand Down Expand Up @@ -105,19 +104,19 @@ private static TrainingConfig createTrainingConfig(
.build();
return new DefaultTrainingConfig(new BertPretrainingLoss())
.optOptimizer(optimizer)
.optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus()))
.optDevices(arguments.getMaxGpus())
.addTrainingListeners(TrainingListener.Defaults.logging());
}

private static Model createBertPretrainingModel(long vocabularySize) {
private static Model createBertPretrainingModel(Arguments arguments, long vocabularySize) {
Block block =
new BertPretrainingBlock(
BertBlock.builder()
.micro()
.setTokenDictionarySize(Math.toIntExact(vocabularySize)));
block.setInitializer(new TruncatedNormalInitializer(0.02f), Parameter.Type.WEIGHT);

Model model = Model.newInstance("Bert Pretraining");
Model model = Model.newInstance("Bert Pretraining", arguments.getEngine());
model.setBlock(block);
return model;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.CaptchaDataset;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.engine.Engine;
import ai.djl.examples.training.util.Arguments;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDArray;
Expand Down Expand Up @@ -63,7 +62,7 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
return null;
}

try (Model model = Model.newInstance("captcha")) {
try (Model model = Model.newInstance("captcha", arguments.getEngine())) {
model.setBlock(getBlock());

// get training and validation dataset
Expand Down Expand Up @@ -107,7 +106,7 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {

DefaultTrainingConfig config =
new DefaultTrainingConfig(loss)
.optDevices(Engine.getInstance().getDevices(arguments.getMaxGpus()))
.optDevices(arguments.getMaxGpus())
.addEvaluators(loss.getComponents())
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
Expand Down
Loading
Loading