diff --git a/api/src/main/java/ai/djl/modality/cv/translator/ImageFeatureExtractor.java b/api/src/main/java/ai/djl/modality/cv/translator/ImageFeatureExtractor.java new file mode 100644 index 00000000000..8edde7a34d5 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/ImageFeatureExtractor.java @@ -0,0 +1,82 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.ndarray.NDList; +import ai.djl.translate.TranslatorContext; + +import java.util.Map; + +/** + * A generic {@link ai.djl.translate.Translator} for Image Classification feature extraction tasks. + */ +public class ImageFeatureExtractor extends BaseImageTranslator { + + /** + * Constructs an Image Classification using {@link Builder}. + * + * @param builder the data to build with + */ + ImageFeatureExtractor(Builder builder) { + super(builder); + } + + /** {@inheritDoc} */ + @Override + public byte[] processOutput(TranslatorContext ctx, NDList list) { + return list.get(0).toByteArray(); + } + + /** + * Creates a builder to build a {@code ImageFeatureExtractor}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Creates a builder to build a {@code ImageFeatureExtractor} with specified arguments. + * + * @param arguments arguments to specify builder options + * @return a new builder + */ + public static Builder builder(Map arguments) { + Builder builder = new Builder(); + builder.configPreProcess(arguments); + return builder; + } + + /** A Builder to construct a {@code ImageFeatureExtractor}. */ + public static class Builder extends BaseBuilder { + + Builder() {} + + /** {@inheritDoc} */ + @Override + protected Builder self() { + return this; + } + + /** + * Builds the {@link ImageFeatureExtractor} with the provided data. + * + * @return an {@link ImageFeatureExtractor} + */ + public ImageFeatureExtractor build() { + validate(); + return new ImageFeatureExtractor(this); + } + } +} diff --git a/api/src/main/java/ai/djl/modality/cv/translator/ImageFeatureExtractorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/ImageFeatureExtractorFactory.java new file mode 100644 index 00000000000..b89c7a155d2 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/ImageFeatureExtractorFactory.java @@ -0,0 +1,72 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.translator.wrapper.FileTranslator; +import ai.djl.modality.cv.translator.wrapper.InputStreamTranslator; +import ai.djl.modality.cv.translator.wrapper.UrlTranslator; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; +import ai.djl.util.Pair; + +import java.io.InputStream; +import java.lang.reflect.Type; +import java.net.URL; +import java.nio.file.Path; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** A {@link TranslatorFactory} that creates an {@link ImageClassificationTranslator}. */ +public class ImageFeatureExtractorFactory implements TranslatorFactory { + + private static final Set> SUPPORTED_TYPES = new HashSet<>(); + + static { + SUPPORTED_TYPES.add(new Pair<>(Image.class, byte[].class)); + SUPPORTED_TYPES.add(new Pair<>(Path.class, byte[].class)); + SUPPORTED_TYPES.add(new Pair<>(URL.class, byte[].class)); + SUPPORTED_TYPES.add(new Pair<>(InputStream.class, byte[].class)); + SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); + } + + /** {@inheritDoc} */ + @Override + public Set> getSupportedTypes() { + return SUPPORTED_TYPES; + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public Translator newInstance( + Class input, Class output, Model model, Map arguments) { + ImageFeatureExtractor translator = ImageFeatureExtractor.builder(arguments).build(); + if (input == Image.class && output == byte[].class) { + return (Translator) translator; + } else if (input == Path.class && output == byte[].class) { + return (Translator) new FileTranslator<>(translator); + } else if (input == URL.class && output == byte[].class) { + return (Translator) new UrlTranslator<>(translator); + } else if (input == InputStream.class && output == byte[].class) { + return (Translator) new InputStreamTranslator<>(translator); + } else if (input == Input.class && output == Output.class) { + return (Translator) new ImageServingTranslator(translator); + } + throw new IllegalArgumentException("Unsupported input/output types."); + } +} diff --git a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java index c53aeedd5a2..b49dd6a6230 100644 --- a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java +++ b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java @@ -33,6 +33,8 @@ public class PtModelZoo extends ModelZoo { PtModelZoo() { addModel(REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1")); + addModel( + REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet18_embedding", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd", "0.0.1")); addModel(REPOSITORY.model(NLP.QUESTION_ANSWER, GROUP_ID, "bertqa", "0.0.1")); addModel(REPOSITORY.model(NLP.SENTIMENT_ANALYSIS, GROUP_ID, "distilbert", "0.0.1")); diff --git a/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet18_embedding/metadata.json b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet18_embedding/metadata.json new file mode 100644 index 00000000000..dbf438b50b9 --- /dev/null +++ b/engines/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet18_embedding/metadata.json @@ -0,0 +1,42 @@ +{ + "metadataVersion": "0.1", + "resourceType": "model", + "application": "cv/image_classification", + "groupId": "ai.djl.pytorch", + "artifactId": "resnet18_embedding", + "name": "resnet18_embedding", + "description": "A pretrained resnet18 model as an embedding base model", + "website": "http://www.djl.ai/engines/pytorch/model-zoo", + "licenses": { + "license": { + "name": "The Apache License, Version 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + } + }, + "artifacts": [ + { + "version": "0.0.1", + "snapshot": false, + "name": "resnet18_embedding", + "arguments": { + "width": 224, + "height": 224, + "resize": 256, + "centerCrop": true, + "normalize": true, + "translatorFactory": "ai.djl.modality.cv.translator.ImageFeatureExtractorFactory" + }, + "options": { + "mapLocation": "true" + }, + "files": { + "model": { + "uri": "0.0.1/resnet18_embedding.zip", + "name": "", + "sha1Hash": "e37db339e87dc13ae0831e45818fca454f526ffb", + "size": 41574720 + } + } + } + ] +}