Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples] Enabled training unit tests on macOS M1 #3256

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ public String getEngineName() {
/** {@inheritDoc} */
@Override
public int getEngineRank() {
String osName = System.getProperty("os.name");
String osArch = System.getProperty("os.arch");
if (osName.startsWith("Mac") && "aarch64".equals(osArch)) {
// MXNet doesn't support macOS M1
return 99;
}
return MxEngine.RANK;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
BertArguments arguments = (BertArguments) new BertArguments().parseArgs(args);

BertCodeDataset dataset =
new BertCodeDataset(arguments.getBatchSize(), arguments.getLimit());
new BertCodeDataset(
arguments.getBatchSize(), arguments.getLimit(), arguments.getEngine());
dataset.prepare();

// Create model & trainer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ public class BertCodeDataset implements Dataset {
private long epochLimit;
private NDManager manager;

public BertCodeDataset(int batchSize, long epochLimit) {
public BertCodeDataset(int batchSize, long epochLimit, String engine) {
this.batchSize = batchSize;
this.epochLimit = epochLimit;
this.manager = NDManager.newBaseManager();
this.manager = NDManager.newBaseManager(engine);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

Expand All @@ -25,8 +24,6 @@ public class TrainAirfoilWithTabNetTest {

@Test
public void testTrainAirfoilWithTabNet() throws TranslateException, IOException {
TestRequirements.linux();

String[] args = {"-g", "1", "-e", "20", "-b", "32"};
if (!Boolean.getBoolean("nightly")) {
args[3] = "2";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.translate.TranslateException;

import org.testng.annotations.Test;
Expand All @@ -23,9 +22,7 @@ public class TrainBertOnGoemotionsTest {

@Test
public void testTrainBert() throws IOException, TranslateException {
TestRequirements.linux();

String[] args = new String[] {"-g", "1", "-m", "1", "-e", "1"};
String[] args = {"-g", "1", "-m", "1", "-e", "1"};
TrainBertOnGoemotions.runExample(args);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.translate.TranslateException;

import org.testng.annotations.Test;
Expand All @@ -23,9 +22,7 @@ public class TrainBertTest {

@Test
public void testTrainBert() throws IOException, TranslateException {
TestRequirements.linux();

String[] args = new String[] {"-g", "1", "-m", "1", "-e", "1"};
String[] args = {"-g", "1", "-m", "1", "-e", "1"};
TrainBertOnCode.runExample(args);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.ModelException;
import ai.djl.examples.inference.cv.ImageClassification;
import ai.djl.modality.Classifications;
import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

Expand All @@ -28,8 +27,6 @@ public class TrainMnistTest {

@Test
public void testTrainMnist() throws ModelException, TranslateException, IOException {
TestRequirements.linux();

double expectedProb;
if (Boolean.getBoolean("nightly")) {
String[] args = new String[] {"-g", "1"};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
*/
package ai.djl.examples.training;

import ai.djl.testing.TestRequirements;
import ai.djl.engine.Engine;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

Expand All @@ -25,10 +25,14 @@ public class TrainMnistWithLSTMTest {

@Test
public void testTrainMnistWithLSTM() throws IOException, TranslateException {
TestRequirements.linux();

// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
String[] args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"};
String[] args;
Engine engine = Engine.getEngine("PyTorch");
if (engine.getGpuCount() > 0) {
// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"};
} else {
args = new String[] {"-g", "1", "-e", "1", "-m", "2"};
}
TrainingResult result = TrainMnistWithLSTM.runExample(args);
Assert.assertNotNull(result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public void testTrainSentimentAnalysis()
TestRequirements.nightly();
TestRequirements.gpu("MXNet");

String[] args = new String[] {"-e", "1", "-g", "1", "--engine", "MXNet"};
String[] args = {"-e", "1", "-g", "1", "--engine", "MXNet"};
TrainSentimentAnalysis.runExample(args);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class TrainSeq2SeqTest {
public void testTrainSeq2Seq() throws IOException, TranslateException {
TestRequirements.linux();

// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
// TODO: PyTorch -- PtNDArray.sequenceMask not implemented
String[] args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"};
TrainingResult result = TrainSeq2Seq.runExample(args);
Assert.assertNotNull(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
package ai.djl.examples.training;

import ai.djl.engine.Engine;
import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;

import org.testng.Assert;
Expand All @@ -25,8 +24,6 @@ public class TrainTicTacToeTest {

@Test
public void testTrainTicTacToe() throws IOException {
TestRequirements.linux();

Engine engine = Engine.getEngine("PyTorch");
if (Boolean.getBoolean("nightly") && engine.getGpuCount() > 0) {
String[] args = new String[] {"-g", "1", "-e", "6"};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class TrainTimeSeriesTest {
public void testTrainTimeSeries() throws TranslateException, IOException {
TestRequirements.linux();

// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
// TODO: PyTorch -- PtNDArray.gammaln not implemented
String[] args = new String[] {"-g", "1", "-e", "5", "-b", "32", "--engine", "MXNet"};
TrainingResult result = TrainTimeSeries.runExample(args);
Assert.assertNotNull(result);
Expand Down
Loading