Skip to content

Commit

Permalink
[pytorch] Adds yolov8n pose estimation model
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Jul 6, 2024
1 parent 6192fd6 commit 09d4a89
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 91 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Joints.Joint;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.TranslatorContext;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/** A translator for Yolov8 pose estimation models. */
public class YoloPoseTranslator extends BaseImageTranslator<Joints[]> {

private static final int MAX_DETECTION = 300;

private int width;
private int height;
private float threshold;
private float nmsThreshold;

/**
* Creates the Pose Estimation translator from the given builder.
*
* @param builder the builder for the translator
*/
public YoloPoseTranslator(Builder builder) {
super(builder);
this.width = builder.width;
this.height = builder.height;
this.threshold = builder.threshold;
this.nmsThreshold = builder.nmsThreshold;
}

/** {@inheritDoc} */
@Override
public Joints[] processOutput(TranslatorContext ctx, NDList list) {
NDArray pred = list.singletonOrThrow();
NDArray candidates = pred.get(4).gt(threshold);
pred = pred.transpose();
NDArray sub = pred.get("..., :4");
sub = xywh2xyxy(sub);
pred = sub.concat(pred.get("..., 4:"), -1);
pred = pred.get(candidates);

NDList split = pred.split(new long[] {4, 5}, 1);
NDArray box = split.get(0);

int numBox = Math.toIntExact(box.getShape().get(0));

float[] buf = box.toFloatArray();
float[] confidences = split.get(1).toFloatArray();
float[] mask = split.get(2).toFloatArray();

List<Rectangle> boxes = new ArrayList<>(numBox);
List<Double> scores = new ArrayList<>(numBox);

for (int i = 0; i < numBox; ++i) {
float xPos = buf[i * 4];
float yPos = buf[i * 4 + 1];
float w = buf[i * 4 + 2] - xPos;
float h = buf[i * 4 + 3] - yPos;
Rectangle rect = new Rectangle(xPos, yPos, w, h);
boxes.add(rect);
scores.add((double) confidences[i]);
}
List<Integer> nms = Rectangle.nms(boxes, scores, nmsThreshold);
if (nms.size() > MAX_DETECTION) {
nms = nms.subList(0, MAX_DETECTION);
}
Joints[] ret = new Joints[nms.size()];
for (int i = 0; i < ret.length; ++i) {
List<Joint> joints = new ArrayList<>();
ret[i] = new Joints(joints);

int index = nms.get(i);
int pos = index * 51;
for (int j = 0; j < 17; ++j) {
joints.add(
new Joints.Joint(
mask[pos + j * 3] / width,
mask[pos + j * 3 + 1] / height,
mask[pos + j * 3 + 2]));
}
}
return ret;
}

private NDArray xywh2xyxy(NDArray array) {
NDArray xy = array.get("..., :2");
NDArray wh = array.get("..., 2:").div(2);
return xy.sub(wh).concat(xy.add(wh), -1);
}

/**
* Creates a builder to build a {@code YoloPoseTranslator}.
*
* @return a new builder
*/
public static Builder builder() {
return new Builder();
}

/**
* Creates a builder to build a {@code YoloPoseTranslator} with specified arguments.
*
* @param arguments arguments to specify builder options
* @return a new builder
*/
public static Builder builder(Map<String, ?> arguments) {
Builder builder = new Builder();
builder.configPreProcess(arguments);
builder.configPostProcess(arguments);

return builder;
}

/** The builder for Pose Estimation translator. */
public static class Builder extends BaseBuilder<Builder> {

float threshold = 0.25f;
float nmsThreshold = 0.7f;

Builder() {}

/**
* Sets the threshold for prediction accuracy.
*
* <p>Predictions below the threshold will be dropped.
*
* @param threshold the threshold for prediction accuracy
* @return the builder
*/
public Builder optThreshold(float threshold) {
this.threshold = threshold;
return self();
}

/**
* Sets the NMS threshold.
*
* @param nmsThreshold the NMS threshold
* @return this builder
*/
public Builder optNmsThreshold(float nmsThreshold) {
this.nmsThreshold = nmsThreshold;
return this;
}

/** {@inheritDoc} */
@Override
protected Builder self() {
return this;
}

/** {@inheritDoc} */
@Override
protected void configPostProcess(Map<String, ?> arguments) {
optThreshold(ArgumentsUtil.floatValue(arguments, "threshold", threshold));
optNmsThreshold(ArgumentsUtil.floatValue(arguments, "nmsThreshold", nmsThreshold));
}

/**
* Builds the translator.
*
* @return the new translator
*/
public YoloPoseTranslator build() {
validate();
return new YoloPoseTranslator(this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.translator;

import ai.djl.Model;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.Joints;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;

import java.io.Serializable;
import java.lang.reflect.Type;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/** An {@link TranslatorFactory} that creates a {@link YoloPoseTranslator} instance. */
public class YoloPoseTranslatorFactory implements TranslatorFactory, Serializable {

private static final long serialVersionUID = 1L;

private static final Set<Pair<Type, Type>> SUPPORTED_TYPES = new HashSet<>();

static {
SUPPORTED_TYPES.add(new Pair<>(Image.class, Joints[].class));
SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class));
}

/** {@inheritDoc} */
@Override
public Set<Pair<Type, Type>> getSupportedTypes() {
return SUPPORTED_TYPES;
}

/** {@inheritDoc} */
@Override
@SuppressWarnings("unchecked")
public <I, O> Translator<I, O> newInstance(
Class<I> input, Class<O> output, Model model, Map<String, ?> arguments) {
YoloPoseTranslator translator = YoloPoseTranslator.builder(arguments).build();
if (input == Image.class && output == Joints[].class) {
return (Translator<I, O>) translator;
} else if (input == Input.class && output == Output.class) {
return (Translator<I, O>) new ImageServingTranslator(translator);
}
throw new IllegalArgumentException("Unsupported input/output types.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class PtModelZoo extends ModelZoo {
GROUP_ID,
"Human-Action-Recognition-VIT-Base-patch16-224",
"0.0.1"));
addModel(REPOSITORY.model(CV.POSE_ESTIMATION, GROUP_ID, "yolov8n-pose", "0.0.1"));
addModel(REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1"));
addModel(
REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet18_embedding", "0.0.1"));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"metadataVersion": "0.2",
"resourceType": "model",
"application": "cv/pose_estimation",
"groupId": "ai.djl.pytorch",
"artifactId": "yolov8n-pose",
"name": "Yolov8n pose",
"description": "Yolov8n Pose Estimation",
"website": "http://www.djl.ai/engines/pytorch/pytorch-model-zoo",
"licenses": {
"license": {
"name": "The Apache License, Version 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
}
},
"artifacts": [
{
"version": "0.0.1",
"snapshot": false,
"name": "yolov8n-pose",
"properties": {
},
"arguments": {
"width": 640,
"height": 640,
"resize": true,
"threshold": 0.25,
"translatorFactory": "ai.djl.modality.cv.translator.YoloPoseTranslatorFactory"
},
"options": {
"mapLocation": "true"
},
"files": {
"model": {
"uri": "0.0.1/yolov8n-pose.zip",
"sha1Hash": "b2358198586ea35189f30aa5f49285dc1842ec1b",
"name": "",
"size": 11684792
}
}
}
]
}
Loading

0 comments on commit 09d4a89

Please sign in to comment.