Skip to content

Commit

Permalink
[examples] Enabled training unit tests on macOS M1
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Jun 16, 2024
1 parent 0b026dc commit ab31870
Show file tree
Hide file tree
Showing 13 changed files with 24 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ public String getEngineName() {
/** {@inheritDoc} */
@Override
public int getEngineRank() {
String osName = System.getProperty("os.name");
String osArch = System.getProperty("os.arch");
if (osName.startsWith("Mac") && "aarch64".equals(osArch)) {
// MXNet doesn't support macOS M1
return 99;
}
return MxEngine.RANK;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
BertArguments arguments = (BertArguments) new BertArguments().parseArgs(args);

BertCodeDataset dataset =
new BertCodeDataset(arguments.getBatchSize(), arguments.getLimit());
new BertCodeDataset(
arguments.getBatchSize(), arguments.getLimit(), arguments.getEngine());
dataset.prepare();

// Create model & trainer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ public class BertCodeDataset implements Dataset {
private long epochLimit;
private NDManager manager;

public BertCodeDataset(int batchSize, long epochLimit) {
public BertCodeDataset(int batchSize, long epochLimit, String engine) {
this.batchSize = batchSize;
this.epochLimit = epochLimit;
this.manager = NDManager.newBaseManager();
this.manager = NDManager.newBaseManager(engine);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

Expand All @@ -25,8 +24,6 @@ public class TrainAirfoilWithTabNetTest {

@Test
public void testTrainAirfoilWithTabNet() throws TranslateException, IOException {
TestRequirements.linux();

String[] args = {"-g", "1", "-e", "20", "-b", "32"};
if (!Boolean.getBoolean("nightly")) {
args[3] = "2";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.translate.TranslateException;

import org.testng.annotations.Test;
Expand All @@ -23,9 +22,7 @@ public class TrainBertOnGoemotionsTest {

@Test
public void testTrainBert() throws IOException, TranslateException {
TestRequirements.linux();

String[] args = new String[] {"-g", "1", "-m", "1", "-e", "1"};
String[] args = {"-g", "1", "-m", "1", "-e", "1"};
TrainBertOnGoemotions.runExample(args);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.translate.TranslateException;

import org.testng.annotations.Test;
Expand All @@ -23,9 +22,7 @@ public class TrainBertTest {

@Test
public void testTrainBert() throws IOException, TranslateException {
TestRequirements.linux();

String[] args = new String[] {"-g", "1", "-m", "1", "-e", "1"};
String[] args = {"-g", "1", "-m", "1", "-e", "1"};
TrainBertOnCode.runExample(args);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.ModelException;
import ai.djl.examples.inference.cv.ImageClassification;
import ai.djl.modality.Classifications;
import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

Expand All @@ -28,8 +27,6 @@ public class TrainMnistTest {

@Test
public void testTrainMnist() throws ModelException, TranslateException, IOException {
TestRequirements.linux();

double expectedProb;
if (Boolean.getBoolean("nightly")) {
String[] args = new String[] {"-g", "1"};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
*/
package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.engine.Engine;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

Expand All @@ -25,10 +25,14 @@ public class TrainMnistWithLSTMTest {

@Test
public void testTrainMnistWithLSTM() throws IOException, TranslateException {
TestRequirements.linux();

// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
String[] args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"};
String[] args;
Engine engine = Engine.getEngine("PyTorch");
if (engine.getGpuCount() > 0) {
// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"};
} else {
args = new String[] {"-g", "1", "-e", "1", "-m", "2"};
}
TrainingResult result = TrainMnistWithLSTM.runExample(args);
Assert.assertNotNull(result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public void testTrainSentimentAnalysis()
TestRequirements.nightly();
TestRequirements.gpu("MXNet");

String[] args = new String[] {"-e", "1", "-g", "1", "--engine", "MXNet"};
String[] args = {"-e", "1", "-g", "1", "--engine", "MXNet"};
TrainSentimentAnalysis.runExample(args);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class TrainSeq2SeqTest {
public void testTrainSeq2Seq() throws IOException, TranslateException {
TestRequirements.linux();

// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
// TODO: PyTorch -- PtNDArray.sequenceMask not implemented
String[] args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"};
TrainingResult result = TrainSeq2Seq.runExample(args);
Assert.assertNotNull(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
package ai.djl.examples.training;

import ai.djl.engine.Engine;
import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;

import org.testng.Assert;
Expand All @@ -25,8 +24,6 @@ public class TrainTicTacToeTest {

@Test
public void testTrainTicTacToe() throws IOException {
TestRequirements.linux();

Engine engine = Engine.getEngine("PyTorch");
if (Boolean.getBoolean("nightly") && engine.getGpuCount() > 0) {
String[] args = new String[] {"-g", "1", "-e", "6"};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class TrainTimeSeriesTest {
public void testTrainTimeSeries() throws TranslateException, IOException {
TestRequirements.linux();

// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
// TODO: PyTorch -- PtNDArray.gammaln not implemented
String[] args = new String[] {"-g", "1", "-e", "5", "-b", "32", "--engine", "MXNet"};
TrainingResult result = TrainTimeSeries.runExample(args);
Assert.assertNotNull(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.examples.training.transferlearning.TransferFreshFruit;
import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

Expand All @@ -30,8 +29,6 @@ public class TransferFreshFruitTest {
@Test
public void testTransferFreshFruit()
throws ModelException, TranslateException, IOException, URISyntaxException {
TestRequirements.linux();

String[][] args = {{}, {"-p"}};
Engine.getEngine("PyTorch").setRandomSeed(1234);
for (String[] arg : args) {
Expand Down

0 comments on commit ab31870

Please sign in to comment.