Skip to content

Commit

Permalink
[example] Adds PyTorch action recognition model to model zoo
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Jul 1, 2024
1 parent c895909 commit 94054b2
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ public class PtModelZoo extends ModelZoo {
public static final String GROUP_ID = "ai.djl.pytorch";

PtModelZoo() {
addModel(
REPOSITORY.model(
CV.ACTION_RECOGNITION,
GROUP_ID,
"Human-Action-Recognition-VIT-Base-patch16-224",
"0.0.1"));
addModel(REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1"));
addModel(
REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet18_embedding", "0.0.1"));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/action_recognition",
"groupId": "ai.djl.pytorch",
"artifactId": "Human-Action-Recognition-VIT-Base-patch16-224",
"name": "Human-Action-Recognition-VIT-Base-patch16-224",
"description": "Action Recognition models",
"website": "http://www.djl.ai/engines/pytorch/model-zoo",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "Human-Action-Recognition-VIT-Base-patch16-224",
"properties": {
},
"arguments": {
"width": 224,
"height": 224,
"resize": true,
"scaleFactor": 0.00392156862745098,
"normalize": "0.5,0.5,0.5,0.5,0.5,0.5",
"applySoftmax": true,
"translatorFactory": "ai.djl.modality.cv.translator.ImageClassificationTranslatorFactory"
},
"options": {
"mapLocation": "true"
},
"files": {
"model": {
"uri": "0.0.1/Human-Action-Recognition-VIT-Base-patch16-224.zip",
"sha1Hash": "943cf0f9cfab07445489aeb67b9c74c9c85680f4",
"name": "",
"size": 318499747
}
}
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;

/**
* An example of inference using an action recognition model.
Expand All @@ -48,16 +46,16 @@ public static void main(String[] args) throws IOException, ModelException, Trans
}

public static Classifications predict() throws IOException, ModelException, TranslateException {
Path imageFile = Paths.get("src/test/resources/action_discus_throw.png");
Image img = ImageFactory.getInstance().fromFile(imageFile);
String url = "https://resources.djl.ai/images/action_dance.jpg";
Image img = ImageFactory.getInstance().fromUrl(url);

// Use DJL MXNet model zoo model
// Use DJL PyTorch model zoo model
Criteria<Image, Classifications> criteria =
Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelUrls(
"djl://ai.djl.mxnet/action_recognition/0.0.1/inceptionv3_ucf101")
.optEngine("MXNet")
"djl://ai.djl.pytorch/Human-Action-Recognition-VIT-Base-patch16-224")
.optEngine("PyTorch")
.optProgress(new ProgressBar())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.ModelException;
import ai.djl.examples.inference.cv.ActionRecognition;
import ai.djl.modality.Classifications;
import ai.djl.testing.TestRequirements;
import ai.djl.translate.TranslateException;

import org.testng.Assert;
Expand All @@ -27,11 +26,9 @@ public class ActionRecognitionTest {

@Test
public void testActionRecognition() throws ModelException, TranslateException, IOException {
TestRequirements.linux();

Classifications result = ActionRecognition.predict();
Classifications.Classification best = result.best();
Assert.assertEquals(best.getClassName(), "ThrowDiscus");
Assert.assertEquals(best.getClassName(), "Dancing");
Assert.assertTrue(Double.compare(best.getProbability(), 0.9) > 0);
}
}

0 comments on commit 94054b2

Please sign in to comment.